Unverified Commit 73866b01 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent ca86f720
from importlib.metadata import version
from packaging.version import Version
__version__ = version('sevenn')
from e3nn import __version__ as e3nn_ver
if Version(e3nn_ver) < Version('0.5.0'):
raise ValueError(
'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient '
'convention.'
)
import os
from enum import Enum
from typing import Dict
import torch
import sevenn._keys as KEY
from sevenn.nn.activation import ShiftedSoftPlus
NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118
IMPLEMENTED_RADIAL_BASIS = ['bessel']
IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR']
# TODO: support None. This became difficult because of parallel model
IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear']
IMPLEMENTED_INTERACTION_TYPE = ['nequip']
IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies']
IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms']
SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss']
SUPPORTING_ERROR_TYPES = [
'TotalEnergy',
'Energy',
'Force',
'Stress',
'Stress_GPa',
'TotalLoss',
]
IMPLEMENTED_MODEL = ['E3_equivariant_model']
# string input to real torch function
ACTIVATION = {
'relu': torch.nn.functional.relu,
'silu': torch.nn.functional.silu,
'tanh': torch.tanh,
'abs': torch.abs,
'ssp': ShiftedSoftPlus,
'sigmoid': torch.sigmoid,
'elu': torch.nn.functional.elu,
}
ACTIVATION_FOR_EVEN = {
'ssp': ShiftedSoftPlus,
'silu': torch.nn.functional.silu,
}
ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs}
ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD}
_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials')
SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth'
SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth'
SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth'
SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth'
SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth'
SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth'
_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download'
CHECKPOINT_DOWNLOAD_LINKS = {
SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth',
SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth',
}
# to avoid torch script to compile torch_geometry.data
AtomGraphDataType = Dict[str, torch.Tensor]
class LossType(Enum): # only used for train_v1, do not use it afterwards
ENERGY = 'energy' # eV or eV/atom
FORCE = 'force' # eV/A
STRESS = 'stress' # kB
def error_record_condition(x):
if type(x) is not list:
return False
for v in x:
if type(v) is not list or len(v) != 2:
return False
if v[0] not in SUPPORTING_ERROR_TYPES:
return False
if v[0] == 'TotalLoss':
continue
if v[1] not in SUPPORTING_METRICS:
return False
return True
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = {
KEY.CUTOFF: 4.5,
KEY.NODE_FEATURE_MULTIPLICITY: 32,
KEY.IRREPS_MANUAL: False,
KEY.LMAX: 1,
KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax
KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax
KEY.IS_PARITY: True,
KEY.NUM_CONVOLUTION: 3,
KEY.RADIAL_BASIS: {
KEY.RADIAL_BASIS_NAME: 'bessel',
},
KEY.CUTOFF_FUNCTION: {
KEY.CUTOFF_FUNCTION_NAME: 'poly_cut',
},
KEY.ACTIVATION_RADIAL: 'silu',
KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'},
KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'},
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64],
# KEY.AVG_NUM_NEIGH: True, # deprecated
# KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
KEY.CONV_DENOMINATOR: 'avg_num_neigh',
KEY.TRAIN_DENOMINTAOR: False,
KEY.TRAIN_SHIFT_SCALE: False,
# KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY.USE_BIAS_IN_LINEAR: False,
KEY.USE_MODAL_NODE_EMBEDDING: False,
KEY.USE_MODAL_SELF_INTER_INTRO: False,
KEY.USE_MODAL_SELF_INTER_OUTRO: False,
KEY.USE_MODAL_OUTPUT_BLOCK: False,
KEY.READOUT_AS_FCN: False,
# Applied af readout as fcn is True
KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30],
KEY.READOUT_FCN_ACTIVATION: 'relu',
KEY.SELF_CONNECTION_TYPE: 'nequip',
KEY.INTERACTION_TYPE: 'nequip',
KEY._NORMALIZE_SPH: True,
KEY.CUEQUIVARIANCE_CONFIG: {},
}
# Basically, "If provided, it should be type of ..."
MODEL_CONFIG_CONDITION = {
KEY.NODE_FEATURE_MULTIPLICITY: int,
KEY.LMAX: int,
KEY.LMAX_EDGE: int,
KEY.LMAX_NODE: int,
KEY.IS_PARITY: bool,
KEY.RADIAL_BASIS: {
KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS,
},
KEY.CUTOFF_FUNCTION: {
KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION,
},
KEY.CUTOFF: float,
KEY.NUM_CONVOLUTION: int,
KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float)
or x
in [
'avg_num_neigh',
'sqrt_avg_num_neigh',
],
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list,
KEY.TRAIN_SHIFT_SCALE: bool,
KEY.TRAIN_DENOMINTAOR: bool,
KEY.USE_BIAS_IN_LINEAR: bool,
KEY.USE_MODAL_NODE_EMBEDDING: bool,
KEY.USE_MODAL_SELF_INTER_INTRO: bool,
KEY.USE_MODAL_SELF_INTER_OUTRO: bool,
KEY.USE_MODAL_OUTPUT_BLOCK: bool,
KEY.READOUT_AS_FCN: bool,
KEY.READOUT_FCN_HIDDEN_NEURONS: list,
KEY.READOUT_FCN_ACTIVATION: str,
KEY.ACTIVATION_RADIAL: str,
KEY.SELF_CONNECTION_TYPE: lambda x: (
x in IMPLEMENTED_SELF_CONNECTION_TYPE
or (
isinstance(x, list)
and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x)
)
),
KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE,
KEY._NORMALIZE_SPH: bool,
KEY.CUEQUIVARIANCE_CONFIG: dict,
}
def model_defaults(config):
defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
if KEY.READOUT_AS_FCN not in config:
config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN]
if config[KEY.READOUT_AS_FCN] is False:
defaults.pop(KEY.READOUT_FCN_ACTIVATION, None)
defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None)
return defaults
DEFAULT_DATA_CONFIG = {
KEY.DTYPE: 'single',
KEY.DATA_FORMAT: 'ase',
KEY.DATA_FORMAT_ARGS: {},
KEY.SAVE_DATASET: False,
KEY.SAVE_BY_LABEL: False,
KEY.SAVE_BY_TRAIN_VALID: False,
KEY.RATIO: 0.0,
KEY.BATCH_SIZE: 6,
KEY.PREPROCESS_NUM_CORES: 1,
KEY.COMPUTE_STATISTICS: True,
KEY.DATASET_TYPE: 'graph',
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY.USE_MODAL_WISE_SHIFT: False,
KEY.USE_MODAL_WISE_SCALE: False,
KEY.SHIFT: 'per_atom_energy_mean',
KEY.SCALE: 'force_rms',
# KEY.DATA_SHUFFLE: True,
# KEY.DATA_WEIGHT: False,
# KEY.DATA_MODALITY: False,
}
DATA_CONFIG_CONDITION = {
KEY.DTYPE: str,
KEY.DATA_FORMAT: str,
KEY.DATA_FORMAT_ARGS: dict,
KEY.SAVE_DATASET: str,
KEY.SAVE_BY_LABEL: bool,
KEY.SAVE_BY_TRAIN_VALID: bool,
KEY.RATIO: float,
KEY.BATCH_SIZE: int,
KEY.PREPROCESS_NUM_CORES: int,
KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'],
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT,
KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE,
KEY.USE_MODAL_WISE_SHIFT: bool,
KEY.USE_MODAL_WISE_SCALE: bool,
# KEY.DATA_SHUFFLE: bool,
KEY.COMPUTE_STATISTICS: bool,
# KEY.DATA_WEIGHT: bool,
# KEY.DATA_MODALITY: bool,
}
def data_defaults(config):
defaults = DEFAULT_DATA_CONFIG
if KEY.LOAD_VALIDSET in config:
defaults.pop(KEY.RATIO, None)
return defaults
DEFAULT_TRAINING_CONFIG = {
KEY.RANDOM_SEED: 1,
KEY.EPOCH: 300,
KEY.LOSS: 'mse',
KEY.LOSS_PARAM: {},
KEY.OPTIMIZER: 'adam',
KEY.OPTIM_PARAM: {},
KEY.SCHEDULER: 'exponentiallr',
KEY.SCHEDULER_PARAM: {},
KEY.FORCE_WEIGHT: 0.1,
KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default
KEY.PER_EPOCH: 5,
# KEY.USE_TESTSET: False,
KEY.CONTINUE: {
KEY.CHECKPOINT: False,
KEY.RESET_OPTIMIZER: False,
KEY.RESET_SCHEDULER: False,
KEY.RESET_EPOCH: False,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True,
},
# KEY.DEFAULT_MODAL: 'common',
KEY.CSV_LOG: 'log.csv',
KEY.NUM_WORKERS: 0,
KEY.IS_TRAIN_STRESS: True,
KEY.TRAIN_SHUFFLE: True,
KEY.ERROR_RECORD: [
['Energy', 'RMSE'],
['Force', 'RMSE'],
['Stress', 'RMSE'],
['TotalLoss', 'None'],
],
KEY.BEST_METRIC: 'TotalLoss',
KEY.USE_WEIGHT: False,
KEY.USE_MODALITY: False,
}
TRAINING_CONFIG_CONDITION = {
KEY.RANDOM_SEED: int,
KEY.EPOCH: int,
KEY.FORCE_WEIGHT: float,
KEY.STRESS_WEIGHT: float,
KEY.USE_TESTSET: None, # Not used
KEY.NUM_WORKERS: int,
KEY.PER_EPOCH: int,
KEY.CONTINUE: {
KEY.CHECKPOINT: str,
KEY.RESET_OPTIMIZER: bool,
KEY.RESET_SCHEDULER: bool,
KEY.RESET_EPOCH: bool,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool,
},
KEY.DEFAULT_MODAL: str,
KEY.IS_TRAIN_STRESS: bool,
KEY.TRAIN_SHUFFLE: bool,
KEY.ERROR_RECORD: error_record_condition,
KEY.BEST_METRIC: str,
KEY.CSV_LOG: str,
KEY.USE_MODALITY: bool,
KEY.USE_WEIGHT: bool,
}
def train_defaults(config):
defaults = DEFAULT_TRAINING_CONFIG
if KEY.IS_TRAIN_STRESS not in config:
config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS]
if not config[KEY.IS_TRAIN_STRESS]:
defaults.pop(KEY.STRESS_WEIGHT, None)
return defaults
"""
How to add new feature?
1. Add new key to this file.
2. Add new key to _const.py
2.1. if the type of input is consistent,
write adequate condition and default to _const.py.
2.2. if the type of input is not consistent,
you must add your own input validation code to
parse_input.py
"""
from typing import Final
# see
# https://github.com/pytorch/pytorch/issues/52312
# for FYI
# ~~ keys ~~ #
# PyG : primitive key of torch_geometric.data.Data type
# ==================================================#
# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ #
# ==================================================#
# some raw properties of graph
ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N)
POS: Final[str] = 'pos' # (N, 3) PyG
CELL: Final[str] = 'cell_lattice_vectors' # (3, 3)
CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3)
CELL_VOLUME: Final[str] = 'cell_volume'
EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3)
EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1)
# some primary data of graph
EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG
ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes
NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG
NODE_FEATURE_GHOST: Final[str] = 'x_ghost'
NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot
MODAL_ATTR: Final[str] = (
'modal_attr' # (1, N_modalities) for handling multi-modal
)
MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal
EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics)
EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding)
# inputs of loss function
ENERGY: Final[str] = 'total_energy' # (1)
FORCE: Final[str] = 'force_of_atoms' # (N, 3)
STRESS: Final[str] = 'stress' # (6)
# This is for training, per atom scale.
SCALED_ENERGY: Final[str] = 'scaled_total_energy'
# general outputs of models
SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy'
ATOMIC_ENERGY: Final[str] = 'atomic_energy'
PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy'
PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy'
PER_ATOM_ENERGY: Final[str] = 'per_atom_energy'
PRED_FORCE: Final[str] = 'inferred_force'
SCALED_FORCE: Final[str] = 'scaled_force'
PRED_STRESS: Final[str] = 'inferred_stress'
SCALED_STRESS: Final[str] = 'scaled_stress'
# very general data property for AtomGraphData
NUM_ATOMS: Final[str] = 'num_atoms' # int
NUM_GHOSTS: Final[str] = 'num_ghosts'
NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu
USER_LABEL: Final[str] = 'user_label'
DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data
DATA_MODALITY: Final[str] = (
'data_modality' # modality of given data. e.g. PBE and SCAN
)
BATCH: Final[str] = 'batch'
TAG = 'tag' # replace USER_LABEL
# etc
SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp'
BATCH_SIZE: Final[str] = 'batch_size'
INFO: Final[str] = 'data_info'
# something special
LABEL_NONE: Final[str] = 'No_label'
# ==================================================#
# ~~~~~~ KEY for train/data configuration ~~~~~~~~ #
# ==================================================#
PREPROCESS_NUM_CORES = 'preprocess_num_cores'
SAVE_DATASET = 'save_dataset_path'
SAVE_BY_LABEL = 'save_by_label'
SAVE_BY_TRAIN_VALID = 'save_by_train_valid'
DATA_FORMAT = 'data_format'
DATA_FORMAT_ARGS = 'data_format_args'
STRUCTURE_LIST = 'structure_list'
LOAD_DATASET = 'load_dataset_path' # not used in v2
LOAD_TRAINSET = 'load_trainset_path'
LOAD_VALIDSET = 'load_validset_path'
LOAD_TESTSET = 'load_testset_path'
FORMAT_OUTPUTS = 'format_outputs_for_ase'
COMPUTE_STATISTICS = 'compute_statistics'
DATASET_TYPE = 'dataset_type'
RANDOM_SEED = 'random_seed'
RATIO = 'data_divide_ratio'
USE_TESTSET = 'use_testset'
EPOCH = 'epoch'
LOSS = 'loss'
LOSS_PARAM = 'loss_param'
OPTIMIZER = 'optimizer'
OPTIM_PARAM = 'optim_param'
SCHEDULER = 'scheduler'
SCHEDULER_PARAM = 'scheduler_param'
FORCE_WEIGHT = 'force_loss_weight'
STRESS_WEIGHT = 'stress_loss_weight'
DEVICE = 'device'
DTYPE = 'dtype'
TRAIN_SHUFFLE = 'train_shuffle'
IS_TRAIN_STRESS = 'is_train_stress'
CONTINUE = 'continue'
CHECKPOINT = 'checkpoint'
RESET_OPTIMIZER = 'reset_optimizer'
RESET_SCHEDULER = 'reset_scheduler'
RESET_EPOCH = 'reset_epoch'
USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint'
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = (
'use_statistic_values_for_cp_modal_only'
)
CSV_LOG = 'csv_log'
ERROR_RECORD = 'error_record'
BEST_METRIC = 'best_metric'
NUM_WORKERS = 'num_workers' # not work
RANK = 'rank'
LOCAL_RANK = 'local_rank'
WORLD_SIZE = 'world_size'
IS_DDP = 'is_ddp'
DDP_BACKEND = 'ddp_backend'
PER_EPOCH = 'per_epoch'
USE_WEIGHT = 'use_weight'
USE_MODALITY = 'use_modality'
DEFAULT_MODAL = 'default_modal'
# ==================================================#
# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ #
# ==================================================#
# ~~ global model configuration ~~ #
# note that these names are directly used for input.yaml for user input
MODEL_TYPE = '_model_type'
CUTOFF = 'cutoff'
CHEMICAL_SPECIES = 'chemical_species'
MODAL_LIST = 'modal_list'
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number'
NUM_SPECIES = '_number_of_species'
NUM_MODALITIES = '_number_of_modalities'
TYPE_MAP = '_type_map'
MODAL_MAP = '_modal_map'
# ~~ E3 equivariant model build configuration keys ~~ #
# see model_build default_config for type
IRREPS_MANUAL = 'irreps_manual'
NODE_FEATURE_MULTIPLICITY = 'channel'
RADIAL_BASIS = 'radial_basis'
BESSEL_BASIS_NUM = 'bessel_basis_num'
CUTOFF_FUNCTION = 'cutoff_function'
POLY_CUT_P = 'poly_cut_p_value'
LMAX = 'lmax'
LMAX_EDGE = 'lmax_edge'
LMAX_NODE = 'lmax_node'
IS_PARITY = 'is_parity'
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons'
NUM_CONVOLUTION = 'num_convolution_layer'
ACTIVATION_SCARLAR = 'act_scalar'
ACTIVATION_GATE = 'act_gate'
ACTIVATION_RADIAL = 'act_radial'
SELF_CONNECTION_TYPE = 'self_connection_type'
RADIAL_BASIS_NAME = 'radial_basis_name'
CUTOFF_FUNCTION_NAME = 'cutoff_function_name'
USE_BIAS_IN_LINEAR = 'use_bias_in_linear'
USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding'
USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro'
USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro'
USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block'
READOUT_AS_FCN = 'readout_as_fcn'
READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons'
READOUT_FCN_ACTIVATION = 'readout_fcn_activation'
AVG_NUM_NEIGH = 'avg_num_neigh'
CONV_DENOMINATOR = 'conv_denominator'
SHIFT = 'shift'
SCALE = 'scale'
USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale'
USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift'
USE_MODAL_WISE_SCALE = 'use_modal_wise_scale'
TRAIN_SHIFT_SCALE = 'train_shift_scale'
TRAIN_DENOMINTAOR = 'train_denominator'
INTERACTION_TYPE = 'interaction_type'
TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated
CUEQUIVARIANCE_CONFIG = 'cuequivariance_config'
_NORMALIZE_SPH = '_normalize_sph'
OPTIMIZE_BY_REDUCE = 'optimize_by_reduce'
from typing import Optional
import torch
import torch_geometric.data
import sevenn._keys as KEY
import sevenn.util
class AtomGraphData(torch_geometric.data.Data):
"""
Args:
x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes,
atomic_numbers]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in coordinate
format with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y_energy: scalar # unit of eV (VASP raw)
y_force: [num_nodes, 3] # unit of eV/A (VASP raw)
y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
x, y_force, pos should be aligned with each other.
"""
def __init__(
self,
x: Optional[torch.Tensor] = None,
edge_index: Optional[torch.Tensor] = None,
pos: Optional[torch.Tensor] = None,
edge_attr: Optional[torch.Tensor] = None,
**kwargs
):
super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos)
self[KEY.NODE_ATTR] = x # ?
for k, v in kwargs.items():
self[k] = v
def to_numpy_dict(self):
# This is not debugged yet!
dct = {
k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v
for k, v in self.items()
}
return dct
def fit_dimension(self):
per_atom_keys = [
KEY.ATOMIC_NUMBERS,
KEY.ATOMIC_ENERGY,
KEY.POS,
KEY.FORCE,
KEY.PRED_FORCE,
]
natoms = self.num_atoms.item()
for k, v in self.items():
if not isinstance(v, torch.Tensor):
continue
if natoms == 1 and k in per_atom_keys:
self[k] = v.squeeze().unsqueeze(0)
else:
self[k] = v.squeeze()
return self
@staticmethod
def from_numpy_dict(dct):
for k, v in dct.items():
if k == KEY.CELL_SHIFT:
dct[k] = torch.Tensor(v) # this is special
else:
dct[k] = sevenn.util.dtype_correct(v)
return AtomGraphData(**dct)
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda/bin/nvcc
cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17
post_cflags =
cuda_cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 --expt-relaxed-constexpr -fmad=false -std=c++17
cuda_post_cflags =
cuda_dlink_post_cflags =
ldflags = -shared -L/home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
rule link
command = $cxx $in $ldflags -o $out
build pair_d3_for_ase.cuda.o: cuda_compile /home/mazhaojia/mace-project/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/pair_d3_for_ase.cu
build pair_d3.so: link pair_d3_for_ase.cuda.o
default pair_d3.so
This diff is collapsed.
This diff is collapsed.
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.train.loss import LossDefinition
from .atom_graph_data import AtomGraphData
from .train.optim import loss_dict
_ERROR_TYPES = {
'TotalEnergy': {
'name': 'Energy',
'ref_key': KEY.ENERGY,
'pred_key': KEY.PRED_TOTAL_ENERGY,
'unit': 'eV',
'vdim': 1,
},
'Energy': { # by default per-atom for energy
'name': 'Energy',
'ref_key': KEY.ENERGY,
'pred_key': KEY.PRED_TOTAL_ENERGY,
'unit': 'eV/atom',
'per_atom': True,
'vdim': 1,
},
'Force': {
'name': 'Force',
'ref_key': KEY.FORCE,
'pred_key': KEY.PRED_FORCE,
'unit': 'eV/Å',
'vdim': 3,
},
'Stress': {
'name': 'Stress',
'ref_key': KEY.STRESS,
'pred_key': KEY.PRED_STRESS,
'unit': 'kbar',
'coeff': 1602.1766208,
'vdim': 6,
},
'Stress_GPa': {
'name': 'Stress',
'ref_key': KEY.STRESS,
'pred_key': KEY.PRED_STRESS,
'unit': 'GPa',
'coeff': 160.21766208,
'vdim': 6,
},
'TotalLoss': {
'name': 'TotalLoss',
'unit': None,
},
}
def get_err_type(name: str) -> Dict[str, Any]:
return deepcopy(_ERROR_TYPES[name])
def _get_loss_function_from_name(loss_functions, name):
for loss_def, w in loss_functions:
if loss_def.name.lower() == name.lower():
return loss_def, w
return None, None
class AverageNumber:
def __init__(self):
self._sum = 0.0
self._count = 0
def update(self, values: torch.Tensor):
self._sum += values.sum().item()
self._count += values.numel()
def _ddp_reduce(self, device):
_sum = torch.tensor(self._sum, device=device)
_count = torch.tensor(self._count, device=device)
dist.all_reduce(_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(_count, op=dist.ReduceOp.SUM)
self._sum = _sum.item()
self._count = _count.item()
def get(self):
if self._count == 0:
return torch.nan
return self._sum / self._count
class ErrorMetric:
"""
Base class for error metrics We always average error by # of structures,
and designed to collect errors in the middle of iteration (by AverageNumber)
"""
def __init__(
self,
name: str,
ref_key: str,
pred_key: str,
coeff: float = 1.0,
unit: Optional[str] = None,
per_atom: bool = False,
ignore_unlabeled: bool = True,
**kwargs,
):
self.name = name
self.unit = unit
self.coeff = coeff
self.ref_key = ref_key
self.pred_key = pred_key
self.per_atom = per_atom
self.ignore_unlabeled = ignore_unlabeled
self.value = AverageNumber()
def update(self, output: AtomGraphData):
raise NotImplementedError
def _retrieve(self, output: AtomGraphData):
y_ref = output[self.ref_key] * self.coeff
y_pred = output[self.pred_key] * self.coeff
if self.per_atom:
assert y_ref.dim() == 1 and y_pred.dim() == 1
natoms = output[KEY.NUM_ATOMS]
y_ref = y_ref / natoms
y_pred = y_pred / natoms
if self.ignore_unlabeled:
unlabelled_idx = torch.isnan(y_ref)
y_ref = y_ref[~unlabelled_idx]
y_pred = y_pred[~unlabelled_idx]
return y_ref, y_pred
def ddp_reduce(self, device):
self.value._ddp_reduce(device)
def reset(self):
self.value = AverageNumber()
def get(self):
return self.value.get()
def key_str(self, with_unit=True):
if self.unit is None or not with_unit:
return self.name
else:
return f'{self.name} ({self.unit})'
def __str__(self):
return f'{self.key_str()}: {self.value.get():.6f}'
class RMSError(ErrorMetric):
"""
Vector squared error
"""
def __init__(self, vdim: int = 1, **kwargs):
super().__init__(**kwargs)
self.vdim = vdim
self._se = torch.nn.MSELoss(reduction='none')
def _square_error(self, y_ref, y_pred, vdim: int):
return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1)
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
se = self._square_error(y_ref, y_pred, self.vdim)
self.value.update(se)
def get(self):
return self.value.get() ** 0.5
class ComponentRMSError(ErrorMetric):
"""
Ignore vector dim and just average over components
Results smaller error
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._se = torch.nn.MSELoss(reduction='none')
def _square_error(self, y_ref, y_pred):
return self._se(y_ref, y_pred)
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
y_ref = y_ref.view(-1)
y_pred = y_pred.view(-1)
se = self._square_error(y_ref, y_pred)
self.value.update(se)
def get(self):
return self.value.get() ** 0.5
class MAError(ErrorMetric):
"""
Average over all component
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _square_error(self, y_ref, y_pred):
return torch.abs(y_ref - y_pred)
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
y_ref = y_ref.reshape((-1,))
y_pred = y_pred.reshape((-1,))
se = self._square_error(y_ref, y_pred)
self.value.update(se)
class CustomError(ErrorMetric):
"""
Custom error metric
Args:
func: a function that takes y_ref and y_pred
and returns a list of errors
"""
def __init__(self, func: Callable, **kwargs):
super().__init__(**kwargs)
self.func = func
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([])
self.value.update(se)
class LossError(ErrorMetric):
"""
Error metric that record loss
"""
def __init__(
self,
name: str,
loss_def: LossDefinition,
**kwargs,
):
super().__init__(
name,
ignore_unlabeld=loss_def.ignore_unlabeled,
**kwargs,
)
self.loss_def = loss_def
def update(self, output: AtomGraphData):
loss = self.loss_def.get_loss(output) # type: ignore
self.value.update(loss) # type: ignore
class CombinedError(ErrorMetric):
"""
Combine multiple error metrics with weights
corresponds to a weighted sum of errors (normally used in loss)
"""
def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs):
super().__init__(**kwargs)
self.metrics = metrics
assert kwargs['unit'] is None
def update(self, output: AtomGraphData):
for metric, _ in self.metrics:
metric.update(output)
def reset(self):
for metric, _ in self.metrics:
metric.reset()
def ddp_reduce(self, device): # override
for metric, _ in self.metrics:
metric.value._ddp_reduce(device)
def get(self):
val = 0.0
for metric, weight in self.metrics:
val += metric.get() * weight
return val
class ErrorRecorder:
"""
record errors of a model
"""
METRIC_DICT = {
'RMSE': RMSError,
'ComponentRMSE': ComponentRMSError,
'MAE': MAError,
'Loss': LossError,
}
def __init__(self, metrics: List[ErrorMetric]):
self.history = []
self.metrics = metrics
def _update(self, output: AtomGraphData):
for metric in self.metrics:
metric.update(output)
def update(self, output: AtomGraphData, no_grad=True):
if no_grad:
with torch.no_grad():
self._update(output)
else:
self._update(output)
def get_metric_dict(self, with_unit=True):
return {metric.key_str(with_unit): metric.get() for metric in self.metrics}
def get_current(self):
dct = {}
for metric in self.metrics:
dct[metric.name] = {
'value': metric.get(),
'unit': metric.unit,
'ref_key': metric.ref_key,
'pred_key': metric.pred_key,
}
return dct
def get_dct(self, prefix=''):
dct = {}
if prefix.endswith('_') is False and prefix != '':
prefix = prefix + '_'
for metric in self.metrics:
dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}'
return dct
def get_key_str(self, name: str):
for metric in self.metrics:
if name == metric.name:
return metric.key_str()
return None
def epoch_forward(self):
self.history.append(self.get_current())
pretty = self.get_metric_dict(with_unit=True)
for metric in self.metrics:
metric.reset()
return pretty # for print
@staticmethod
def init_total_loss_metric(
config,
criteria: Optional[Callable] = None,
loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None,
):
if criteria is None and loss_functions is None:
raise ValueError('both criteria and loss functions not given')
is_stress = config[KEY.IS_TRAIN_STRESS]
metrics = []
if criteria is not None:
energy_metric = CustomError(criteria, **get_err_type('Energy'))
metrics.append((energy_metric, 1))
force_metric = CustomError(criteria, **get_err_type('Force'))
metrics.append((force_metric, config[KEY.FORCE_WEIGHT]))
if is_stress:
stress_metric = CustomError(criteria, **get_err_type('Stress'))
metrics.append((stress_metric, config[KEY.STRESS_WEIGHT]))
else: # TODO: this is hard-coded
for efs in ['Energy', 'Force', 'Stress']:
if efs == 'Stress' and not is_stress:
continue
lf, w = _get_loss_function_from_name(loss_functions, efs)
if lf is None:
raise ValueError(f'{efs} not found from loss_functions')
metric = LossError(loss_def=lf, **get_err_type(efs))
metrics.append((metric, w))
total_loss_metric = CombinedError(
metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None
)
return total_loss_metric
@staticmethod
def from_config(config: dict, loss_functions=None):
loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()]
loss_param = config.get(KEY.LOSS_PARAM, {})
criteria = loss_cls(**loss_param) if loss_functions is None else None
err_config = config.get(KEY.ERROR_RECORD, False)
if not err_config:
raise ValueError(
'No error_record config found. Consider util.get_error_recorder'
)
err_config_n = []
if not config.get(KEY.IS_TRAIN_STRESS, True):
for err_type, metric_name in err_config:
if 'Stress' in err_type:
continue
err_config_n.append((err_type, metric_name))
err_config = err_config_n
err_metrics = []
for err_type, metric_name in err_config:
metric_kwargs = get_err_type(err_type)
if err_type == 'TotalLoss': # special case
err_metrics.append(
ErrorRecorder.init_total_loss_metric(
config, criteria, loss_functions
)
)
continue
metric_cls = ErrorRecorder.METRIC_DICT[metric_name]
assert isinstance(metric_kwargs['name'], str)
if metric_name == 'Loss':
if loss_functions is not None:
metric_cls = LossError
metric_kwargs['loss_def'], _ = _get_loss_function_from_name(
loss_functions, metric_kwargs['name']
)
else:
metric_cls = CustomError
metric_kwargs['func'] = criteria
metric_kwargs.pop('unit', None)
metric_kwargs['name'] += f'_{metric_name}'
err_metrics.append(metric_cls(**metric_kwargs))
return ErrorRecorder(err_metrics)
import os
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
from ase.data import atomic_numbers
import sevenn._keys as KEY
from sevenn import __version__
CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output
def __init__(
self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
):
self.rank = rank
self._filename = filename
if rank == 0:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self.logfile = None
self.files = {}
self.screen = screen
else:
self.logfile = None
self.screen = False
self.timer_dct = {}
self.active = True
def __enter__(self):
if self.rank != 0:
return self
if self.logfile is None and self._filename is not None:
try:
self.logfile = open(
self._filename, 'a', buffering=1, encoding='utf-8'
)
except IOError as e:
print(f'Failed to re-open log file {self._filename}: {e}')
self.logfile = None
self.files = {}
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.rank != 0:
return self
try:
if self.logfile is not None:
self.logfile.close()
self.logfile = None
for f in self.files.values():
f.close()
except IOError as e:
print(f'Failed to close log files: {e}')
finally:
self.logfile = None
self.files = {}
def switch_file(self, new_filename: str):
if self.rank != 0:
return self
if self.logfile is not None:
raise ValueError('Current logfile is not yet closed')
self._filename = new_filename
return self
def write(self, content: str):
if self.rank != 0:
return
# no newline!
if self.logfile is not None and self.active:
self.logfile.write(content)
if self.screen and self.active:
print(content, end='')
def writeline(self, content: str):
content = content + '\n'
self.write(content)
def init_csv(self, filename: str, header: list):
"""
Deprecated
"""
if self.rank == 0:
self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
self.files[filename].write(','.join(header) + '\n')
else:
pass
def append_csv(self, filename: str, content: list, decimal: int = 6):
"""
Deprecated
"""
if self.rank == 0:
if filename not in self.files:
self.files[filename] = open(filename, 'a', buffering=1)
str_content = []
for c in content:
if isinstance(c, float):
str_content.append(f'{c:.{decimal}f}')
else:
str_content.append(str(c))
self.files[filename].write(','.join(str_content) + '\n')
else:
pass
def natoms_write(self, natoms: Dict[str, Dict]):
content = ''
total_natom = {}
for label, natom in natoms.items():
content += self.format_k_v(label, natom)
for specie, num in natom.items():
try:
total_natom[specie] += num
except KeyError:
total_natom[specie] = num
content += self.format_k_v('Total, label wise', total_natom)
content += self.format_k_v('Total', sum(total_natom.values()))
self.write(content)
def statistic_write(self, statistic: Dict[str, Dict]):
content = ''
for label, dct in statistic.items():
if label.startswith('_'):
continue
if not isinstance(dct, dict):
continue
dct_new = {}
for k, v in dct.items():
if k.startswith('_'):
continue
if isinstance(v, int):
dct_new[k] = v
else:
dct_new[k] = f'{v:.3f}'
content += self.format_k_v(label, dct_new)
self.write(content)
# TODO : refactoring!!!, this is not loss, rmse
def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
lb_pad = 21
fs = 6
pad = 21 - fs
ln = '-' * fs
total_atom_type = train_loss.keys()
content = ''
for at in total_atom_type:
t_F = train_loss[at]
v_F = valid_loss[at]
at_sym = CHEM_SYMBOLS[at]
content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
t_F=t_F, v_F=v_F, pad=pad, fs=fs
)
content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
t_S=ln, v_S=ln, pad=pad, fs=fs
)
content += '\n'
self.write(content)
def write_full_table(
self,
dict_list: List[Dict],
row_labels: List[str],
decimal_places: int = 6,
pad: int = 2,
):
"""
Assume data_list is list of dict with same keys
"""
assert len(dict_list) == len(row_labels)
label_len = max(map(len, row_labels))
# Extract the column names and create a 2D array of values
col_names = list(dict_list[0].keys())
values = [list(d.values()) for d in dict_list]
# Format the numbers with the given decimal places
formatted_values = [
[f'{value:.{decimal_places}f}' for value in row] for row in values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths = [
max(len(str(value)) for value in col) + pad
for col in zip(col_names, *formatted_values)
]
# Create header row and separator
header = ' ' * (label_len + pad) + ' '.join(
col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
)
separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
label_len + pad
)
# Print header and separator
self.writeline(header)
self.writeline(separator)
# Print the data rows with row labels
for row_label, row in zip(row_labels, formatted_values):
data_row = ' '.join(
value.rjust(pad) for value, pad in zip(row, max_col_lengths)
)
self.writeline(f'{row_label.ljust(label_len)}{data_row}')
def format_k_v(self, key: Any, val: Any, write: bool = False):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE = 20
SEPARATOR = ', '
EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
key = str(key)
val = str(val)
content = f'{key:<{MAX_KEY_SIZE}}: {val}'
if len(content) > NEW_LINE_LEN:
content = f'{key:<{MAX_KEY_SIZE}}: '
# septate val by separator
val_list = val.split(SEPARATOR)
current_len = len(content)
for val_compo in val_list:
current_len += len(val_compo)
if current_len > NEW_LINE_LEN:
newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
content += f'\\\n{newline_content}'
current_len = len(newline_content)
else:
content += f'{val_compo}{SEPARATOR}'
if content.endswith(f'{SEPARATOR}'):
content = content[: -len(SEPARATOR)]
content += '\n'
if write is False:
return content
else:
self.write(content)
return ''
def greeting(self):
LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
with open(LOGO_ASCII_FILE, 'r') as logo_f:
logo_ascii = logo_f.read()
content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
content += f'version {__version__}, {time.ctime()}\n'
self.write(content)
self.write(logo_ascii)
def bar(self):
content = '-' * Logger.SCREEN_WIDTH + '\n'
self.write(content)
def print_config(
self,
model_config: Dict[str, Any],
data_config: Dict[str, Any],
train_config: Dict[str, Any],
):
"""
print some important information from config
"""
content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
for k, v in model_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom train configuration\n'
for k, v in train_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom data configuration\n'
for k, v in data_config.items():
content += self.format_k_v(k, str(v))
self.write(content)
# TODO: This is not good make own exception
def error(self, e: Exception):
content = ''
if type(e) is ValueError:
content += 'Error occurred!\n'
content += str(e) + '\n'
else:
content += 'Unknown error occurred!\n'
content += traceback.format_exc()
self.write(content)
def timer_start(self, name: str):
self.timer_dct[name] = datetime.now()
def timer_end(self, name: str, message: str, remove: bool = True):
"""
print f"{message}: {elapsed}"
"""
elapsed = str(datetime.now() - self.timer_dct[name])
# elapsed = elapsed.strftime('%H-%M-%S')
if remove:
del self.timer_dct[name]
self.write(f'{message}: {elapsed[:-4]}\n')
# TODO: print it without config
# TODO: refactoring, readout part name :(
def print_model_info(self, model, config):
from functools import partial
kv_write = partial(self.format_k_v, write=True)
self.writeline('Irreps of features')
kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
for i in range(config[KEY.NUM_CONVOLUTION]):
kv_write(
f'{i}th node',
model.get_irreps_in(f'{i}_self_interaction_1'),
)
i = config[KEY.NUM_CONVOLUTION] - 1
kv_write(
'readout irreps',
model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
)
num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.writeline(f'# learnable parameters: {num_weights}\n')
****
******** .
*//////, .. . ,*.
,,***. .. , ********. ./,
. . .. /////. ., . *///////// /////////.
.&@&/ . .(((((((.. / *//////*. ... *((((((((((.
@@@@@@@@@@* @@@@@@@@@@ @@@@@ *((@@@@@ ( %@@@@@@@@@@ .@@@@@@ ..@@@@. @@@@@@* .(@@@@@(((*
@@@@@. @@@@ @@@@@ . @@@@@ # %@@@@ @@@@@@@@ @@@@(, @@@@@@@@. @@@@@(*.
%@@@@@@@& @@@@@@@@@@ @@@@@ @@@@@ # ., .%@@@@@@@@@ @@@@@@@@@@ @@@@, @@@@@@@@@@ @@@@@
,(%@@@@@@@@@ @@@@@@@@@@ @@@@@ @@@@& (//////%@@@@@@@@@ @@@@ @@@@@@ @@@@ . @@@@@ @@@@@.@@@@@
. @@@@@ @@@@ . . @@@@@@@@% . . ( .////,%@@@@ @@@@ @@@@@@@@@ @@@@@ @@@@@@@@@
(@@@@@@@@@@@ @@@@@@@@@@**. @@@@@@* *. .%@@@@@@@@@@ @@@@ . @@@@@@@ @@@@@ .@@@@@@@
@@@@@@@@@. @@@@@@@@@@///, @@@@. . / %@@@@@@@@@@ @@@@***, @@@@@ @@@@@ @@@@@
. //////////*. / . .*******... . ,.
.&&&&&... ,//////*. ...////. / ,*/. . ,////, .,/////
&@@@@@@ ,(/((, * ,((((((. .***.
,/@(, .. * ,((((*
,
.
import copy
import warnings
from collections import OrderedDict
from typing import List, Literal, Union, overload
from e3nn.o3 import Irreps
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
from .nn.convolution import IrrepsConvolution
from .nn.edge_embedding import (
BesselBasis,
EdgeEmbedding,
PolynomialCutoff,
SphericalEncoding,
XPLORCutoff,
)
from .nn.force_output import ForceStressOutputFromEdge
from .nn.interaction_blocks import NequIP_interaction_block
from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear
from .nn.node_embedding import OnehotEmbedding
from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale
from .nn.self_connection import (
SelfConnectionIntro,
SelfConnectionLinearIntro,
SelfConnectionOutro,
)
from .nn.sequential import AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings.filterwarnings(
'ignore',
message=(
"The TorchScript type system doesn't " 'support instance-level annotations'
),
)
def _insert_after(module_name_after, key_module_pair, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name_after:
idx = i
break
if idx == -1:
return layers # do nothing if not found
layers.insert(idx + 1, key_module_pair)
return layers
def init_self_connection(config):
self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE]
num_conv = config[KEY.NUM_CONVOLUTION]
if isinstance(self_connection_type_list, str):
self_connection_type_list = [self_connection_type_list] * num_conv
io_pair_list = []
for sc_type in self_connection_type_list:
if sc_type == 'none':
io_pair = None
elif sc_type == 'nequip':
io_pair = SelfConnectionIntro, SelfConnectionOutro
elif sc_type == 'linear':
io_pair = SelfConnectionLinearIntro, SelfConnectionOutro
else:
raise ValueError(f'Unknown self_connection_type found: {sc_type}')
io_pair_list.append(io_pair)
return io_pair_list
def init_edge_embedding(config):
_cutoff_param = {'cutoff_length': config[KEY.CUTOFF]}
rbf, env, sph = None, None, None
rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS])
rbf_dct.update(_cutoff_param)
rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME)
if rbf_name == 'bessel':
rbf = BesselBasis(**rbf_dct)
envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION])
envelop_dct.update(_cutoff_param)
envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME)
if envelop_name == 'poly_cut':
env = PolynomialCutoff(**envelop_dct)
elif envelop_name == 'XPLOR':
env = XPLORCutoff(**envelop_dct)
lmax_edge = config[KEY.LMAX]
if config[KEY.LMAX_EDGE] > 0:
lmax_edge = config[KEY.LMAX_EDGE]
parity = -1 if config[KEY.IS_PARITY] else 1
_normalize_sph = config[KEY._NORMALIZE_SPH]
sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph)
return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph)
def init_feature_reduce(config, irreps_x):
# features per node to scalar per node
layers = OrderedDict()
if config[KEY.READOUT_AS_FCN] is False:
hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))])
layers.update(
{
'reduce_input_to_hidden': IrrepsLinear(
irreps_x,
hidden_irreps,
data_key_in=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
'reduce_hidden_to_energy': IrrepsLinear(
hidden_irreps,
Irreps([(1, (0, 1))]),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
}
)
else:
act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]]
hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS]
layers.update(
{
'readout_FCN': FCN_e3nn(
dim_out=1,
hidden_neurons=hidden_neurons,
activation=act,
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
irreps_in=irreps_x,
)
}
)
return layers
def init_shift_scale(config):
# for mm, ex, shift: modal_idx -> shifts
shift_scale = []
train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE]
type_map = config[KEY.TYPE_MAP]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for s in (config[KEY.SHIFT], config[KEY.SCALE]):
if hasattr(s, 'tolist'): # numpy or torch
s = s.tolist()
if isinstance(s, dict):
s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()}
if isinstance(s, list) and len(s) == 1:
s = s[0]
shift_scale.append(s)
shift, scale = shift_scale
rescale_module = None
if config.get(KEY.USE_MODALITY, False):
rescale_module = ModalWiseRescale.from_mappers( # type: ignore
shift,
scale,
config[KEY.USE_MODAL_WISE_SHIFT],
config[KEY.USE_MODAL_WISE_SCALE],
type_map=type_map,
modal_map=config[KEY.MODAL_MAP],
train_shift_scale=train_shift_scale,
)
elif all([isinstance(s, float) for s in shift_scale]):
rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale)
elif any([isinstance(s, list) for s in shift_scale]):
rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore
shift, scale, type_map=type_map, train_shift_scale=train_shift_scale
)
else:
raise ValueError('shift, scale should be list of float or float')
return rescale_module
def patch_modality(layers: OrderedDict, config):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg = config
if not cfg.get(KEY.USE_MODALITY, False):
return layers
_layers = list(layers.items())
_layers = _insert_after(
'onehot_idx_to_onehot',
(
'one_hot_modality',
OnehotEmbedding(
num_classes=config[KEY.NUM_MODALITIES],
data_key_x=KEY.MODAL_TYPE,
data_key_out=KEY.MODAL_ATTR,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
layers = OrderedDict(_layers)
num_modal = config[KEY.NUM_MODALITIES]
for k, module in layers.items():
if not isinstance(module, IrrepsLinear):
continue
if (
(cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x'))
or (
cfg[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1')
)
or (
cfg[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2')
)
or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden')
):
module.set_num_modalities(num_modal)
return layers
def patch_cue(layers: OrderedDict, config):
import sevenn.nn.cue_helper as cue_helper
cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {}))
if not cue_cfg.pop('use', False):
return layers
if not cue_helper.is_cue_available():
warnings.warn(
(
'cuEquivariance is requested, but the package is not installed. '
+ 'Fallback to original code.'
)
)
return layers
if not cue_helper.is_cue_cuda_available_model(config):
return layers
group = 'O3' if config[KEY.IS_PARITY] else 'SO3'
cueq_module_params = dict(layout='mul_ir')
cueq_module_params.update(cue_cfg)
updates = {}
for k, module in layers.items():
if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)):
if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape
continue
module_patched = cue_helper.patch_linear(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, SelfConnectionIntro):
module_patched = cue_helper.patch_fully_connected(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, IrrepsConvolution):
module_patched = cue_helper.patch_convolution(
module, group, **cueq_module_params
)
updates[k] = module_patched
layers.update(updates)
return layers
def patch_modules(layers: OrderedDict, config):
layers = patch_modality(layers, config)
layers = patch_cue(layers, config)
return layers
def _to_parallel_model(layers: OrderedDict, config):
num_classes = layers['onehot_idx_to_onehot'].num_classes
one_hot_irreps = Irreps(f'{num_classes}x0e')
irreps_node_zero = layers['onehot_to_feature_x'].irreps_out
_layers = list(layers.items())
layers_list = []
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
def slice_until_this(module_name, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name:
idx = i
break
first_to = layers[: idx + 1]
remain = layers[idx + 1 :]
return first_to, remain
_layers = _insert_after(
'onehot_to_feature_x',
(
'one_hot_ghost',
OnehotEmbedding(
data_key_x=KEY.NODE_FEATURE_GHOST,
num_classes=num_classes,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
_layers = _insert_after(
'one_hot_ghost',
(
'ghost_onehot_to_feature_x',
IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
_layers = _insert_after(
'0_self_interaction_1',
(
'ghost_0_self_interaction_1',
IrrepsLinear(
irreps_node_zero,
irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for i in range(1, num_convolution_layer):
sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers)
layers_list.append(OrderedDict(sliced))
_layers.insert(0, ('edge_embedding', init_edge_embedding(config)))
layers_list.append(OrderedDict(_layers))
del layers_list[-1]['force_output'] # done in LAMMPS
return layers_list
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[False] = False
) -> AtomGraphSequential: # noqa
...
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[True]
) -> List[AtomGraphSequential]: # noqa
...
def build_E3_equivariant_model(
config: dict, parallel: bool = False
) -> Union[AtomGraphSequential, List[AtomGraphSequential]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers = OrderedDict()
cutoff = config[KEY.CUTOFF]
num_species = config[KEY.NUM_SPECIES]
feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY]
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
interaction_type = config[KEY.INTERACTION_TYPE]
use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR]
lmax_node = config[KEY.LMAX] # ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if config[KEY.LMAX_NODE] > 0:
lmax_node = config[KEY.LMAX_NODE]
act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]]
self_connection_pair_list = init_self_connection(config)
irreps_manual = None
if config[KEY.IRREPS_MANUAL] is not False:
irreps_manual = config[KEY.IRREPS_MANUAL]
try:
irreps_manual = [Irreps(irr) for irr in irreps_manual]
assert len(irreps_manual) == num_convolution_layer + 1
except Exception:
raise RuntimeError('invalid irreps_manual input given')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if not isinstance(conv_denominator, list):
conv_denominator = [conv_denominator] * num_convolution_layer
train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR]
edge_embedding = init_edge_embedding(config)
irreps_filter = edge_embedding.spherical.irreps_out
radial_basis_num = edge_embedding.basis_function.num_basis
layers.update({'edge_embedding': edge_embedding})
one_hot_irreps = Irreps(f'{num_species}x0e')
irreps_x = (
Irreps(f'{feature_multiplicity}x0e')
if irreps_manual is None
else irreps_manual[0]
)
layers.update(
{
'onehot_idx_to_onehot': OnehotEmbedding(
num_classes=num_species,
data_key_x=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE,
data_key_save=KEY.ATOM_TYPE, # atomic numbers
data_key_additional=KEY.NODE_ATTR, # one-hot embeddings
),
'onehot_to_feature_x': IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_x,
data_key_in=KEY.NODE_FEATURE,
biases=use_bias_in_linear,
),
}
)
weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS]
weight_nn_layers = [radial_basis_num] + weight_nn_hidden
param_interaction_block = {
'irreps_filter': irreps_filter,
'weight_nn_layers': weight_nn_layers,
'train_conv_denominator': train_conv_denominator,
'act_radial': act_radial,
'bias_in_linear': use_bias_in_linear,
'num_species': num_species,
'parallel': parallel,
}
interaction_builder = None
if interaction_type in ['nequip']:
act_scalar = {}
act_gate = {}
for k, v in config[KEY.ACTIVATION_SCARLAR].items():
act_scalar[k] = _const.ACTIVATION_DICT[k][v]
for k, v in config[KEY.ACTIVATION_GATE].items():
act_gate[k] = _const.ACTIVATION_DICT[k][v]
param_interaction_block.update(
{
'act_scalar': act_scalar,
'act_gate': act_gate,
}
)
if interaction_type == 'nequip':
interaction_builder = NequIP_interaction_block
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
for t in range(num_convolution_layer):
param_interaction_block.update(
{
'irreps_x': irreps_x,
't': t,
'conv_denominator': conv_denominator[t],
'self_connection_pair': self_connection_pair_list[t],
}
)
if interaction_type == 'nequip':
parity_mode = 'full'
fix_multiplicity = False
if t == num_convolution_layer - 1:
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = (
util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
lmax_node, # type: ignore
parity_mode,
fix_multiplicity=feature_multiplicity,
)
if irreps_manual is None
else irreps_manual[t + 1]
)
irreps_out_tp = util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
irreps_out.lmax, # type: ignore
parity_mode,
fix_multiplicity,
)
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
param_interaction_block.update(
{
'irreps_out_tp': irreps_out_tp,
'irreps_out': irreps_out,
}
)
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out
layers.update(init_feature_reduce(config, irreps_x))
layers.update(
{
'rescale_atomic_energy': init_shift_scale(config),
'reduce_total_enegy': AtomReduce(
data_key_in=KEY.ATOMIC_ENERGY,
data_key_out=KEY.PRED_TOTAL_ENERGY,
),
}
)
gradient_module = ForceStressOutputFromEdge()
grad_key = gradient_module.get_grad_key()
layers.update({'force_output': gradient_module})
common_args = {
'cutoff': cutoff,
'type_map': config[KEY.TYPE_MAP],
'modal_map': config.get(KEY.MODAL_MAP, None),
'eval_type_map': False if parallel else True,
'eval_modal_map': False
if not config.get(KEY.USE_MODALITY, False) or parallel
else True,
'data_key_grad': grad_key,
}
if parallel:
layers_list = _to_parallel_model(layers, config)
return [
AtomGraphSequential(patch_modules(layers, config), **common_args)
for layers in layers_list
]
else:
return AtomGraphSequential(patch_modules(layers, config), **common_args)
import glob
import os
import warnings
from typing import Any, Callable, Dict
import torch
import yaml
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
def config_initialize(
key: str,
config: Dict,
default: Any,
conditions: Dict,
):
# default value exist & no user input -> return default
if key not in config.keys():
return default
# No validation method exist => accept user input
user_input = config[key]
if key in conditions:
condition = conditions[key]
else:
return user_input
if type(default) is dict and isinstance(condition, dict):
for i_key, val in default.items():
user_input[i_key] = config_initialize(
i_key, user_input, val, condition
)
return user_input
elif isinstance(condition, type):
if isinstance(user_input, condition):
return user_input
else:
try:
return condition(user_input) # try type casting
except ValueError:
raise ValueError(
f"Expect '{user_input}' for '{key}' is {condition}"
)
elif isinstance(condition, Callable) and condition(user_input):
return user_input
else:
raise ValueError(
f"Given input '{user_input}' for '{key}' is not valid"
)
def init_model_config(config: Dict):
# defaults = _const.model_defaults(config)
model_meta = {}
# init complicated ones
if KEY.CHEMICAL_SPECIES not in config.keys():
raise ValueError('required key chemical_species not exist')
input_chem = config[KEY.CHEMICAL_SPECIES]
if isinstance(input_chem, str) and input_chem.lower() == 'auto':
model_meta[KEY.CHEMICAL_SPECIES] = 'auto'
model_meta[KEY.NUM_SPECIES] = 'auto'
model_meta[KEY.TYPE_MAP] = 'auto'
elif isinstance(input_chem, str) and 'univ' in input_chem.lower():
model_meta.update(util.chemical_species_preprocess([], universal=True))
else:
if isinstance(input_chem, list) and all(
isinstance(x, str) for x in input_chem
):
pass
elif isinstance(input_chem, str):
input_chem = (
input_chem.replace('-', ',').replace(' ', ',').split(',')
)
input_chem = [chem for chem in input_chem if len(chem) != 0]
else:
raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange')
model_meta.update(util.chemical_species_preprocess(input_chem))
# deprecation warnings
if KEY.AVG_NUM_NEIGH in config:
warnings.warn(
"key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'."
' We use the default, the average number of neighbors in the'
' dataset, if not provided.',
UserWarning,
)
config.pop(KEY.AVG_NUM_NEIGH)
if KEY.TRAIN_AVG_NUM_NEIGH in config:
warnings.warn(
"key 'train_avg_num_neigh' is deprecated. Please use"
" 'train_denominator'. We overwrite train_denominator as given"
' train_avg_num_neigh',
UserWarning,
)
config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH]
config.pop(KEY.TRAIN_AVG_NUM_NEIGH)
if KEY.OPTIMIZE_BY_REDUCE in config:
warnings.warn(
"key 'optimize_by_reduce' is deprecated. Always true",
UserWarning,
)
config.pop(KEY.OPTIMIZE_BY_REDUCE)
# init simpler ones
for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items():
model_meta[key] = config_initialize(
key, config, default, _const.MODEL_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in model_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected model keys: {unknown_keys} will be ignored',
UserWarning,
)
return model_meta
def init_train_config(config: Dict):
train_meta = {}
# defaults = _const.train_defaults(config)
try:
device_input = config[KEY.DEVICE]
train_meta[KEY.DEVICE] = torch.device(device_input)
except KeyError:
train_meta[KEY.DEVICE] = (
torch.device('cuda')
if torch.cuda.is_available()
else torch.device('cpu')
)
train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE])
# init simpler ones
for key, default in _const.DEFAULT_TRAINING_CONFIG.items():
train_meta[key] = config_initialize(
key, config, default, _const.TRAINING_CONFIG_CONDITION
)
if KEY.CONTINUE in config.keys():
cnt_dct = config[KEY.CONTINUE]
if KEY.CHECKPOINT not in cnt_dct.keys():
raise ValueError('no checkpoint is given in continue')
checkpoint = cnt_dct[KEY.CHECKPOINT]
if os.path.isfile(checkpoint):
checkpoint_file = checkpoint
else:
checkpoint_file = util.pretrained_name_to_path(checkpoint)
train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file})
unknown_keys = [
key for key in config.keys() if key not in train_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected train keys: {unknown_keys} will be ignored',
UserWarning,
)
return train_meta
def init_data_config(config: Dict):
data_meta = {}
# defaults = _const.data_defaults(config)
load_data_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
load_data_keys.append(k)
for load_data_key in load_data_keys:
if load_data_key in config.keys():
inp = config[load_data_key]
extended = []
if type(inp) not in [str, list]:
raise ValueError(f'unexpected input {inp} for sturcture_list')
if type(inp) is str:
extended = glob.glob(inp)
elif type(inp) is list:
for i in inp:
if isinstance(i, str):
extended.extend(glob.glob(i))
elif isinstance(i, dict):
extended.append(i)
if len(extended) == 0:
raise ValueError(
f'Cannot find {inp} for {load_data_key}'
+ ' or path is not given'
)
data_meta[load_data_key] = extended
else:
data_meta[load_data_key] = False
for key, default in _const.DEFAULT_DATA_CONFIG.items():
data_meta[key] = config_initialize(
key, config, default, _const.DATA_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in data_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected data keys: {unknown_keys} will be ignored',
UserWarning,
)
return data_meta
def read_config_yaml(filename: str, return_separately: bool = False):
with open(filename, 'r') as fstream:
inputs = yaml.safe_load(fstream)
model_meta, train_meta, data_meta = {}, {}, {}
for key, config in inputs.items():
if key == 'model':
model_meta = init_model_config(config)
elif key == 'train':
train_meta = init_train_config(config)
elif key == 'data':
data_meta = init_data_config(config)
else:
raise ValueError(f'Unexpected input {key} given')
if return_separately:
return model_meta, train_meta, data_meta
else:
model_meta.update(train_meta)
model_meta.update(data_meta)
return model_meta
def main():
filename = './input.yaml'
read_config_yaml(filename)
if __name__ == '__main__':
main()
import warnings
from .logger import * # noqa: F403
warnings.warn('Please use sevenn.logger instead of sevenn.sevenn_logger',
DeprecationWarning, stacklevel=2)
import warnings
from .calculator import * # noqa: F403
warnings.warn('Please use sevenn.calculator instead of sevenn.sevennet_calculator',
DeprecationWarning, stacklevel=2)
import os
import os.path as osp
import pathlib
import shutil
from typing import Dict, List, Tuple, Union
import numpy as np
import requests
import torch
import torch.nn
from e3nn.o3 import FullTensorProduct, Irreps
from tqdm import tqdm
import sevenn._const as _const
import sevenn._keys as KEY
def to_atom_graph_list(atom_graph_batch):
"""
torch_geometric batched data to separate list
original to_data_list() by PyG is not enough since
it doesn't handle inferred tensors
"""
is_stress = KEY.PRED_STRESS in atom_graph_batch
data_list = atom_graph_batch.to_data_list()
indices = atom_graph_batch[KEY.NUM_ATOMS].tolist()
atomic_energy_list = torch.split(atom_graph_batch[KEY.ATOMIC_ENERGY], indices)
inferred_total_energy_list = torch.unbind(
atom_graph_batch[KEY.PRED_TOTAL_ENERGY]
)
inferred_force_list = torch.split(atom_graph_batch[KEY.PRED_FORCE], indices)
inferred_stress_list = None
if is_stress:
inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS])
for i, data in enumerate(data_list):
data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i]
data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i]
data[KEY.PRED_FORCE] = inferred_force_list[i]
# To fit with KEY.STRESS (ref) format
if is_stress and inferred_stress_list is not None:
data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0)
return data_list
def error_recorder_from_loss_functions(loss_functions):
from .error_recorder import ErrorRecorder, MAError, RMSError, get_err_type
from .train.loss import ForceLoss, PerAtomEnergyLoss, StressLoss
metrics = []
for loss_function, _ in loss_functions:
ref_key = loss_function.ref_key
pred_key = loss_function.pred_key
# unit = loss_function.unit
criterion = loss_function.criterion
name = loss_function.name
base = None
if type(loss_function) is PerAtomEnergyLoss:
base = get_err_type('Energy')
elif type(loss_function) is ForceLoss:
base = get_err_type('Force')
elif type(loss_function) is StressLoss:
base = get_err_type('Stress')
else:
base = {}
base['name'] = name
base['ref_key'] = ref_key
base['pred_key'] = pred_key
if type(criterion) is torch.nn.MSELoss:
base['name'] = base['name'] + '_RMSE'
metrics.append(RMSError(**base))
elif type(criterion) is torch.nn.L1Loss:
metrics.append(MAError(**base))
return ErrorRecorder(metrics)
def onehot_to_chem(one_hot_indices: List[int], type_map: Dict[int, int]):
from ase.data import chemical_symbols
type_map_rev = {v: k for k, v in type_map.items()}
return [chemical_symbols[type_map_rev[x]] for x in one_hot_indices]
def model_from_checkpoint(
checkpoint: str,
) -> Tuple[torch.nn.Module, Dict]:
cp = load_checkpoint(checkpoint)
model = cp.build_model()
return model, cp.config
def model_from_checkpoint_with_backend(
checkpoint: str,
backend: str = 'e3nn',
) -> Tuple[torch.nn.Module, Dict]:
cp = load_checkpoint(checkpoint)
model = cp.build_model(backend)
return model, cp.config
def unlabeled_atoms_to_input(atoms, cutoff: float, grad_key: str = KEY.EDGE_VEC):
from .atom_graph_data import AtomGraphData
from .train.dataload import unlabeled_atoms_to_graph
atom_graph = AtomGraphData.from_numpy_dict(
unlabeled_atoms_to_graph(atoms, cutoff)
)
atom_graph[grad_key].requires_grad_(True)
atom_graph[KEY.BATCH] = torch.zeros([0])
return atom_graph
def chemical_species_preprocess(input_chem: List[str], universal: bool = False):
from ase.data import atomic_numbers, chemical_symbols
from .nn.node_embedding import get_type_mapper_from_specie
config = {}
if not universal:
input_chem = list(set(input_chem))
chemical_specie = sorted([x.strip() for x in input_chem])
config[KEY.CHEMICAL_SPECIES] = chemical_specie
config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = [
atomic_numbers[x] for x in chemical_specie
]
config[KEY.NUM_SPECIES] = len(chemical_specie)
config[KEY.TYPE_MAP] = get_type_mapper_from_specie(chemical_specie)
else:
config[KEY.CHEMICAL_SPECIES] = chemical_symbols
len_univ = len(chemical_symbols)
config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = list(range(len_univ))
config[KEY.NUM_SPECIES] = len_univ
config[KEY.TYPE_MAP] = {z: z for z in range(len_univ)}
return config
def dtype_correct(
v: Union[np.ndarray, torch.Tensor, int, float],
float_dtype: torch.dtype = torch.float32,
int_dtype: torch.dtype = torch.int64,
):
if isinstance(v, np.ndarray):
if np.issubdtype(v.dtype, np.floating):
return torch.from_numpy(v).to(float_dtype)
elif np.issubdtype(v.dtype, np.integer):
return torch.from_numpy(v).to(int_dtype)
elif isinstance(v, torch.Tensor):
if v.dtype.is_floating_point:
return v.to(float_dtype) # convert to specified float dtype
else: # assuming non-floating point tensors are integers
return v.to(int_dtype) # convert to specified int dtype
else: # scalar values
if isinstance(v, int):
return torch.tensor(v, dtype=int_dtype)
elif isinstance(v, float):
return torch.tensor(v, dtype=float_dtype)
else: # Not numeric
return v
def infer_irreps_out(
irreps_x: Irreps,
irreps_operand: Irreps,
drop_l: Union[bool, int] = False,
parity_mode: str = 'full',
fix_multiplicity: Union[bool, int] = False,
):
assert parity_mode in ['full', 'even', 'sph']
# (mul, (ir, p))
irreps_out = FullTensorProduct(irreps_x, irreps_operand).irreps_out.simplify()
new_irreps_elem = []
for mul, (l, p) in irreps_out: # noqa
elem = (mul, (l, p))
if drop_l is not False and l > drop_l:
continue
if parity_mode == 'even' and p == -1:
continue
elif parity_mode == 'sph' and p != (-1) ** l:
continue
if fix_multiplicity:
elem = (fix_multiplicity, (l, p))
new_irreps_elem.append(elem)
return Irreps(new_irreps_elem)
def download_checkpoint(path: str, url: str):
fname = osp.basename(path)
temp_path = path + '.partial'
try:
# raises permission error if fails
os.makedirs(osp.dirname(path), exist_ok=True)
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status() # Raise exception for bad status codes
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 KB chunks
progress_bar = tqdm(
total=total_size,
unit='B',
unit_scale=True,
desc=f'Downloading {fname}',
)
with open(temp_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
shutil.move(temp_path, path)
print(f'Checkpoint downloaded: {path}')
return path
except PermissionError:
raise
except Exception as e:
# Clean up partial downloads on failure
# May not work as errors handled internally by tqdm etc.
print(f'Download failed: {str(e)}')
if os.path.exists(temp_path):
print(f'Cleaning up partial download: {temp_path}')
os.remove(temp_path)
raise
def pretrained_name_to_path(name: str) -> str:
name = name.lower()
heads = ['sevennet', '7net']
checkpoint_path = None
url = None
if ( # TODO: regex
name in [f'{n}-0_11july2024' for n in heads]
or name in [f'{n}-0_11jul2024' for n in heads]
or name in ['sevennet-0', '7net-0']
):
checkpoint_path = _const.SEVENNET_0_11Jul2024
elif name in [f'{n}-0_22may2024' for n in heads]:
checkpoint_path = _const.SEVENNET_0_22May2024
elif name in [f'{n}-l3i5' for n in heads]:
checkpoint_path = _const.SEVENNET_l3i5
elif name in [f'{n}-mf-0' for n in heads]:
checkpoint_path = _const.SEVENNET_MF_0
elif name in [f'{n}-mf-ompa' for n in heads]:
checkpoint_path = _const.SEVENNET_MF_ompa
elif name in [f'{n}-omat' for n in heads]:
checkpoint_path = _const.SEVENNET_omat
else:
raise ValueError('Not a valid pretrained model name')
url = _const.CHECKPOINT_DOWNLOAD_LINKS.get(checkpoint_path)
paths = [
checkpoint_path,
checkpoint_path.replace(_const._prefix, osp.expanduser('~/.cache/sevennet')),
]
for path in paths:
if osp.exists(path):
return path
# File not found check url and try download
if url is None:
raise FileNotFoundError(checkpoint_path)
try:
return download_checkpoint(paths[0], url) # 7net package path
except PermissionError:
return download_checkpoint(paths[1], url) # ~/.cache
def load_checkpoint(checkpoint: Union[pathlib.Path, str]):
from sevenn.checkpoint import SevenNetCheckpoint
suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat']
if osp.isfile(checkpoint):
checkpoint_path = checkpoint
else:
try:
checkpoint_path = pretrained_name_to_path(str(checkpoint))
except ValueError:
raise ValueError(
f'Given {checkpoint} is not exists and not a pre-trained name.\n'
f'Valid pretrained model names: {suggests}'
)
return SevenNetCheckpoint(checkpoint_path)
def unique_filepath(filepath: str) -> str:
if not os.path.isfile(filepath):
return filepath
else:
dirname = os.path.dirname(filepath)
fname = os.path.basename(filepath)
name, ext = os.path.splitext(fname)
cnt = 0
new_name = f'{name}{cnt}{ext}'
new_path = os.path.join(dirname, new_name)
while os.path.exists(new_path):
cnt += 1
new_name = f'{name}{cnt}{ext}'
new_path = os.path.join(dirname, new_name)
return new_path
def get_error_recorder(
recorder_tuples: List[Tuple[str, str]] = [
('Energy', 'RMSE'),
('Force', 'RMSE'),
('Stress', 'RMSE'),
('Energy', 'MAE'),
('Force', 'MAE'),
('Stress', 'MAE'),
],
):
# TODO add criterion argument and loss recorder selections
import sevenn.error_recorder as error_recorder
config = recorder_tuples
err_metrics = []
for err_type, metric_name in config:
metric_kwargs = error_recorder.get_err_type(err_type).copy()
metric_kwargs['name'] += f'_{metric_name}'
metric_cls = error_recorder.ErrorRecorder.METRIC_DICT[metric_name]
err_metrics.append(metric_cls(**metric_kwargs))
return error_recorder.ErrorRecorder(err_metrics)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment