train.py 11.3 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
"""DeePMD training entrypoint script.

Can handle local or distributed training.
"""

import json
import logging
import time
import os
from typing import Dict, List, Optional, Any

from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have
from deepmd.env import tf, reset_default_tf_session_config, GLOBAL_ENER_FLOAT_PRECISION
from deepmd.infer.data_modifier import DipoleChargeModifier
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
from deepmd.train.trainer import DPTrainer
from deepmd.utils import random as dp_random
from deepmd.utils.argcheck import normalize
from deepmd.utils.compat import update_deepmd_input
from deepmd.utils.data_system import DeepmdDataSystem
from deepmd.utils.sess import run_sess
from deepmd.utils.neighbor_stat import NeighborStat
from deepmd.utils.path import DPPath

__all__ = ["train"]

log = logging.getLogger(__name__)


def train(
    *,
    INPUT: str,
    init_model: Optional[str],
    restart: Optional[str],
    output: str,
    init_frz_model: str,
    mpi_log: str,
    log_level: int,
    log_path: Optional[str],
    is_compress: bool = False,
    skip_neighbor_stat: bool = False,
    **kwargs,
):
    """Run DeePMD model training.

    Parameters
    ----------
    INPUT : str
        json/yaml control file
    init_model : Optional[str]
        path to checkpoint folder or None
    restart : Optional[str]
        path to checkpoint folder or None
    output : str
        path for dump file with arguments
    init_frz_model : str
        path to frozen model or None
    mpi_log : str
        mpi logging mode
    log_level : int
        logging level defined by int 0-3
    log_path : Optional[str]
        logging file path or None if logs are to be output only to stdout
    is_compress: bool
        indicates whether in the model compress mode
    skip_neighbor_stat : bool, default=False
        skip checking neighbor statistics

    Raises
    ------
    RuntimeError
        if distributed training job nem is wrong
    """
    run_opt = RunOptions(
        init_model=init_model,
        restart=restart,
        init_frz_model=init_frz_model,
        log_path=log_path,
        log_level=log_level,
        mpi_log=mpi_log
    )
    if run_opt.is_distrib and len(run_opt.gpus or []) > 1:
        # avoid conflict of visible gpus among multipe tf sessions in one process
        reset_default_tf_session_config(cpu_only=True)

    # load json database
    jdata = j_loader(INPUT)

    jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

    jdata = normalize(jdata)

    if not is_compress and not skip_neighbor_stat:
        jdata = update_sel(jdata)

    with open(output, "w") as fp:
        json.dump(jdata, fp, indent=4)

    # save the training script into the graph
    # remove white spaces as it is not compressed
    tf.constant(json.dumps(jdata, separators=(',', ':')), name='train_attr/training_script', dtype=tf.string)

    for message in WELCOME + CITATION + BUILD:
        log.info(message)

    run_opt.print_resource_summary()
    _do_work(jdata, run_opt, is_compress)


def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = False):
    """Run serial model training.

    Parameters
    ----------
    jdata : Dict[str, Any]
        arguments read form json/yaml control file
    run_opt : RunOptions
        object with run configuration
    is_compress : Bool
        indicates whether in model compress mode

    Raises
    ------
    RuntimeError
        If unsupported modifier type is selected for model
    """
    # make necessary checks
    assert "training" in jdata

    # init the model
    model = DPTrainer(jdata, run_opt=run_opt, is_compress = is_compress)
    rcut = model.model.get_rcut()
    type_map = model.model.get_type_map()
    if len(type_map) == 0:
        ipt_type_map = None
    else:
        ipt_type_map = type_map

    # init random seed of data systems
    seed = jdata["training"].get("seed", None)
    if seed is not None:
        # avoid the same batch sequence among workers
        seed += run_opt.my_rank
        seed = seed % (2 ** 32)
    dp_random.seed(seed)

    # setup data modifier
    modifier = get_modifier(jdata["model"].get("modifier", None))

    # decouple the training data from the model compress process
    train_data = None
    valid_data = None
    if not is_compress:
        # init data
        train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier)
        train_data.print_summary("training")
        if jdata["training"].get("validation_data", None) is not None:
            valid_data = get_data(jdata["training"]["validation_data"], rcut, train_data.type_map, modifier)
            valid_data.print_summary("validation")

    # get training info
    stop_batch = j_must_have(jdata["training"], "numb_steps")
    model.build(train_data, stop_batch)

    if not is_compress:
        # train the model with the provided systems in a cyclic way
        start_time = time.time()
        model.train(train_data, valid_data)
        end_time = time.time()
        log.info("finished training")
        log.info(f"wall time: {(end_time - start_time):.3f} s")
    else:
        model.save_compressed()
        log.info("finished compressing")


def get_data(jdata: Dict[str, Any], rcut, type_map, modifier):
    systems = j_must_have(jdata, "systems")
    if isinstance(systems, str):
        systems = expand_sys_str(systems)
    help_msg = 'Please check your setting for data systems'
    # check length of systems
    if len(systems) == 0:
        msg = 'cannot find valid a data system'
        log.fatal(msg)
        raise IOError(msg, help_msg)
    # rougly check all items in systems are valid
    for ii in systems:
        ii = DPPath(ii)
        if (not ii.is_dir()):
            msg = f'dir {ii} is not a valid dir'
            log.fatal(msg)
            raise IOError(msg, help_msg)
        if (not (ii / 'type.raw').is_file()):
            msg = f'dir {ii} is not a valid data system dir'
            log.fatal(msg)
            raise IOError(msg, help_msg)

    batch_size = j_must_have(jdata, "batch_size")
    sys_probs = jdata.get("sys_probs", None)
    auto_prob = jdata.get("auto_prob", "prob_sys_size")

    data = DeepmdDataSystem(
        systems=systems,
        batch_size=batch_size,
        test_size=1,        # to satisfy the old api
        shuffle_test=True,  # to satisfy the old api
        rcut=rcut,
        type_map=type_map,
        modifier=modifier,
        trn_all_set=True,    # sample from all sets
        sys_probs=sys_probs,
        auto_prob_style=auto_prob
    )
    data.add_dict(data_requirement)

    return data


