"""DeePMD training entrypoint script. Can handle local or distributed training. """ import json import logging import time import os from typing import Dict, List, Optional, Any from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have from deepmd.env import tf, reset_default_tf_session_config, GLOBAL_ENER_FLOAT_PRECISION from deepmd.infer.data_modifier import DipoleChargeModifier from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions from deepmd.train.trainer import DPTrainer from deepmd.utils import random as dp_random from deepmd.utils.argcheck import normalize from deepmd.utils.compat import update_deepmd_input from deepmd.utils.data_system import DeepmdDataSystem from deepmd.utils.sess import run_sess from deepmd.utils.neighbor_stat import NeighborStat from deepmd.utils.path import DPPath __all__ = ["train"] log = logging.getLogger(__name__) def train( *, INPUT: str, init_model: Optional[str], restart: Optional[str], output: str, init_frz_model: str, mpi_log: str, log_level: int, log_path: Optional[str], is_compress: bool = False, skip_neighbor_stat: bool = False, **kwargs, ): """Run DeePMD model training. Parameters ---------- INPUT : str json/yaml control file init_model : Optional[str] path to checkpoint folder or None restart : Optional[str] path to checkpoint folder or None output : str path for dump file with arguments init_frz_model : str path to frozen model or None mpi_log : str mpi logging mode log_level : int logging level defined by int 0-3 log_path : Optional[str] logging file path or None if logs are to be output only to stdout is_compress: bool indicates whether in the model compress mode skip_neighbor_stat : bool, default=False skip checking neighbor statistics Raises ------ RuntimeError if distributed training job nem is wrong """ run_opt = RunOptions( init_model=init_model, restart=restart, init_frz_model=init_frz_model, log_path=log_path, log_level=log_level, mpi_log=mpi_log ) if run_opt.is_distrib and len(run_opt.gpus or []) > 1: # avoid conflict of visible gpus among multipe tf sessions in one process reset_default_tf_session_config(cpu_only=True) # load json database jdata = j_loader(INPUT) jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") jdata = normalize(jdata) if not is_compress and not skip_neighbor_stat: jdata = update_sel(jdata) with open(output, "w") as fp: json.dump(jdata, fp, indent=4) # save the training script into the graph # remove white spaces as it is not compressed tf.constant(json.dumps(jdata, separators=(',', ':')), name='train_attr/training_script', dtype=tf.string) for message in WELCOME + CITATION + BUILD: log.info(message) run_opt.print_resource_summary() _do_work(jdata, run_opt, is_compress) def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = False): """Run serial model training. Parameters ---------- jdata : Dict[str, Any] arguments read form json/yaml control file run_opt : RunOptions object with run configuration is_compress : Bool indicates whether in model compress mode Raises ------ RuntimeError If unsupported modifier type is selected for model """ # make necessary checks assert "training" in jdata # init the model model = DPTrainer(jdata, run_opt=run_opt, is_compress = is_compress) rcut = model.model.get_rcut() type_map = model.model.get_type_map() if len(type_map) == 0: ipt_type_map = None else: ipt_type_map = type_map # init random seed of data systems seed = jdata["training"].get("seed", None) if seed is not None: # avoid the same batch sequence among workers seed += run_opt.my_rank seed = seed % (2 ** 32) dp_random.seed(seed) # setup data modifier modifier = get_modifier(jdata["model"].get("modifier", None)) # decouple the training data from the model compress process train_data = None valid_data = None if not is_compress: # init data train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier) train_data.print_summary("training") if jdata["training"].get("validation_data", None) is not None: valid_data = get_data(jdata["training"]["validation_data"], rcut, train_data.type_map, modifier) valid_data.print_summary("validation") # get training info stop_batch = j_must_have(jdata["training"], "numb_steps") model.build(train_data, stop_batch) if not is_compress: # train the model with the provided systems in a cyclic way start_time = time.time() model.train(train_data, valid_data) end_time = time.time() log.info("finished training") log.info(f"wall time: {(end_time - start_time):.3f} s") else: model.save_compressed() log.info("finished compressing") def get_data(jdata: Dict[str, Any], rcut, type_map, modifier): systems = j_must_have(jdata, "systems") if isinstance(systems, str): systems = expand_sys_str(systems) help_msg = 'Please check your setting for data systems' # check length of systems if len(systems) == 0: msg = 'cannot find valid a data system' log.fatal(msg) raise IOError(msg, help_msg) # rougly check all items in systems are valid for ii in systems: ii = DPPath(ii) if (not ii.is_dir()): msg = f'dir {ii} is not a valid dir' log.fatal(msg) raise IOError(msg, help_msg) if (not (ii / 'type.raw').is_file()): msg = f'dir {ii} is not a valid data system dir' log.fatal(msg) raise IOError(msg, help_msg) batch_size = j_must_have(jdata, "batch_size") sys_probs = jdata.get("sys_probs", None) auto_prob = jdata.get("auto_prob", "prob_sys_size") data = DeepmdDataSystem( systems=systems, batch_size=batch_size, test_size=1, # to satisfy the old api shuffle_test=True, # to satisfy the old api rcut=rcut, type_map=type_map, modifier=modifier, trn_all_set=True, # sample from all sets sys_probs=sys_probs, auto_prob_style=auto_prob ) data.add_dict(data_requirement) return data def get_modifier(modi_data=None): modifier: Optional[DipoleChargeModifier] if modi_data is not None: if modi_data["type"] == "dipole_charge": modifier = DipoleChargeModifier( modi_data["model_name"], modi_data["model_charge_map"], modi_data["sys_charge_map"], modi_data["ewald_h"], modi_data["ewald_beta"], ) else: raise RuntimeError("unknown modifier type " + str(modi_data["type"])) else: modifier = None return modifier def get_rcut(jdata): descrpt_data = jdata['model']['descriptor'] rcut_list = [] if descrpt_data['type'] == 'hybrid': for ii in descrpt_data['list']: rcut_list.append(ii['rcut']) else: rcut_list.append(descrpt_data['rcut']) return max(rcut_list) def get_type_map(jdata): return jdata['model'].get('type_map', None) def get_nbor_stat(jdata, rcut, one_type: bool = False): max_rcut = get_rcut(jdata) type_map = get_type_map(jdata) if type_map and len(type_map) == 0: type_map = None train_data = get_data(jdata["training"]["training_data"], max_rcut, type_map, None) train_data.get_batch() data_ntypes = train_data.get_ntypes() if type_map is not None: map_ntypes = len(type_map) else: map_ntypes = data_ntypes ntypes = max([map_ntypes, data_ntypes]) neistat = NeighborStat(ntypes, rcut, one_type=one_type) min_nbor_dist, max_nbor_size = neistat.get_stat(train_data) # moved from traier.py as duplicated # TODO: this is a simple fix but we should have a clear # architecture to call neighbor stat tf.constant(min_nbor_dist, name = 'train_attr/min_nbor_dist', dtype = GLOBAL_ENER_FLOAT_PRECISION) tf.constant(max_nbor_size, name = 'train_attr/max_nbor_size', dtype = tf.int32) return min_nbor_dist, max_nbor_size def get_sel(jdata, rcut, one_type: bool = False): _, max_nbor_size = get_nbor_stat(jdata, rcut, one_type=one_type) return max_nbor_size def get_min_nbor_dist(jdata, rcut): min_nbor_dist, _ = get_nbor_stat(jdata, rcut) return min_nbor_dist def parse_auto_sel(sel): if type(sel) is not str: return False words = sel.split(':') if words[0] == 'auto': return True else: return False def parse_auto_sel_ratio(sel): if not parse_auto_sel(sel): raise RuntimeError(f'invalid auto sel format {sel}') else: words = sel.split(':') if len(words) == 1: ratio = 1.1 elif len(words) == 2: ratio = float(words[1]) else: raise RuntimeError(f'invalid auto sel format {sel}') return ratio def wrap_up_4(xx): return 4 * ((int(xx) + 3) // 4) def update_one_sel(jdata, descriptor): if descriptor['type'] == 'loc_frame': return descriptor rcut = descriptor['rcut'] tmp_sel = get_sel(jdata, rcut, one_type=descriptor['type'] in ('se_atten',)) sel = descriptor['sel'] if isinstance(sel, int): # convert to list and finnally convert back to int sel = [sel] if parse_auto_sel(descriptor['sel']) : ratio = parse_auto_sel_ratio(descriptor['sel']) descriptor['sel'] = sel = [int(wrap_up_4(ii * ratio)) for ii in tmp_sel] else: # sel is set by user for ii, (tt, dd) in enumerate(zip(tmp_sel, sel)): if dd and tt > dd: # we may skip warning for sel=0, where the user is likely # to exclude such type in the descriptor log.warning( "sel of type %d is not enough! The expected value is " "not less than %d, but you set it to %d. The accuracy" " of your model may get worse." %(ii, tt, dd) ) if descriptor['type'] in ('se_atten',): descriptor['sel'] = sel = sum(sel) return descriptor def update_sel(jdata): log.info("Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)") descrpt_data = jdata['model']['descriptor'] if descrpt_data['type'] == 'hybrid': for ii in range(len(descrpt_data['list'])): descrpt_data['list'][ii] = update_one_sel(jdata, descrpt_data['list'][ii]) else: descrpt_data = update_one_sel(jdata, descrpt_data) jdata['model']['descriptor'] = descrpt_data return jdata