deep_tensor.py 11.6 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os
from typing import List, Optional, TYPE_CHECKING, Tuple

import numpy as np
from deepmd.common import make_default_mesh
from deepmd.env import default_tf_session_config, tf
from deepmd.infer.deep_eval import DeepEval
from deepmd.utils.sess import run_sess

if TYPE_CHECKING:
    from pathlib import Path

class DeepTensor(DeepEval):
    """Evaluates a tensor model.

    Parameters
    ----------
    model_file: str
        The name of the frozen model file.
    load_prefix: str
        The prefix in the load computational graph
    default_tf_graph : bool
        If uses the default tf graph, otherwise build a new tf graph for evaluation
    """

    tensors = {
        # descriptor attrs
        "t_ntypes": "descrpt_attr/ntypes:0",
        "t_rcut": "descrpt_attr/rcut:0",
        # model attrs
        "t_tmap": "model_attr/tmap:0",
        "t_sel_type": "model_attr/sel_type:0",
        "t_ouput_dim": "model_attr/output_dim:0",
        # inputs
        "t_coord": "t_coord:0",
        "t_type": "t_type:0",
        "t_natoms": "t_natoms:0",
        "t_box": "t_box:0",
        "t_mesh": "t_mesh:0",
    }

    def __init__(
        self,
        model_file: "Path",
        load_prefix: str = 'load',
        default_tf_graph: bool = False
    ) -> None:
        """Constructor"""
        DeepEval.__init__(
            self,
            model_file,
            load_prefix=load_prefix,
            default_tf_graph=default_tf_graph
        )
        # check model type
        model_type = self.tensors["t_tensor"][2:-2]
        assert self.model_type == model_type, \
            f"expect {model_type} model but got {self.model_type}"

        # now load tensors to object attributes
        for attr_name, tensor_name in self.tensors.items():
            self._get_tensor(tensor_name, attr_name)
        
        # load optional tensors if possible
        optional_tensors = {
            "t_global_tensor": f"o_global_{model_type}:0",
            "t_force": "o_force:0",
            "t_virial": "o_virial:0",
            "t_atom_virial": "o_atom_virial:0"
        }
        try:
            # first make sure these tensor all exists (but do not modify self attr)
            for attr_name, tensor_name in optional_tensors.items():
                self._get_tensor(tensor_name)
            # then put those into self.attrs
            for attr_name, tensor_name in optional_tensors.items():
                self._get_tensor(tensor_name, attr_name)
        except KeyError:
            self._support_gfv = False
        else:
            self.tensors.update(optional_tensors)
            self._support_gfv = True
            
        self._run_default_sess()
        self.tmap = self.tmap.decode('UTF-8').split()

    def _run_default_sess(self):
        [self.ntypes, self.rcut, self.tmap, self.tselt, self.output_dim] \
            = run_sess(self.sess, 
                [self.t_ntypes, self.t_rcut, self.t_tmap, self.t_sel_type, self.t_ouput_dim]
            )

    def get_ntypes(self) -> int:
        """Get the number of atom types of this model."""
        return self.ntypes

    def get_rcut(self) -> float:
        """Get the cut-off radius of this model."""
        return self.rcut

    def get_type_map(self) -> List[int]:
        """Get the type map (element name of the atom types) of this model."""
        return self.tmap

    def get_sel_type(self) -> List[int]:
        """Get the selected atom types of this model."""        
        return self.tselt

    def get_dim_fparam(self) -> int:
        """Get the number (dimension) of frame parameters of this DP."""
        return self.dfparam

    def get_dim_aparam(self) -> int:
        """Get the number (dimension) of atomic parameters of this DP."""
        return self.daparam

    def eval(
        self,
        coords: np.ndarray,
        cells: np.ndarray,
        atom_types: List[int],
        atomic: bool = True,
        fparam: Optional[np.ndarray] = None,
        aparam: Optional[np.ndarray] = None,
        efield: Optional[np.ndarray] = None
    ) -> np.ndarray:
        """Evaluate the model.

        Parameters
        ----------
        coords
            The coordinates of atoms. 
            The array should be of size nframes x natoms x 3
        cells
            The cell of the region. 
            If None then non-PBC is assumed, otherwise using PBC. 
            The array should be of size nframes x 9
        atom_types
            The atom types
            The list should contain natoms ints
        atomic
            If True (default), return the atomic tensor
            Otherwise return the global tensor
        fparam
            Not used in this model
        aparam
            Not used in this model
        efield
            Not used in this model

        Returns
        -------
        tensor
                The returned tensor
                If atomic == False then of size nframes x output_dim
                else of size nframes x natoms x output_dim
        """
        # standarize the shape of inputs
        atom_types = np.array(atom_types, dtype = int).reshape([-1])
        natoms = atom_types.size
        coords = np.reshape(np.array(coords), [-1, natoms * 3])
        nframes = coords.shape[0]
        if cells is None:
            pbc = False
            cells = np.tile(np.eye(3), [nframes, 1]).reshape([nframes, 9])
        else:
            pbc = True
            cells = np.array(cells).reshape([nframes, 9])

        # sort inputs
        coords, atom_types, imap, sel_at, sel_imap = self.sort_input(coords, atom_types, sel_atoms = self.get_sel_type())

        # make natoms_vec and default_mesh
        natoms_vec = self.make_natoms_vec(atom_types)
        assert(natoms_vec[0] == natoms)

        # evaluate
        feed_dict_test = {}
        feed_dict_test[self.t_natoms] = natoms_vec
        feed_dict_test[self.t_type  ] = np.tile(atom_types, [nframes,1]).reshape([-1])
        feed_dict_test[self.t_coord] = np.reshape(coords, [-1])
        feed_dict_test[self.t_box  ] = np.reshape(cells , [-1])
        if pbc:
            feed_dict_test[self.t_mesh ] = make_default_mesh(cells)
        else:
            feed_dict_test[self.t_mesh ] = np.array([], dtype = np.int32)

        if atomic:
            assert "global" not in self.model_type, \
                f"cannot do atomic evaluation with model type {self.model_type}"
            t_out = [self.t_tensor]
        else:
            assert self._support_gfv or "global" in self.model_type, \
                f"do not support global tensor evaluation with old {self.model_type} model"
            t_out = [self.t_global_tensor if self._support_gfv else self.t_tensor]
        v_out = self.sess.run (t_out, feed_dict = feed_dict_test)
        tensor = v_out[0]

        # reverse map of the outputs
        if atomic:
            tensor = np.array(tensor)
            tensor = self.reverse_map(np.reshape(tensor, [nframes,-1,self.output_dim]), sel_imap)
            tensor = np.reshape(tensor, [nframes, len(sel_at), self.output_dim])
        else:
            tensor = np.reshape(tensor, [nframes, self.output_dim])
        
        return tensor

    def eval_full(
        self,
        coords: np.ndarray,
        cells: np.ndarray,
        atom_types: List[int],
        atomic: bool = False,
        fparam: Optional[np.array] = None,
        aparam: Optional[np.array] = None,
        efield: Optional[np.array] = None
    ) -> Tuple[np.ndarray, ...]:
        """Evaluate the model with interface similar to the energy model.
        Will return global tensor, component-wise force and virial
        and optionally atomic tensor and atomic virial.

        Parameters
        ----------
        coords
            The coordinates of atoms. 
            The array should be of size nframes x natoms x 3
        cells
            The cell of the region. 
            If None then non-PBC is assumed, otherwise using PBC. 
            The array should be of size nframes x 9
        atom_types
            The atom types
            The list should contain natoms ints
        atomic
            Whether to calculate atomic tensor and virial
        fparam
            Not used in this model
        aparam
            Not used in this model
        efield
            Not used in this model

        Returns
        -------
        tensor
            The global tensor. 
            shape: [nframes x nout]
        force
            The component-wise force (negative derivative) on each atom.
            shape: [nframes x nout x natoms x 3]
        virial
            The component-wise virial of the tensor.
            shape: [nframes x nout x 9]
        atom_tensor
            The atomic tensor. Only returned when atomic == True
            shape: [nframes x natoms x nout]
        atom_virial
            The atomic virial. Only returned when atomic == True
            shape: [nframes x nout x natoms x 9]
        """
        assert self._support_gfv, \
            f"do not support eval_full with old tensor model"

        # standarize the shape of inputs
        atom_types = np.array(atom_types, dtype = int).reshape([-1])
        natoms = atom_types.size
        coords = np.reshape(np.array(coords), [-1, natoms * 3])
        nframes = coords.shape[0]
        if cells is None:
            pbc = False
            cells = np.tile(np.eye(3), [nframes, 1]).reshape([nframes, 9])
        else:
            pbc = True
            cells = np.array(cells).reshape([nframes, 9])
        nout = self.output_dim

        # sort inputs
        coords, atom_types, imap, sel_at, sel_imap = self.sort_input(coords, atom_types, sel_atoms = self.get_sel_type())

        # make natoms_vec and default_mesh
        natoms_vec = self.make_natoms_vec(atom_types)
        assert(natoms_vec[0] == natoms)

        # evaluate
        feed_dict_test = {}
        feed_dict_test[self.t_natoms] = natoms_vec
        feed_dict_test[self.t_type  ] = np.tile(atom_types, [nframes,1]).reshape([-1])
        feed_dict_test[self.t_coord] = np.reshape(coords, [-1])
        feed_dict_test[self.t_box  ] = np.reshape(cells , [-1])
        if pbc:
            feed_dict_test[self.t_mesh ] = make_default_mesh(cells)
        else:
            feed_dict_test[self.t_mesh ] = np.array([], dtype = np.int32)
        
        t_out = [self.t_global_tensor, 
                 self.t_force, 
                 self.t_virial]
        if atomic :
            t_out += [self.t_tensor, 
                      self.t_atom_virial]
        
        v_out = self.sess.run (t_out, feed_dict = feed_dict_test)
        gt     = v_out[0] # global tensor
        force  = v_out[1]
        virial = v_out[2]
        if atomic:
            at = v_out[3] # atom tensor
            av = v_out[4] # atom virial

        # please note here the shape are wrong!
        force  = self.reverse_map(np.reshape(force, [nframes*nout, natoms ,3]), imap)
        if atomic:
            at = self.reverse_map(np.reshape(at, [nframes, len(sel_at), nout]), sel_imap)
            av = self.reverse_map(np.reshape(av, [nframes*nout, natoms, 9]), imap)
        
        # make sure the shapes are correct here
        gt     = np.reshape(gt, [nframes, nout])
        force  = np.reshape(force, [nframes, nout, natoms, 3])
        virial = np.reshape(virial, [nframes, nout, 9])
        if atomic:
            at = np.reshape(at, [nframes, len(sel_at), self.output_dim])
            av = np.reshape(av, [nframes, nout, natoms, 9])
            return gt, force, virial, at, av
        else:
            return gt, force, virial