__init__.py 2.3 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Submodule containing all the implemented potentials."""

from pathlib import Path
from typing import Union

from .data_modifier import DipoleChargeModifier
from .deep_dipole import DeepDipole
from .deep_eval import DeepEval
from .deep_polar import DeepGlobalPolar, DeepPolar
from .deep_pot import DeepPot
from .deep_wfc import DeepWFC
from .ewald_recp import EwaldRecp
from .model_devi import calc_model_devi

__all__ = [
    "DeepPotential",
    "DeepDipole",
    "DeepEval",
    "DeepGlobalPolar",
    "DeepPolar",
    "DeepPot",
    "DeepWFC",
    "DipoleChargeModifier",
    "EwaldRecp",
    "calc_model_devi"
]


def DeepPotential(
    model_file: Union[str, Path],
    load_prefix: str = "load",
    default_tf_graph: bool = False,
) -> Union[DeepDipole, DeepGlobalPolar, DeepPolar, DeepPot, DeepWFC]:
    """Factory function that will inialize appropriate potential read from `model_file`.

    Parameters
    ----------
    model_file: str
        The name of the frozen model file.
    load_prefix: str
        The prefix in the load computational graph
    default_tf_graph : bool
        If uses the default tf graph, otherwise build a new tf graph for evaluation

    Returns
    -------
    Union[DeepDipole, DeepGlobalPolar, DeepPolar, DeepPot, DeepWFC]
        one of the available potentials

    Raises
    ------
    RuntimeError
        if model file does not correspond to any implementd potential
    """
    mf = Path(model_file)

    model_type = DeepEval(
        mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph
    ).model_type

    if model_type == "ener":
        dp = DeepPot(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph)
    elif model_type == "dipole":
        dp = DeepDipole(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph)
    elif model_type == "polar":
        dp = DeepPolar(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph)
    elif model_type == "global_polar":
        dp = DeepGlobalPolar(
            mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph
        )
    elif model_type == "wfc":
        dp = DeepWFC(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph)
    else:
        raise RuntimeError(f"unknow model type {model_type}")

    return dp