config.py 8.79 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#!/usr/bin/env python3
"""Quickly create a configuration file for smooth model."""

import json
import yaml
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np

__all__ = ["config"]


DEFAULT_DATA: Dict[str, Any] = {
    "use_smooth": True,
    "sel_a": [],
    "rcut_smth": -1,
    "rcut": -1,
    "filter_neuron": [20, 40, 80],
    "filter_resnet_dt": False,
    "axis_neuron": 8,
    "fitting_neuron": [240, 240, 240],
    "fitting_resnet_dt": True,
    "coord_norm": True,
    "type_fitting_net": False,
    "systems": [],
    "set_prefix": "set",
    "stop_batch": -1,
    "batch_size": -1,
    "start_lr": 0.001,
    "decay_steps": -1,
    "decay_rate": 0.95,
    "start_pref_e": 0.02,
    "limit_pref_e": 1,
    "start_pref_f": 1000,
    "limit_pref_f": 1,
    "start_pref_v": 0,
    "limit_pref_v": 0,
    "seed": 1,
    "disp_file": "lcurve.out",
    "disp_freq": 1000,
    "numb_test": 10,
    "save_freq": 10000,
    "save_ckpt": "model.ckpt",
    "disp_training": True,
    "time_training": True,
}


def valid_dir(path: Path):
    """Check if directory is a valid deepmd system directory.

    Parameters
    ----------
    path : Path
        path to directory

    Raises
    ------
    OSError
        if `type.raw` is missing on dir or `box.npy` or `coord.npy` are missing in one
        of the sets subdirs
    """
    if not (path / "type.raw").is_file():
        raise OSError
    for ii in path.glob("set.*"):
        if not (ii / "box.npy").is_file():
            raise OSError
        if not (ii / "coord.npy").is_file():
            raise OSError