def get_modifier(modi_data=None):
    modifier: Optional[DipoleChargeModifier]
    if modi_data is not None:
        if modi_data["type"] == "dipole_charge":
            modifier = DipoleChargeModifier(
                modi_data["model_name"],
                modi_data["model_charge_map"],
                modi_data["sys_charge_map"],
                modi_data["ewald_h"],
                modi_data["ewald_beta"],
            )
        else:
            raise RuntimeError("unknown modifier type " + str(modi_data["type"]))
    else:
        modifier = None
    return modifier


def get_rcut(jdata):
    descrpt_data = jdata['model']['descriptor']
    rcut_list = []
    if descrpt_data['type'] == 'hybrid':
        for ii in descrpt_data['list']:
            rcut_list.append(ii['rcut'])
    else:
        rcut_list.append(descrpt_data['rcut'])
    return max(rcut_list)


def get_type_map(jdata):
    return jdata['model'].get('type_map', None)


def get_nbor_stat(jdata, rcut, one_type: bool = False):
    max_rcut = get_rcut(jdata)
    type_map = get_type_map(jdata)

    if type_map and len(type_map) == 0:
        type_map = None
    train_data = get_data(jdata["training"]["training_data"], max_rcut, type_map, None)
    train_data.get_batch()
    data_ntypes = train_data.get_ntypes()
    if type_map is not None:
        map_ntypes = len(type_map)
    else:
        map_ntypes = data_ntypes
    ntypes = max([map_ntypes, data_ntypes])

    neistat = NeighborStat(ntypes, rcut, one_type=one_type)

    min_nbor_dist, max_nbor_size = neistat.get_stat(train_data)

    # moved from traier.py as duplicated
    # TODO: this is a simple fix but we should have a clear
    #       architecture to call neighbor stat
    tf.constant(min_nbor_dist,
        name = 'train_attr/min_nbor_dist',
        dtype = GLOBAL_ENER_FLOAT_PRECISION)
    tf.constant(max_nbor_size,
        name = 'train_attr/max_nbor_size',
        dtype = tf.int32)
    return min_nbor_dist, max_nbor_size

def get_sel(jdata, rcut, one_type: bool = False):
    _, max_nbor_size = get_nbor_stat(jdata, rcut, one_type=one_type)
    return max_nbor_size

def get_min_nbor_dist(jdata, rcut):
    min_nbor_dist, _ = get_nbor_stat(jdata, rcut)
    return min_nbor_dist

def parse_auto_sel(sel):
    if type(sel) is not str:
        return False
    words = sel.split(':')
    if words[0] == 'auto':
        return True
    else:
        return False

    
def parse_auto_sel_ratio(sel):
    if not parse_auto_sel(sel):
        raise RuntimeError(f'invalid auto sel format {sel}')
    else:
        words = sel.split(':')
        if len(words) == 1:
            ratio = 1.1
        elif len(words) == 2:
            ratio = float(words[1])
        else:
            raise RuntimeError(f'invalid auto sel format {sel}')
        return ratio


def wrap_up_4(xx):
    return 4 * ((int(xx) + 3) // 4)


def update_one_sel(jdata, descriptor):
    if descriptor['type'] == 'loc_frame':
        return descriptor
    rcut = descriptor['rcut']
    tmp_sel = get_sel(jdata, rcut, one_type=descriptor['type'] in ('se_atten',))
    sel = descriptor['sel']
    if isinstance(sel, int):
        # convert to list and finnally convert back to int
        sel = [sel]
    if parse_auto_sel(descriptor['sel']) :
        ratio = parse_auto_sel_ratio(descriptor['sel'])
        descriptor['sel'] = sel = [int(wrap_up_4(ii * ratio)) for ii in tmp_sel]
    else:
        # sel is set by user
        for ii, (tt, dd) in enumerate(zip(tmp_sel, sel)):
            if dd and tt > dd:
                # we may skip warning for sel=0, where the user is likely
                # to exclude such type in the descriptor
                log.warning(
                    "sel of type %d is not enough! The expected value is "
                    "not less than %d, but you set it to %d. The accuracy"
                    " of your model may get worse." %(ii, tt, dd)
                )
    if descriptor['type'] in ('se_atten',):
        descriptor['sel'] = sel = sum(sel)
    return descriptor


def update_sel(jdata):    
    log.info("Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)")
    descrpt_data = jdata['model']['descriptor']
    if descrpt_data['type'] == 'hybrid':
        for ii in range(len(descrpt_data['list'])):
            descrpt_data['list'][ii] = update_one_sel(jdata, descrpt_data['list'][ii])
    else:
        descrpt_data = update_one_sel(jdata, descrpt_data)
    jdata['model']['descriptor'] = descrpt_data
    return jdata