Commit 56d2b104 authored by zhangqha's avatar zhangqha
Browse files

BladeDISC DeePMD code update

parent 6b33aeb8
Pipeline #180 failed with stages
in 0 seconds
import logging
from typing import Callable, Tuple
import numpy as np
from deepmd.utils.errors import OutOfMemoryError
class AutoBatchSize:
"""This class allows DeePMD-kit to automatically decide the maximum
batch size that will not cause an OOM error.
Notes
-----
We assume all OOM error will raise :class:`OutOfMemoryError`.
Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms)
factor : float, default: 2.
increased factor
Attributes
----------
current_batch_size : int
current batch size (number of total atoms)
maximum_working_batch_size : int
maximum working batch size
minimal_not_working_batch_size : int
minimal not working batch size
"""
def __init__(self, initial_batch_size: int = 1024, factor: float = 2.) -> None:
# See also PyTorchLightning/pytorch-lightning#1638
# TODO: discuss a proper initial batch size
self.current_batch_size = initial_batch_size
self.maximum_working_batch_size = 0
self.minimal_not_working_batch_size = 2**31
self.factor = factor
def execute(self, callable: Callable, start_index: int, natoms: int) -> Tuple[int, tuple]:
"""Excuate a method with given batch size.
Parameters
----------
callable : Callable
The method should accept the batch size and start_index as parameters,
and returns executed batch size and data.
start_index : int
start index
natoms : int
natoms
Returns
-------
int
executed batch size * number of atoms
tuple
result from callable, None if failing to execute
Raises
------
OutOfMemoryError
OOM when batch size is 1
"""
try:
n_batch, result = callable(max(self.current_batch_size // natoms, 1), start_index)
except OutOfMemoryError as e:
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
self.minimal_not_working_batch_size = min(self.minimal_not_working_batch_size, self.current_batch_size)
if self.maximum_working_batch_size >= self.minimal_not_working_batch_size:
self.maximum_working_batch_size = int(self.minimal_not_working_batch_size / self.factor)
if self.minimal_not_working_batch_size <= natoms:
raise OutOfMemoryError("The callable still throws an out-of-memory (OOM) error even when batch size is 1!") from e
# adjust the next batch size
self._adjust_batch_size(1./self.factor)
return 0, None
else:
n_tot = n_batch * natoms
self.maximum_working_batch_size = max(self.maximum_working_batch_size, n_tot)
# adjust the next batch size
if n_tot + natoms > self.current_batch_size and self.current_batch_size * self.factor < self.minimal_not_working_batch_size:
self._adjust_batch_size(self.factor)
return n_batch, result
def _adjust_batch_size(self, factor: float):
old_batch_size = self.current_batch_size
self.current_batch_size = int(self.current_batch_size * factor)
logging.info("Adjust batch size from %d to %d" % (old_batch_size, self.current_batch_size))
def execute_all(self, callable: Callable, total_size: int, natoms: int, *args, **kwargs) -> Tuple[np.ndarray]:
"""Excuate a method with all given data.
Parameters
----------
callable : Callable
The method should accept *args and **kwargs as input and return the similiar array.
total_size : int
Total size
natoms : int
The number of atoms
**kwargs
If 2D np.ndarray, assume the first axis is batch; otherwise do nothing.
"""
def execute_with_batch_size(batch_size: int, start_index: int) -> Tuple[int, Tuple[np.ndarray]]:
end_index = start_index + batch_size
end_index = min(end_index, total_size)
return (end_index - start_index), callable(
*[(vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for vv in args],
**{kk: (vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for kk, vv in kwargs.items()},
)
index = 0
results = []
while index < total_size:
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
if not isinstance(result, tuple):
result = (result,)
index += n_batch
if n_batch:
for rr in result:
rr.reshape((n_batch, -1))
results.append(result)
r = tuple([np.concatenate(r, axis=0) for r in zip(*results)])
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
r = r[0]
return r
"""Module providing compatibility between `0.x.x` and `1.x.x` input versions."""
import json
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union
import numpy as np
from deepmd.common import j_must_have
def convert_input_v0_v1(
jdata: Dict[str, Any], warning: bool = True, dump: Optional[Union[str, Path]] = None
) -> Dict[str, Any]:
"""Convert input from v0 format to v1.
Parameters
----------
jdata : Dict[str, Any]
loaded json/yaml file
warning : bool, optional
whether to show deprecation warning, by default True
dump : Optional[Union[str, Path]], optional
whether to dump converted file, by default None
Returns
-------
Dict[str, Any]
converted output
"""
output = {}
output["model"] = _model(jdata, jdata["use_smooth"])
output["learning_rate"] = _learning_rate(jdata)
output["loss"] = _loss(jdata)
output["training"] = _training(jdata)
if warning:
_warning_input_v0_v1(dump)
if dump is not None:
with open(dump, "w") as fp:
json.dump(output, fp, indent=4)
return output
def _warning_input_v0_v1(fname: Optional[Union[str, Path]]):
msg = "It seems that you are using a deepmd-kit input of version 0.x.x, " \
"which is deprecated. we have converted the input to >2.0.0 compatible"
if fname is not None:
msg += f", and output it to file {fname}"
warnings.warn(msg)
def _model(jdata: Dict[str, Any], smooth: bool) -> Dict[str, Dict[str, Any]]:
"""Convert data to v1 input for non-smooth model.
Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
smooth : bool
whether to use smooth or non-smooth descriptor version
Returns
-------
Dict[str, Dict[str, Any]]
dictionary with model input parameters and sub-dictionaries for descriptor and
fitting net
"""
model = {}
model["descriptor"] = (
_smth_descriptor(jdata) if smooth else _nonsmth_descriptor(jdata)
)
model["fitting_net"] = _fitting_net(jdata)
return model
def _nonsmth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for non-smooth descriptor.
Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
Returns
-------
Dict[str, Any]
dict with descriptor parameters
"""
descriptor = {}
descriptor["type"] = "loc_frame"
_jcopy(jdata, descriptor, ("sel_a", "sel_r", "rcut", "axis_rule"))
return descriptor
def _smth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for smooth descriptor.
Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
Returns
-------
Dict[str, Any]
dict with descriptor parameters
"""
descriptor = {}
seed = jdata.get("seed", None)
if seed is not None:
descriptor["seed"] = seed
descriptor["type"] = "se_a"
descriptor["sel"] = jdata["sel_a"]
_jcopy(jdata, descriptor, ("rcut", ))
descriptor["rcut_smth"] = jdata.get("rcut_smth", descriptor["rcut"])
descriptor["neuron"] = j_must_have(jdata, "filter_neuron")
descriptor["axis_neuron"] = j_must_have(jdata, "axis_neuron", ["n_axis_neuron"])
descriptor["resnet_dt"] = False
if "resnet_dt" in jdata:
descriptor["resnet_dt"] = jdata["filter_resnet_dt"]
return descriptor
def _fitting_net(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for fitting net.
Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
Returns
-------
Dict[str, Any]
dict with fitting net parameters
"""
fitting_net = {}
seed = jdata.get("seed", None)
if seed is not None:
fitting_net["seed"] = seed
fitting_net["neuron"] = j_must_have(jdata, "fitting_neuron", ["n_neuron"])
fitting_net["resnet_dt"] = True
if "resnet_dt" in jdata:
fitting_net["resnet_dt"] = jdata["resnet_dt"]
if "fitting_resnet_dt" in jdata:
fitting_net["resnet_dt"] = jdata["fitting_resnet_dt"]
return fitting_net
def _learning_rate(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for learning rate section.
Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
Returns
-------
Dict[str, Any]
dict with learning rate parameters
"""
learning_rate = {}
learning_rate["type"] = "exp"
_jcopy(jdata, learning_rate, ("decay_steps", "decay_rate", "start_lr"))
return learning_rate
def _loss(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for loss function.
Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
Returns
-------
Dict[str, Any]
dict with loss function parameters
"""
loss: Dict[str, Any] = {}
_jcopy(
jdata,
loss,
(
"start_pref_e",
"limit_pref_e",
"start_pref_f",
"limit_pref_f",
"start_pref_v",
"limit_pref_v",
),
)
if "start_pref_ae" in jdata:
loss["start_pref_ae"] = jdata["start_pref_ae"]
if "limit_pref_ae" in jdata:
loss["limit_pref_ae"] = jdata["limit_pref_ae"]
return loss
def _training(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for training.
Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
Returns
-------
Dict[str, Any]
dict with training parameters
"""
training = {}
seed = jdata.get("seed", None)
if seed is not None:
training["seed"] = seed
_jcopy(jdata, training, ("systems", "set_prefix", "stop_batch", "batch_size"))
training["disp_file"] = "lcurve.out"
if "disp_file" in jdata:
training["disp_file"] = jdata["disp_file"]
training["disp_freq"] = j_must_have(jdata, "disp_freq")
training["numb_test"] = j_must_have(jdata, "numb_test")
training["save_freq"] = j_must_have(jdata, "save_freq")
training["save_ckpt"] = j_must_have(jdata, "save_ckpt")
training["disp_training"] = j_must_have(jdata, "disp_training")
training["time_training"] = j_must_have(jdata, "time_training")
if "profiling" in jdata:
training["profiling"] = jdata["profiling"]
if training["profiling"]:
training["profiling_file"] = j_must_have(jdata, "profiling_file")
return training
def _jcopy(src: Dict[str, Any], dst: Dict[str, Any], keys: Sequence[str]):
"""Copy specified keys from one dict to another.
Parameters
----------
src : Dict[str, Any]
source dictionary
dst : Dict[str, Any]
destination dictionary, will be modified in place
keys : Sequence[str]
list of keys to copy
must_have : bool
ensure that the source dictionary contains the copyyied keys
"""
for k in keys:
dst[k] = src[k]
def remove_decay_rate(jdata: Dict[str, Any]):
"""convert decay_rate to stop_lr.
Parameters
----------
jdata: Dict[str, Any]
input data
"""
lr = jdata["learning_rate"]
if "decay_rate" in lr:
decay_rate = lr["decay_rate"]
start_lr = lr["start_lr"]
stop_step = jdata["training"]["stop_batch"]
decay_steps = lr["decay_steps"]
stop_lr = np.exp(np.log(decay_rate) * (stop_step / decay_steps)) * start_lr
lr["stop_lr"] = stop_lr
lr.pop("decay_rate")
def convert_input_v1_v2(jdata: Dict[str, Any],
warning: bool = True,
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
tr_cfg = jdata["training"]
tr_data_keys = {
"systems",
"set_prefix",
"batch_size",
"sys_prob",
"auto_prob",
# alias included
"sys_weights",
"auto_prob_style"
}
tr_data_cfg = {k: v for k, v in tr_cfg.items() if k in tr_data_keys}
new_tr_cfg = {k: v for k, v in tr_cfg.items() if k not in tr_data_keys}
new_tr_cfg["training_data"] = tr_data_cfg
jdata["training"] = new_tr_cfg
# remove deprecated arguments
remove_decay_rate(jdata)
if warning:
_warning_input_v1_v2(dump)
if dump is not None:
with open(dump, "w") as fp:
json.dump(jdata, fp, indent=4)
return jdata
def _warning_input_v1_v2(fname: Optional[Union[str, Path]]):
msg = "It seems that you are using a deepmd-kit input of version 1.x.x, " \
"which is deprecated. we have converted the input to >2.0.0 compatible"
if fname is not None:
msg += f", and output it to file {fname}"
warnings.warn(msg)
def deprecate_numb_test(jdata: Dict[str, Any],
warning: bool = True,
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
"""Deprecate `numb_test` since v2.1. It has taken no effect since v2.0.
See `#1243 <https://github.com/deepmodeling/deepmd-kit/discussions/1243>`_.
Parameters
----------
jdata : Dict[str, Any]
loaded json/yaml file
warning : bool, optional
whether to show deprecation warning, by default True
dump : Optional[Union[str, Path]], optional
whether to dump converted file, by default None
Returns
-------
Dict[str, Any]
converted output
"""
try:
jdata.get("training", {}).pop("numb_test")
except KeyError:
pass
else:
if warning:
warnings.warn(
"The argument training->numb_test has been deprecated since v2.0.0. "
"Use training->validation_data->batch_size instead."
)
if dump is not None:
with open(dump, "w") as fp:
json.dump(jdata, fp, indent=4)
return jdata
def update_deepmd_input(jdata: Dict[str, Any],
warning: bool = True,
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
def is_deepmd_v0_input(jdata):
return "model" not in jdata.keys()
def is_deepmd_v1_input(jdata):
return "systems" in j_must_have(jdata, "training").keys()
if is_deepmd_v0_input(jdata):
jdata = convert_input_v0_v1(jdata, warning, None)
jdata = convert_input_v1_v2(jdata, False, None)
jdata = deprecate_numb_test(jdata, False, dump)
elif is_deepmd_v1_input(jdata):
jdata = convert_input_v1_v2(jdata, warning, None)
jdata = deprecate_numb_test(jdata, False, dump)
else:
jdata = deprecate_numb_test(jdata, warning, dump)
return jdata
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment