tensor.py 8.01 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import numpy as np
from deepmd.env import tf
from deepmd.common import ClassArg, add_data_requirement

from deepmd.env import global_cvt_2_tf_float
from deepmd.env import global_cvt_2_ener_float
from deepmd.utils.sess import run_sess
from .loss import Loss


class TensorLoss(Loss) :
    """
    Loss function for tensorial properties.
    """
    def __init__ (self, jdata, **kwarg) :
        model = kwarg.get('model', None)
        if model is not None:
            self.type_sel = model.get_sel_type()
        else:
            self.type_sel = None
        self.tensor_name = kwarg['tensor_name']
        self.tensor_size = kwarg['tensor_size']
        self.label_name = kwarg['label_name']
        if jdata is not None:
            self.scale = jdata.get('scale', 1.0)
        else:
            self.scale = 1.0

        # YHT: added for global / local dipole combination
        assert jdata is not None, "Please provide loss parameters!"
        # YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight
        self.local_weight = jdata.get('pref_atomic', None)
        self.global_weight = jdata.get('pref', None)

        assert (self.local_weight is not None and self.global_weight is not None), "Both `pref` and `pref_atomic` should be provided."
        assert self.local_weight >= 0.0 and self.global_weight >= 0.0, "Can not assign negative weight to `pref` and `pref_atomic`"
        assert (self.local_weight >0.0) or (self.global_weight>0.0), AssertionError('Can not assian zero weight both to `pref` and `pref_atomic`')

        # data required
        add_data_requirement("atomic_" + self.label_name, 
                             self.tensor_size, 
                             atomic=True,  
                             must=False, 
                             high_prec=False, 
                             type_sel = self.type_sel)
        add_data_requirement(self.label_name, 
                             self.tensor_size, 
                             atomic=False,  
                             must=False, 
                             high_prec=False, 
                             type_sel = self.type_sel)

    def build (self, 
               learning_rate,
               natoms,
               model_dict,
               label_dict,
               suffix):        
        polar_hat = label_dict[self.label_name]
        atomic_polar_hat = label_dict["atomic_" + self.label_name]
        polar = tf.reshape(model_dict[self.tensor_name], [-1])

        find_global = label_dict['find_' + self.label_name]
        find_atomic = label_dict['find_atomic_' + self.label_name]
        
        

        # YHT: added for global / local dipole combination
        l2_loss = global_cvt_2_tf_float(0.0)
        more_loss = {
            "local_loss":global_cvt_2_tf_float(0.0),
            "global_loss":global_cvt_2_tf_float(0.0)
        }

        
        if self.local_weight > 0.0:
            local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean( tf.square(self.scale*(polar - atomic_polar_hat)), name='l2_'+suffix)
            more_loss['local_loss'] = local_loss
            l2_loss += self.local_weight * local_loss
            self.l2_loss_local_summary = tf.summary.scalar('l2_local_loss', 
                                            tf.sqrt(more_loss['local_loss']))
        

        if self.global_weight > 0.0:    # Need global loss
            atoms = 0
            if self.type_sel is not None:
                for w in self.type_sel:
                    atoms += natoms[2+w]
            else:
                atoms = natoms[0]     
            nframes = tf.shape(polar)[0] // self.tensor_size // atoms
            # get global results
            global_polar = tf.reshape(tf.reduce_sum(tf.reshape(
                polar, [nframes, -1, self.tensor_size]), axis=1),[-1])
            #if self.atomic: # If label is local, however
            #    global_polar_hat = tf.reshape(tf.reduce_sum(tf.reshape(
            #        polar_hat, [nframes, -1, self.tensor_size]), axis=1),[-1])
            #else:
            #    global_polar_hat = polar_hat
            
            global_loss = global_cvt_2_tf_float(find_global) * tf.reduce_mean( tf.square(self.scale*(global_polar - polar_hat)), name='l2_'+suffix)

            more_loss['global_loss'] = global_loss
            self.l2_loss_global_summary = tf.summary.scalar('l2_global_loss', 
                                            tf.sqrt(more_loss['global_loss']) / global_cvt_2_tf_float(atoms))

            # YWolfeee: should only consider atoms with dipole, i.e. atoms
            # atom_norm  = 1./ global_cvt_2_tf_float(natoms[0])  
            atom_norm  = 1./ global_cvt_2_tf_float(atoms)  
            global_loss *= atom_norm   

            l2_loss += self.global_weight * global_loss
            
        self.l2_more = more_loss
        self.l2_l = l2_loss

        self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss))
        return l2_loss, more_loss

    def eval(self, sess, feed_dict, natoms):
        atoms = 0
        if self.type_sel is not None:
            for w in self.type_sel:
                atoms += natoms[2+w]
        else:
            atoms = natoms[0]

        run_data = [self.l2_l, self.l2_more['local_loss'], self.l2_more['global_loss']]
        error, error_lc, error_gl = run_sess(sess, run_data, feed_dict=feed_dict)

        results = {"natoms": atoms, "rmse": np.sqrt(error)}
        if self.local_weight > 0.0:
            results["rmse_lc"] = np.sqrt(error_lc)
        if self.global_weight > 0.0:
            results["rmse_gl"] = np.sqrt(error_gl) / atoms
        return results

    def print_header(self):  # depreciated
        prop_fmt = '   %11s %11s'
        print_str = ''
        print_str += prop_fmt % ('rmse_tst', 'rmse_trn')
        if self.local_weight > 0.0:
            print_str += prop_fmt % ('rmse_lc_tst', 'rmse_lc_trn')
        if self.global_weight > 0.0:
            print_str += prop_fmt % ('rmse_gl_tst', 'rmse_gl_trn')
        return print_str

    def print_on_training(self, 
                          tb_writer,
                          cur_batch,
                          sess, 
                          natoms,
                          feed_dict_test,
                          feed_dict_batch) :  # depreciated

        # YHT: added to calculate the atoms number
        atoms = 0
        if self.type_sel is not None:
            for w in self.type_sel:
                atoms += natoms[2+w]                   
        else:
            atoms = natoms[0]

        run_data = [self.l2_l, self.l2_more['local_loss'], self.l2_more['global_loss']]
        summary_list = [self.l2_loss_summary]
        if self.local_weight > 0.0:
            summary_list.append(self.l2_loss_local_summary)
        if self.global_weight > 0.0:
            summary_list.append(self.l2_loss_global_summary)

        # first train data
        error_train = run_sess(sess, run_data, feed_dict=feed_dict_batch)

        # than test data, if tensorboard log writter is present, commpute summary
        # and write tensorboard logs
        if tb_writer:
            #summary_merged_op = tf.summary.merge([self.l2_loss_summary])
            summary_merged_op = tf.summary.merge(summary_list)
            run_data.insert(0, summary_merged_op)

        test_out = run_sess(sess, run_data, feed_dict=feed_dict_test)

        if tb_writer:
            summary = test_out.pop(0)
            tb_writer.add_summary(summary, cur_batch)

        error_test = test_out  
        
        print_str = ""
        prop_fmt = "   %11.2e %11.2e"
        print_str += prop_fmt % (np.sqrt(error_test[0]), np.sqrt(error_train[0]))
        if self.local_weight > 0.0:
            print_str += prop_fmt % (np.sqrt(error_test[1]), np.sqrt(error_train[1]) )
        if self.global_weight > 0.0:
            print_str += prop_fmt % (np.sqrt(error_test[2])/atoms, np.sqrt(error_train[2])/atoms)

        return print_str