def load_systems(dirs: List[Path]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """Load systems to memory for disk.

    Parameters
    ----------
    dirs : List[Path]
        list of system directories paths

    Returns
    -------
    Tuple[List[np.ndarray], List[np.ndarray]]
        atoms types and structure cells formated as Nx9 array
    """
    all_type = []
    all_box = []
    for d in dirs:
        sys_type = np.loadtxt(d / "type.raw", dtype=int)
        sys_box = np.vstack([np.load(s / "box.npy") for s in d.glob("set.*")])
        all_type.append(sys_type)
        all_box.append(sys_box)
    return all_type, all_box


def get_system_names() -> List[Path]:
    """Get system directory paths from stdin.

    Returns
    -------
    List[Path]
        list of system directories paths
    """
    dirs = input("Enter system path(s) (seperated by space, wild card supported): \n")
    system_dirs = []
    for dir_str in dirs.split():
        found_dirs = Path.cwd().glob(dir_str)
        for d in found_dirs:
            valid_dir(d)
            system_dirs.append(d)

    return system_dirs


def get_rcut() -> float:
    """Get rcut from stdin from user.

    Returns
    -------
    float
        input rcut lenght converted to float

    Raises
    ------
    ValueError
        if rcut is smaller than 0.0
    """
    dv = 6.0
    rcut_input = input(f"Enter rcut (default {dv:.1f} A): \n")
    try:
        rcut = float(rcut_input)
    except ValueError as e:
        print(f"invalid rcut: {e} setting to default: {dv:.1f}")
        rcut = dv
    if rcut <= 0:
        raise ValueError("rcut should be > 0")
    return rcut


def get_batch_size_rule() -> int:
    """Get minimal batch size from user from stdin.

    Returns
    -------
    int
        size of the batch

    Raises
    ------
    ValueError
        if batch size is <= 0
    """
    dv = 32
    matom_input = input(
        f"Enter the minimal number of atoms in a batch (default {dv:d}: \n"
    )
    try:
        matom = int(matom_input)
    except ValueError as e:
        print(f"invalid batch size: {e} setting to default: {dv:d}")
        matom = dv
    if matom <= 0:
        raise ValueError("the number should be > 0")
    return matom


def get_stop_batch() -> int:
    """Get stop batch from user from stdin.

    Returns
    -------
    int
        size of the batch

    Raises
    ------
    ValueError
        if stop batch is <= 0
    """
    dv = 1000000
    sb_input = input(f"Enter the stop batch (default {dv:d}): \n")
    try:
        sb = int(sb_input)
    except ValueError as e:
        print(f"invalid stop batch: {e} setting to default: {dv:d}")
        sb = dv
    if sb <= 0:
        raise ValueError("the number should be > 0")
    return sb


def get_ntypes(all_type: List[np.ndarray]) -> int:
    """Count number of unique elements.

    Parameters
    ----------
    all_type : List[np.ndarray]
        list with arrays specifying elements of structures

    Returns
    -------
    int
        number of unique elements
    """
    return len(np.unique(all_type))


def get_max_density(
    all_type: List[np.ndarray], all_box: List[np.ndarray]
) -> np.ndarray:
    """Compute maximum density in suppliedd cells.

    Parameters
    ----------
    all_type : List[np.ndarray]
        list with arrays specifying elements of structures
    all_box : List[np.ndarray]
        list with arrays specifying cells for all structures

    Returns
    -------
    float
        maximum atom density in all supplies structures for each element individually
    """
    ntypes = get_ntypes(all_type)
    all_max = []
    for tt, bb in zip(all_type, all_box):
        vv = np.reshape(bb, [-1, 3, 3])
        vv = np.linalg.det(vv)
        min_v = np.min(vv)
        type_count = []
        for ii in range(ntypes):
            type_count.append(sum(tt == ii))
        max_den = type_count / min_v
        all_max.append(max_den)
    all_max = np.max(all_max, axis=0)
    return all_max


def suggest_sel(
    all_type: List[np.ndarray],
    all_box: List[np.ndarray],
    rcut: float,
    ratio: float = 1.5,
) -> List[int]:
    """Suggest selection parameter.

    Parameters
    ----------
    all_type : List[np.ndarray]
        list with arrays specifying elements of structures
    all_box : List[np.ndarray]
        list with arrays specifying cells for all structures
    rcut : float
        cutoff radius
    ratio : float, optional
        safety margin to add to estimated value, by default 1.5

    Returns
    -------
    List[int]
        [description]
    """
    max_den = get_max_density(all_type, all_box)
    return [int(ii) for ii in max_den * 4.0 / 3.0 * np.pi * rcut ** 3 * ratio]


def suggest_batch_size(all_type: List[np.ndarray], min_atom: int) -> List[int]:
    """Get suggestion for batch size.

    Parameters
    ----------
    all_type : List[np.ndarray]
        list with arrays specifying elements of structures
    min_atom : int
        minimal number of atoms in batch

    Returns
    -------
    List[int]
        suggested batch sizes for each system
    """
    bs = []
    for ii in all_type:
        natoms = len(ii)
        tbs = min_atom // natoms
        if (min_atom // natoms) * natoms != min_atom:
            tbs += 1
        bs.append(tbs)
    return bs


def suggest_decay(stop_batch: int) -> Tuple[int, float]:
    """Suggest number of decay steps and decay rate.

    Parameters
    ----------
    stop_batch : int
        stop batch number

    Returns
    -------
    Tuple[int, float]
        number of decay steps and decay rate
    """
    decay_steps = int(stop_batch // 200)
    decay_rate = 0.95
    return decay_steps, decay_rate


def config(*, output: str, **kwargs):
    """Auto config file generator.

    Parameters
    ----------
    output: str
        file to write config file

    Raises
    ------
    RuntimeError
        if user does not input any systems
    ValueError
        if output file is of wrong type
    """
    all_sys = get_system_names()
    if len(all_sys) == 0:
        raise RuntimeError("no system specified")
    rcut = get_rcut()
    matom = get_batch_size_rule()
    stop_batch = get_stop_batch()

    all_type, all_box = load_systems(all_sys)
    sel = suggest_sel(all_type, all_box, rcut, ratio=1.5)
    bs = suggest_batch_size(all_type, matom)
    decay_steps, decay_rate = suggest_decay(stop_batch)

    jdata = DEFAULT_DATA.copy()
    jdata["systems"] = [str(ii) for ii in all_sys]
    jdata["sel_a"] = sel
    jdata["rcut"] = rcut
    jdata["rcut_smth"] = rcut - 0.2
    jdata["stop_batch"] = stop_batch
    jdata["batch_size"] = bs
    jdata["decay_steps"] = decay_steps
    jdata["decay_rate"] = decay_rate

    with open(output, "w") as fp:
        if output.endswith("json"):
            json.dump(jdata, fp, indent=4)
        elif output.endswith(("yml", "yaml")):
            yaml.safe_dump(jdata, fp, default_flow_style=False)
        else:
            raise ValueError("output file must be of type json or yaml")