freeze.py 1.63 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

#!/usr/bin/env python3

from deepmd.env import tf
from deepmd.nvnmd.utils.fio import FioDic


def filter_tensorVariableList(tensorVariableList) -> dict:
    r"""Get the name of variable for NVNMD

    | :code:`descrpt_attr/t_avg:0`
    | :code:`descrpt_attr/t_std:0`
    | :code:`filter_type_{atom i}/matrix_{layer l}_{atomj}:0`
    | :code:`filter_type_{atom i}/bias_{layer l}_{atomj}:0`
    | :code:`layer_{layer l}_type_{atom i}/matrix:0`
    | :code:`layer_{layer l}_type_{atom i}/bias:0`
    | :code:`final_layer_type_{atom i}/matrix:0`
    | :code:`final_layer_type_{atom i}/bias:0`
    """
    nameList = [tv.name for tv in tensorVariableList]
    nameList = [name.replace(':0', '') for name in nameList]
    nameList = [name.replace('/', '.') for name in nameList]

    dic_name_tv = {}
    for ii in range(len(nameList)):
        name = nameList[ii]
        tv = tensorVariableList[ii]
        p1 = name.startswith('descrpt_attr')
        p1 = p1 or name.startswith('filter_type_')
        p1 = p1 or name.startswith('layer_')
        p1 = p1 or name.startswith('final_layer_type_')
        p2 = 'Adam' not in name
        p3 = 'XXX' not in name
        if p1 and p2 and p3:
            dic_name_tv[name] = tv
    return dic_name_tv


def save_weight(sess, file_name: str = 'nvnmd/weight.npy'):
    r"""Save the dictionary of weight to a npy file
    """
    tvs = tf.global_variables()
    dic_key_tv = filter_tensorVariableList(tvs)
    dic_key_value = {}
    for key in dic_key_tv.keys():
        value = sess.run(dic_key_tv[key])
        dic_key_value[key] = value
    FioDic().save(file_name, dic_key_value)