neighbor_stat.py 1.52 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import logging
from typing import List

from deepmd.common import expand_sys_str
from deepmd.utils.data_system import DeepmdDataSystem
from deepmd.utils.neighbor_stat import NeighborStat

log = logging.getLogger(__name__)

def neighbor_stat(
    *,
    system: str,
    rcut: float,
    type_map: List[str],
    one_type: bool = False,
    **kwargs,
):
    """Calculate neighbor statistics.

    Parameters
    ----------
    system : str
        system to stat
    rcut : float
        cutoff radius
    type_map : list[str]
        type map
    one_type : bool, optional, default=False
        treat all types as a single type

    Examples
    --------
    >>> neighbor_stat(system='.', rcut=6., type_map=["C", "H", "O", "N", "P", "S", "Mg", "Na", "HW", "OW", "mNa", "mCl", "mC", "mH", "mMg", "mN", "mO", "mP"])
    min_nbor_dist: 0.6599510670195264
    max_nbor_size: [23, 26, 19, 16, 2, 2, 1, 1, 72, 37, 5, 0, 31, 29, 1, 21, 20, 5]
    """ 
    all_sys = expand_sys_str(system)
    if not len(all_sys):
        raise RuntimeError("Did not find valid system")
    data = DeepmdDataSystem(
        systems=all_sys,
        batch_size=1,
        test_size=1,
        rcut=rcut,
        type_map=type_map,
    )
    data.get_batch()
    nei = NeighborStat(data.get_ntypes(), rcut, one_type=one_type)
    min_nbor_dist, max_nbor_size = nei.get_stat(data)
    log.info("min_nbor_dist: %f" % min_nbor_dist)
    log.info("max_nbor_size: %s" % str(max_nbor_size))
    return min_nbor_dist, max_nbor_size