se.py 4 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Tuple, List

from deepmd.env import tf
from deepmd.utils.graph import get_embedding_net_variables_from_graph_def, get_tensor_by_name_from_graph
from .descriptor import Descriptor


class DescrptSe (Descriptor):
    """A base class for smooth version of descriptors.
    
    Notes
    -----
    All of these descriptors have an environmental matrix and an
    embedding network (:meth:`deepmd.utils.network.embedding_net`), so
    they can share some similiar methods without defining them twice.

    Attributes
    ----------
    embedding_net_variables : dict
        initial embedding network variables
    descrpt_reshape : tf.Tensor
        the reshaped descriptor
    descrpt_deriv : tf.Tensor
        the descriptor derivative
    rij : tf.Tensor
        distances between two atoms
    nlist : tf.Tensor
        the neighbor list
    
    """
    def _identity_tensors(self, suffix : str = "") -> None:
        """Identify tensors which are expected to be stored and restored.
        
        Notes
        -----
        These tensors will be indentitied:
            self.descrpt_reshape : o_rmat
            self.descrpt_deriv : o_rmat_deriv
            self.rij : o_rij
            self.nlist : o_nlist
        Thus, this method should be called during building the descriptor and
        after these tensors are initialized.

        Parameters
        ----------
        suffix : str
            The suffix of the scope
        """
        self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix)
        self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix)
        self.rij = tf.identity(self.rij, name = 'o_rij' + suffix)
        self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix)

    def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
        """Get names of tensors.
        
        Parameters
        ----------
        suffix : str
            The suffix of the scope

        Returns
        -------
        Tuple[str]
            Names of tensors
        """
        return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0')

    def pass_tensors_from_frz_model(self,
                                    descrpt_reshape : tf.Tensor,
                                    descrpt_deriv   : tf.Tensor,
                                    rij             : tf.Tensor,
                                    nlist           : tf.Tensor
    ):
        """
        Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def

        Parameters
        ----------
        descrpt_reshape
                The passed descrpt_reshape tensor
        descrpt_deriv
                The passed descrpt_deriv tensor
        rij
                The passed rij tensor
        nlist
                The passed nlist tensor
        """
        self.rij = rij
        self.nlist = nlist
        self.descrpt_deriv = descrpt_deriv
        self.descrpt_reshape = descrpt_reshape

    def init_variables(self,
                       graph: tf.Graph,
                       graph_def: tf.GraphDef,
                       suffix : str = "",
    ) -> None:
        """
        Init the embedding net variables with the given dict

        Parameters
        ----------
        graph : tf.Graph
            The input frozen model graph
        graph_def : tf.GraphDef
            The input frozen model graph_def
        suffix : str, optional
            The suffix of the scope
        """
        self.embedding_net_variables = get_embedding_net_variables_from_graph_def(graph_def, suffix = suffix)
        self.davg = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_avg' % suffix)
        self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_std' % suffix)

    @property
    def precision(self) -> tf.DType:
        """Precision of filter network."""
        return self.filter_precision