deep_eval.py 9.54 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
from typing import List, Optional, TYPE_CHECKING, Union
from functools import lru_cache

import numpy as np
from deepmd.common import make_default_mesh
from deepmd.env import default_tf_session_config, tf, MODEL_VERSION
from deepmd.utils.sess import run_sess
from deepmd.utils.batch_size import AutoBatchSize

if TYPE_CHECKING:
    from pathlib import Path


class DeepEval:
    """Common methods for DeepPot, DeepWFC, DeepPolar, ...
    
    Parameters
    ----------
    model_file : Path
        The name of the frozen model file.
    load_prefix: str
        The prefix in the load computational graph
    default_tf_graph : bool
        If uses the default tf graph, otherwise build a new tf graph for evaluation
    auto_batch_size : bool or int or AutomaticBatchSize, default: False
        If True, automatic batch size will be used. If int, it will be used
        as the initial batch size.
    """

    load_prefix: str  # set by subclass

    def __init__(
        self,
        model_file: "Path",
        load_prefix: str = "load",
        default_tf_graph: bool = False,
        auto_batch_size: Union[bool, int, AutoBatchSize] = False,
    ):
        self.graph = self._load_graph(
            model_file, prefix=load_prefix, default_tf_graph=default_tf_graph
        )
        self.load_prefix = load_prefix

        # graph_compatable should be called after graph and prefix are set
        if not self._graph_compatable():
            raise RuntimeError(
                f"model in graph (version {self.model_version}) is incompatible"
                f"with the model (version {MODEL_VERSION}) supported by the current code."
            )
        
        # set default to False, as subclasses may not support
        if isinstance(auto_batch_size, bool):
            if auto_batch_size:
                self.auto_batch_size = AutoBatchSize()
            else:
                self.auto_batch_size = None
        elif isinstance(auto_batch_size, int):
            self.auto_batch_size = AutoBatchSize(auto_batch_size)
        elif isinstance(auto_batch_size, AutoBatchSize):
            self.auto_batch_size = auto_batch_size
        else:
            raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize")

    @property
    @lru_cache(maxsize=None)
    def model_type(self) -> str:
        """Get type of model.

        :type:str
        """
        t_mt = self._get_tensor("model_attr/model_type:0")
        [mt] = run_sess(self.sess, [t_mt], feed_dict={})
        return mt.decode("utf-8")

    @property
    @lru_cache(maxsize=None)
    def model_version(self) -> str:
        """Get version of model.

        Returns
        -------
        str
            version of model
        """
        try:
            t_mt = self._get_tensor("model_attr/model_version:0")
        except KeyError:
            # For deepmd-kit version 0.x - 1.x, set model version to 0.0
            return "0.0"
        else:
            [mt] = run_sess(self.sess, [t_mt], feed_dict={})
            return mt.decode("utf-8")

    @property
    @lru_cache(maxsize=None)
    def sess(self) -> tf.Session:
        """Get TF session."""
        # start a tf session associated to the graph
        return tf.Session(graph=self.graph, config=default_tf_session_config)

    def _graph_compatable(
        self
    ) -> bool :
        """ Check the model compatability
        
        Returns
        -------
        bool
            If the model stored in the graph file is compatable with the current code
        """
        model_version_major = int(self.model_version.split('.')[0])
        model_version_minor = int(self.model_version.split('.')[1])
        MODEL_VERSION_MAJOR = int(MODEL_VERSION.split('.')[0])
        MODEL_VERSION_MINOR = int(MODEL_VERSION.split('.')[1])
        if (model_version_major != MODEL_VERSION_MAJOR) or \
           (model_version_minor >  MODEL_VERSION_MINOR) :
            return False
        else:
            return True

    def _get_tensor(
        self, tensor_name: str, attr_name: Optional[str] = None
    ) -> tf.Tensor:
        """Get TF graph tensor and assign it to class namespace.

        Parameters
        ----------
        tensor_name : str
            name of tensor to get
        attr_name : Optional[str], optional
            if specified, class attribute with this name will be created and tensor will
            be assigned to it, by default None

        Returns
        -------
        tf.Tensor
            loaded tensor
        """
        tensor_path = os.path.join(self.load_prefix, tensor_name)
        tensor = self.graph.get_tensor_by_name(tensor_path)
        if attr_name:
            setattr(self, attr_name, tensor)
            return tensor
        else:
            return tensor

    @staticmethod
    def _load_graph(
        frozen_graph_filename: "Path", prefix: str = "load", default_tf_graph: bool = False
    ):
        # We load the protobuf file from the disk and parse it to retrieve the
        # unserialized graph_def
        with tf.gfile.GFile(str(frozen_graph_filename), "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

            if default_tf_graph:
                tf.import_graph_def(
                    graph_def, 
                    input_map=None, 
                    return_elements=None, 
                    name=prefix, 
                    producer_op_list=None
                )
                graph = tf.get_default_graph()
            else :
                # Then, we can use again a convenient built-in function to import
                # a graph_def into the  current default Graph
                with tf.Graph().as_default() as graph:
                    tf.import_graph_def(
                        graph_def,
                        input_map=None,
                        return_elements=None,
                        name=prefix,
                        producer_op_list=None
                    )

            return graph

    @staticmethod
    def sort_input(
        coord : np.ndarray, atom_type : np.ndarray, sel_atoms : List[int] = None
    ):
        """
        Sort atoms in the system according their types.
        
        Parameters
        ----------
        coord
                The coordinates of atoms.
                Should be of shape [nframes, natoms, 3]
        atom_type
                The type of atoms
                Should be of shape [natoms]
        sel_atom
                The selected atoms by type
        
        Returns
        -------
        coord_out
                The coordinates after sorting
        atom_type_out
                The atom types after sorting
        idx_map
                The index mapping from the input to the output. 
                For example coord_out = coord[:,idx_map,:]
        sel_atom_type
                Only output if sel_atoms is not None
                The sorted selected atom types
        sel_idx_map
                Only output if sel_atoms is not None
                The index mapping from the selected atoms to sorted selected atoms.
        """
        if sel_atoms is not None:
            selection = [False] * np.size(atom_type)
            for ii in sel_atoms:
                selection += (atom_type == ii)
            sel_atom_type = atom_type[selection]
        natoms = atom_type.size
        idx = np.arange (natoms)
        idx_map = np.lexsort ((idx, atom_type))
        nframes = coord.shape[0]
        coord = coord.reshape([nframes, -1, 3])
        coord = np.reshape(coord[:,idx_map,:], [nframes, -1])
        atom_type = atom_type[idx_map]
        if sel_atoms is not None:
            sel_natoms = np.size(sel_atom_type)
            sel_idx = np.arange(sel_natoms)
            sel_idx_map = np.lexsort((sel_idx, sel_atom_type))
            sel_atom_type = sel_atom_type[sel_idx_map]
            return coord, atom_type, idx_map, sel_atom_type, sel_idx_map
        else:
            return coord, atom_type, idx_map

    @staticmethod
    def reverse_map(vec : np.ndarray, imap : List[int]) -> np.ndarray:
        """Reverse mapping of a vector according to the index map

        Parameters
        ----------
        vec
                Input vector. Be of shape [nframes, natoms, -1]
        imap
                Index map. Be of shape [natoms]
        
        Returns
        -------
        vec_out
                Reverse mapped vector.
        """
        ret = np.zeros(vec.shape)        
        # for idx,ii in enumerate(imap) :
        #     ret[:,ii,:] = vec[:,idx,:]
        ret[:, imap, :] = vec
        return ret


    def make_natoms_vec(self, atom_types : np.ndarray) -> np.ndarray :
        """Make the natom vector used by deepmd-kit.

        Parameters
        ----------
        atom_types
                The type of atoms
        
        Returns
        -------
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
  
        """
        natoms_vec = np.zeros (self.ntypes+2).astype(int)
        natoms = atom_types.size
        natoms_vec[0] = natoms
        natoms_vec[1] = natoms
        for ii in range (self.ntypes) :
            natoms_vec[ii+2] = np.count_nonzero(atom_types == ii)
        return natoms_vec