Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
nivren
ICT-CSP
Commits
73866b01
Unverified
Commit
73866b01
authored
Aug 24, 2025
by
zcxzcx1
Committed by
GitHub
Aug 24, 2025
Browse files
Add files via upload
parent
ca86f720
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
3985 additions
and
0 deletions
+3985
-0
mace-bench/3rdparty/SevenNet/sevenn/__init__.py
mace-bench/3rdparty/SevenNet/sevenn/__init__.py
+13
-0
mace-bench/3rdparty/SevenNet/sevenn/_const.py
mace-bench/3rdparty/SevenNet/sevenn/_const.py
+310
-0
mace-bench/3rdparty/SevenNet/sevenn/_keys.py
mace-bench/3rdparty/SevenNet/sevenn/_keys.py
+226
-0
mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py
mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py
+75
-0
mace-bench/3rdparty/SevenNet/sevenn/build.ninja
mace-bench/3rdparty/SevenNet/sevenn/build.ninja
+33
-0
mace-bench/3rdparty/SevenNet/sevenn/calculator.py
mace-bench/3rdparty/SevenNet/sevenn/calculator.py
+846
-0
mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py
mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py
+552
-0
mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py
mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py
+430
-0
mace-bench/3rdparty/SevenNet/sevenn/logger.py
mace-bench/3rdparty/SevenNet/sevenn/logger.py
+336
-0
mace-bench/3rdparty/SevenNet/sevenn/logo_ascii
mace-bench/3rdparty/SevenNet/sevenn/logo_ascii
+20
-0
mace-bench/3rdparty/SevenNet/sevenn/model_build.py
mace-bench/3rdparty/SevenNet/sevenn/model_build.py
+556
-0
mace-bench/3rdparty/SevenNet/sevenn/pair_d3.so
mace-bench/3rdparty/SevenNet/sevenn/pair_d3.so
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/pair_d3_for_ase.cuda.o
mace-bench/3rdparty/SevenNet/sevenn/pair_d3_for_ase.cuda.o
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/parse_input.py
mace-bench/3rdparty/SevenNet/sevenn/parse_input.py
+246
-0
mace-bench/3rdparty/SevenNet/sevenn/py.typed
mace-bench/3rdparty/SevenNet/sevenn/py.typed
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py
mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py
+6
-0
mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py
mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py
+6
-0
mace-bench/3rdparty/SevenNet/sevenn/util.py
mace-bench/3rdparty/SevenNet/sevenn/util.py
+330
-0
No files found.
mace-bench/3rdparty/SevenNet/sevenn/__init__.py
0 → 100644
View file @
73866b01
from
importlib.metadata
import
version
from
packaging.version
import
Version
__version__
=
version
(
'sevenn'
)
from
e3nn
import
__version__
as
e3nn_ver
if
Version
(
e3nn_ver
)
<
Version
(
'0.5.0'
):
raise
ValueError
(
'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient '
'convention.'
)
mace-bench/3rdparty/SevenNet/sevenn/_const.py
0 → 100644
View file @
73866b01
import
os
from
enum
import
Enum
from
typing
import
Dict
import
torch
import
sevenn._keys
as
KEY
from
sevenn.nn.activation
import
ShiftedSoftPlus
NUM_UNIV_ELEMENT
=
119
# Z = 0 ~ 118
IMPLEMENTED_RADIAL_BASIS
=
[
'bessel'
]
IMPLEMENTED_CUTOFF_FUNCTION
=
[
'poly_cut'
,
'XPLOR'
]
# TODO: support None. This became difficult because of parallel model
IMPLEMENTED_SELF_CONNECTION_TYPE
=
[
'nequip'
,
'linear'
]
IMPLEMENTED_INTERACTION_TYPE
=
[
'nequip'
]
IMPLEMENTED_SHIFT
=
[
'per_atom_energy_mean'
,
'elemwise_reference_energies'
]
IMPLEMENTED_SCALE
=
[
'force_rms'
,
'per_atom_energy_std'
,
'elemwise_force_rms'
]
SUPPORTING_METRICS
=
[
'RMSE'
,
'ComponentRMSE'
,
'MAE'
,
'Loss'
]
SUPPORTING_ERROR_TYPES
=
[
'TotalEnergy'
,
'Energy'
,
'Force'
,
'Stress'
,
'Stress_GPa'
,
'TotalLoss'
,
]
IMPLEMENTED_MODEL
=
[
'E3_equivariant_model'
]
# string input to real torch function
ACTIVATION
=
{
'relu'
:
torch
.
nn
.
functional
.
relu
,
'silu'
:
torch
.
nn
.
functional
.
silu
,
'tanh'
:
torch
.
tanh
,
'abs'
:
torch
.
abs
,
'ssp'
:
ShiftedSoftPlus
,
'sigmoid'
:
torch
.
sigmoid
,
'elu'
:
torch
.
nn
.
functional
.
elu
,
}
ACTIVATION_FOR_EVEN
=
{
'ssp'
:
ShiftedSoftPlus
,
'silu'
:
torch
.
nn
.
functional
.
silu
,
}
ACTIVATION_FOR_ODD
=
{
'tanh'
:
torch
.
tanh
,
'abs'
:
torch
.
abs
}
ACTIVATION_DICT
=
{
'e'
:
ACTIVATION_FOR_EVEN
,
'o'
:
ACTIVATION_FOR_ODD
}
_prefix
=
os
.
path
.
abspath
(
f
'
{
os
.
path
.
dirname
(
__file__
)
}
/pretrained_potentials'
)
SEVENNET_0_11Jul2024
=
f
'
{
_prefix
}
/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth'
SEVENNET_0_22May2024
=
f
'
{
_prefix
}
/SevenNet_0__22May2024/checkpoint_sevennet_0.pth'
SEVENNET_l3i5
=
f
'
{
_prefix
}
/SevenNet_l3i5/checkpoint_l3i5.pth'
SEVENNET_MF_0
=
f
'
{
_prefix
}
/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth'
SEVENNET_MF_ompa
=
f
'
{
_prefix
}
/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth'
SEVENNET_omat
=
f
'
{
_prefix
}
/SevenNet_omat/checkpoint_sevennet_omat.pth'
_git_prefix
=
'https://github.com/MDIL-SNU/SevenNet/releases/download'
CHECKPOINT_DOWNLOAD_LINKS
=
{
SEVENNET_MF_ompa
:
f
'
{
_git_prefix
}
/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth'
,
SEVENNET_omat
:
f
'
{
_git_prefix
}
/v0.11.0.cp/checkpoint_sevennet_omat.pth'
,
}
# to avoid torch script to compile torch_geometry.data
AtomGraphDataType
=
Dict
[
str
,
torch
.
Tensor
]
class
LossType
(
Enum
):
# only used for train_v1, do not use it afterwards
ENERGY
=
'energy'
# eV or eV/atom
FORCE
=
'force'
# eV/A
STRESS
=
'stress'
# kB
def
error_record_condition
(
x
):
if
type
(
x
)
is
not
list
:
return
False
for
v
in
x
:
if
type
(
v
)
is
not
list
or
len
(
v
)
!=
2
:
return
False
if
v
[
0
]
not
in
SUPPORTING_ERROR_TYPES
:
return
False
if
v
[
0
]
==
'TotalLoss'
:
continue
if
v
[
1
]
not
in
SUPPORTING_METRICS
:
return
False
return
True
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
=
{
KEY
.
CUTOFF
:
4.5
,
KEY
.
NODE_FEATURE_MULTIPLICITY
:
32
,
KEY
.
IRREPS_MANUAL
:
False
,
KEY
.
LMAX
:
1
,
KEY
.
LMAX_EDGE
:
-
1
,
# -1 means lmax_edge = lmax
KEY
.
LMAX_NODE
:
-
1
,
# -1 means lmax_node = lmax
KEY
.
IS_PARITY
:
True
,
KEY
.
NUM_CONVOLUTION
:
3
,
KEY
.
RADIAL_BASIS
:
{
KEY
.
RADIAL_BASIS_NAME
:
'bessel'
,
},
KEY
.
CUTOFF_FUNCTION
:
{
KEY
.
CUTOFF_FUNCTION_NAME
:
'poly_cut'
,
},
KEY
.
ACTIVATION_RADIAL
:
'silu'
,
KEY
.
ACTIVATION_SCARLAR
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
KEY
.
ACTIVATION_GATE
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
KEY
.
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
:
[
64
,
64
],
# KEY.AVG_NUM_NEIGH: True, # deprecated
# KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
KEY
.
CONV_DENOMINATOR
:
'avg_num_neigh'
,
KEY
.
TRAIN_DENOMINTAOR
:
False
,
KEY
.
TRAIN_SHIFT_SCALE
:
False
,
# KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY
.
USE_BIAS_IN_LINEAR
:
False
,
KEY
.
USE_MODAL_NODE_EMBEDDING
:
False
,
KEY
.
USE_MODAL_SELF_INTER_INTRO
:
False
,
KEY
.
USE_MODAL_SELF_INTER_OUTRO
:
False
,
KEY
.
USE_MODAL_OUTPUT_BLOCK
:
False
,
KEY
.
READOUT_AS_FCN
:
False
,
# Applied af readout as fcn is True
KEY
.
READOUT_FCN_HIDDEN_NEURONS
:
[
30
,
30
],
KEY
.
READOUT_FCN_ACTIVATION
:
'relu'
,
KEY
.
SELF_CONNECTION_TYPE
:
'nequip'
,
KEY
.
INTERACTION_TYPE
:
'nequip'
,
KEY
.
_NORMALIZE_SPH
:
True
,
KEY
.
CUEQUIVARIANCE_CONFIG
:
{},
}
# Basically, "If provided, it should be type of ..."
MODEL_CONFIG_CONDITION
=
{
KEY
.
NODE_FEATURE_MULTIPLICITY
:
int
,
KEY
.
LMAX
:
int
,
KEY
.
LMAX_EDGE
:
int
,
KEY
.
LMAX_NODE
:
int
,
KEY
.
IS_PARITY
:
bool
,
KEY
.
RADIAL_BASIS
:
{
KEY
.
RADIAL_BASIS_NAME
:
lambda
x
:
x
in
IMPLEMENTED_RADIAL_BASIS
,
},
KEY
.
CUTOFF_FUNCTION
:
{
KEY
.
CUTOFF_FUNCTION_NAME
:
lambda
x
:
x
in
IMPLEMENTED_CUTOFF_FUNCTION
,
},
KEY
.
CUTOFF
:
float
,
KEY
.
NUM_CONVOLUTION
:
int
,
KEY
.
CONV_DENOMINATOR
:
lambda
x
:
isinstance
(
x
,
float
)
or
x
in
[
'avg_num_neigh'
,
'sqrt_avg_num_neigh'
,
],
KEY
.
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
:
list
,
KEY
.
TRAIN_SHIFT_SCALE
:
bool
,
KEY
.
TRAIN_DENOMINTAOR
:
bool
,
KEY
.
USE_BIAS_IN_LINEAR
:
bool
,
KEY
.
USE_MODAL_NODE_EMBEDDING
:
bool
,
KEY
.
USE_MODAL_SELF_INTER_INTRO
:
bool
,
KEY
.
USE_MODAL_SELF_INTER_OUTRO
:
bool
,
KEY
.
USE_MODAL_OUTPUT_BLOCK
:
bool
,
KEY
.
READOUT_AS_FCN
:
bool
,
KEY
.
READOUT_FCN_HIDDEN_NEURONS
:
list
,
KEY
.
READOUT_FCN_ACTIVATION
:
str
,
KEY
.
ACTIVATION_RADIAL
:
str
,
KEY
.
SELF_CONNECTION_TYPE
:
lambda
x
:
(
x
in
IMPLEMENTED_SELF_CONNECTION_TYPE
or
(
isinstance
(
x
,
list
)
and
all
(
sc
in
IMPLEMENTED_SELF_CONNECTION_TYPE
for
sc
in
x
)
)
),
KEY
.
INTERACTION_TYPE
:
lambda
x
:
x
in
IMPLEMENTED_INTERACTION_TYPE
,
KEY
.
_NORMALIZE_SPH
:
bool
,
KEY
.
CUEQUIVARIANCE_CONFIG
:
dict
,
}
def
model_defaults
(
config
):
defaults
=
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
if
KEY
.
READOUT_AS_FCN
not
in
config
:
config
[
KEY
.
READOUT_AS_FCN
]
=
defaults
[
KEY
.
READOUT_AS_FCN
]
if
config
[
KEY
.
READOUT_AS_FCN
]
is
False
:
defaults
.
pop
(
KEY
.
READOUT_FCN_ACTIVATION
,
None
)
defaults
.
pop
(
KEY
.
READOUT_FCN_HIDDEN_NEURONS
,
None
)
return
defaults
DEFAULT_DATA_CONFIG
=
{
KEY
.
DTYPE
:
'single'
,
KEY
.
DATA_FORMAT
:
'ase'
,
KEY
.
DATA_FORMAT_ARGS
:
{},
KEY
.
SAVE_DATASET
:
False
,
KEY
.
SAVE_BY_LABEL
:
False
,
KEY
.
SAVE_BY_TRAIN_VALID
:
False
,
KEY
.
RATIO
:
0.0
,
KEY
.
BATCH_SIZE
:
6
,
KEY
.
PREPROCESS_NUM_CORES
:
1
,
KEY
.
COMPUTE_STATISTICS
:
True
,
KEY
.
DATASET_TYPE
:
'graph'
,
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY
.
USE_MODAL_WISE_SHIFT
:
False
,
KEY
.
USE_MODAL_WISE_SCALE
:
False
,
KEY
.
SHIFT
:
'per_atom_energy_mean'
,
KEY
.
SCALE
:
'force_rms'
,
# KEY.DATA_SHUFFLE: True,
# KEY.DATA_WEIGHT: False,
# KEY.DATA_MODALITY: False,
}
DATA_CONFIG_CONDITION
=
{
KEY
.
DTYPE
:
str
,
KEY
.
DATA_FORMAT
:
str
,
KEY
.
DATA_FORMAT_ARGS
:
dict
,
KEY
.
SAVE_DATASET
:
str
,
KEY
.
SAVE_BY_LABEL
:
bool
,
KEY
.
SAVE_BY_TRAIN_VALID
:
bool
,
KEY
.
RATIO
:
float
,
KEY
.
BATCH_SIZE
:
int
,
KEY
.
PREPROCESS_NUM_CORES
:
int
,
KEY
.
DATASET_TYPE
:
lambda
x
:
x
in
[
'graph'
,
'atoms'
],
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
KEY
.
SHIFT
:
lambda
x
:
type
(
x
)
in
[
float
,
list
]
or
x
in
IMPLEMENTED_SHIFT
,
KEY
.
SCALE
:
lambda
x
:
type
(
x
)
in
[
float
,
list
]
or
x
in
IMPLEMENTED_SCALE
,
KEY
.
USE_MODAL_WISE_SHIFT
:
bool
,
KEY
.
USE_MODAL_WISE_SCALE
:
bool
,
# KEY.DATA_SHUFFLE: bool,
KEY
.
COMPUTE_STATISTICS
:
bool
,
# KEY.DATA_WEIGHT: bool,
# KEY.DATA_MODALITY: bool,
}
def
data_defaults
(
config
):
defaults
=
DEFAULT_DATA_CONFIG
if
KEY
.
LOAD_VALIDSET
in
config
:
defaults
.
pop
(
KEY
.
RATIO
,
None
)
return
defaults
DEFAULT_TRAINING_CONFIG
=
{
KEY
.
RANDOM_SEED
:
1
,
KEY
.
EPOCH
:
300
,
KEY
.
LOSS
:
'mse'
,
KEY
.
LOSS_PARAM
:
{},
KEY
.
OPTIMIZER
:
'adam'
,
KEY
.
OPTIM_PARAM
:
{},
KEY
.
SCHEDULER
:
'exponentiallr'
,
KEY
.
SCHEDULER_PARAM
:
{},
KEY
.
FORCE_WEIGHT
:
0.1
,
KEY
.
STRESS_WEIGHT
:
1e-6
,
# SIMPLE-NN default
KEY
.
PER_EPOCH
:
5
,
# KEY.USE_TESTSET: False,
KEY
.
CONTINUE
:
{
KEY
.
CHECKPOINT
:
False
,
KEY
.
RESET_OPTIMIZER
:
False
,
KEY
.
RESET_SCHEDULER
:
False
,
KEY
.
RESET_EPOCH
:
False
,
KEY
.
USE_STATISTIC_VALUES_OF_CHECKPOINT
:
True
,
KEY
.
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY
:
True
,
},
# KEY.DEFAULT_MODAL: 'common',
KEY
.
CSV_LOG
:
'log.csv'
,
KEY
.
NUM_WORKERS
:
0
,
KEY
.
IS_TRAIN_STRESS
:
True
,
KEY
.
TRAIN_SHUFFLE
:
True
,
KEY
.
ERROR_RECORD
:
[
[
'Energy'
,
'RMSE'
],
[
'Force'
,
'RMSE'
],
[
'Stress'
,
'RMSE'
],
[
'TotalLoss'
,
'None'
],
],
KEY
.
BEST_METRIC
:
'TotalLoss'
,
KEY
.
USE_WEIGHT
:
False
,
KEY
.
USE_MODALITY
:
False
,
}
TRAINING_CONFIG_CONDITION
=
{
KEY
.
RANDOM_SEED
:
int
,
KEY
.
EPOCH
:
int
,
KEY
.
FORCE_WEIGHT
:
float
,
KEY
.
STRESS_WEIGHT
:
float
,
KEY
.
USE_TESTSET
:
None
,
# Not used
KEY
.
NUM_WORKERS
:
int
,
KEY
.
PER_EPOCH
:
int
,
KEY
.
CONTINUE
:
{
KEY
.
CHECKPOINT
:
str
,
KEY
.
RESET_OPTIMIZER
:
bool
,
KEY
.
RESET_SCHEDULER
:
bool
,
KEY
.
RESET_EPOCH
:
bool
,
KEY
.
USE_STATISTIC_VALUES_OF_CHECKPOINT
:
bool
,
KEY
.
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY
:
bool
,
},
KEY
.
DEFAULT_MODAL
:
str
,
KEY
.
IS_TRAIN_STRESS
:
bool
,
KEY
.
TRAIN_SHUFFLE
:
bool
,
KEY
.
ERROR_RECORD
:
error_record_condition
,
KEY
.
BEST_METRIC
:
str
,
KEY
.
CSV_LOG
:
str
,
KEY
.
USE_MODALITY
:
bool
,
KEY
.
USE_WEIGHT
:
bool
,
}
def
train_defaults
(
config
):
defaults
=
DEFAULT_TRAINING_CONFIG
if
KEY
.
IS_TRAIN_STRESS
not
in
config
:
config
[
KEY
.
IS_TRAIN_STRESS
]
=
defaults
[
KEY
.
IS_TRAIN_STRESS
]
if
not
config
[
KEY
.
IS_TRAIN_STRESS
]:
defaults
.
pop
(
KEY
.
STRESS_WEIGHT
,
None
)
return
defaults
mace-bench/3rdparty/SevenNet/sevenn/_keys.py
0 → 100644
View file @
73866b01
"""
How to add new feature?
1. Add new key to this file.
2. Add new key to _const.py
2.1. if the type of input is consistent,
write adequate condition and default to _const.py.
2.2. if the type of input is not consistent,
you must add your own input validation code to
parse_input.py
"""
from
typing
import
Final
# see
# https://github.com/pytorch/pytorch/issues/52312
# for FYI
# ~~ keys ~~ #
# PyG : primitive key of torch_geometric.data.Data type
# ==================================================#
# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ #
# ==================================================#
# some raw properties of graph
ATOMIC_NUMBERS
:
Final
[
str
]
=
'atomic_numbers'
# (N)
POS
:
Final
[
str
]
=
'pos'
# (N, 3) PyG
CELL
:
Final
[
str
]
=
'cell_lattice_vectors'
# (3, 3)
CELL_SHIFT
:
Final
[
str
]
=
'pbc_shift'
# (N, 3)
CELL_VOLUME
:
Final
[
str
]
=
'cell_volume'
EDGE_VEC
:
Final
[
str
]
=
'edge_vec'
# (N_edge, 3)
EDGE_LENGTH
:
Final
[
str
]
=
'edge_length'
# (N_edge, 1)
# some primary data of graph
EDGE_IDX
:
Final
[
str
]
=
'edge_index'
# (2, N_edge) PyG
ATOM_TYPE
:
Final
[
str
]
=
'atom_type'
# (N) one-hot index of nodes
NODE_FEATURE
:
Final
[
str
]
=
'x'
# (N, ?) PyG
NODE_FEATURE_GHOST
:
Final
[
str
]
=
'x_ghost'
NODE_ATTR
:
Final
[
str
]
=
'node_attr'
# (N, N_species) from one_hot
MODAL_ATTR
:
Final
[
str
]
=
(
'modal_attr'
# (1, N_modalities) for handling multi-modal
)
MODAL_TYPE
:
Final
[
str
]
=
'modal_type'
# (1) one-hot index of modal
EDGE_ATTR
:
Final
[
str
]
=
'edge_attr'
# (from spherical harmonics)
EDGE_EMBEDDING
:
Final
[
str
]
=
'edge_embedding'
# (from edge embedding)
# inputs of loss function
ENERGY
:
Final
[
str
]
=
'total_energy'
# (1)
FORCE
:
Final
[
str
]
=
'force_of_atoms'
# (N, 3)
STRESS
:
Final
[
str
]
=
'stress'
# (6)
# This is for training, per atom scale.
SCALED_ENERGY
:
Final
[
str
]
=
'scaled_total_energy'
# general outputs of models
SCALED_ATOMIC_ENERGY
:
Final
[
str
]
=
'scaled_atomic_energy'
ATOMIC_ENERGY
:
Final
[
str
]
=
'atomic_energy'
PRED_TOTAL_ENERGY
:
Final
[
str
]
=
'inferred_total_energy'
PRED_PER_ATOM_ENERGY
:
Final
[
str
]
=
'inferred_per_atom_energy'
PER_ATOM_ENERGY
:
Final
[
str
]
=
'per_atom_energy'
PRED_FORCE
:
Final
[
str
]
=
'inferred_force'
SCALED_FORCE
:
Final
[
str
]
=
'scaled_force'
PRED_STRESS
:
Final
[
str
]
=
'inferred_stress'
SCALED_STRESS
:
Final
[
str
]
=
'scaled_stress'
# very general data property for AtomGraphData
NUM_ATOMS
:
Final
[
str
]
=
'num_atoms'
# int
NUM_GHOSTS
:
Final
[
str
]
=
'num_ghosts'
NLOCAL
:
Final
[
str
]
=
'nlocal'
# only for lammps parallel, must be on cpu
USER_LABEL
:
Final
[
str
]
=
'user_label'
DATA_WEIGHT
:
Final
[
str
]
=
'data_weight'
# weight for given data
DATA_MODALITY
:
Final
[
str
]
=
(
'data_modality'
# modality of given data. e.g. PBE and SCAN
)
BATCH
:
Final
[
str
]
=
'batch'
TAG
=
'tag'
# replace USER_LABEL
# etc
SELF_CONNECTION_TEMP
:
Final
[
str
]
=
'self_cont_tmp'
BATCH_SIZE
:
Final
[
str
]
=
'batch_size'
INFO
:
Final
[
str
]
=
'data_info'
# something special
LABEL_NONE
:
Final
[
str
]
=
'No_label'
# ==================================================#
# ~~~~~~ KEY for train/data configuration ~~~~~~~~ #
# ==================================================#
PREPROCESS_NUM_CORES
=
'preprocess_num_cores'
SAVE_DATASET
=
'save_dataset_path'
SAVE_BY_LABEL
=
'save_by_label'
SAVE_BY_TRAIN_VALID
=
'save_by_train_valid'
DATA_FORMAT
=
'data_format'
DATA_FORMAT_ARGS
=
'data_format_args'
STRUCTURE_LIST
=
'structure_list'
LOAD_DATASET
=
'load_dataset_path'
# not used in v2
LOAD_TRAINSET
=
'load_trainset_path'
LOAD_VALIDSET
=
'load_validset_path'
LOAD_TESTSET
=
'load_testset_path'
FORMAT_OUTPUTS
=
'format_outputs_for_ase'
COMPUTE_STATISTICS
=
'compute_statistics'
DATASET_TYPE
=
'dataset_type'
RANDOM_SEED
=
'random_seed'
RATIO
=
'data_divide_ratio'
USE_TESTSET
=
'use_testset'
EPOCH
=
'epoch'
LOSS
=
'loss'
LOSS_PARAM
=
'loss_param'
OPTIMIZER
=
'optimizer'
OPTIM_PARAM
=
'optim_param'
SCHEDULER
=
'scheduler'
SCHEDULER_PARAM
=
'scheduler_param'
FORCE_WEIGHT
=
'force_loss_weight'
STRESS_WEIGHT
=
'stress_loss_weight'
DEVICE
=
'device'
DTYPE
=
'dtype'
TRAIN_SHUFFLE
=
'train_shuffle'
IS_TRAIN_STRESS
=
'is_train_stress'
CONTINUE
=
'continue'
CHECKPOINT
=
'checkpoint'
RESET_OPTIMIZER
=
'reset_optimizer'
RESET_SCHEDULER
=
'reset_scheduler'
RESET_EPOCH
=
'reset_epoch'
USE_STATISTIC_VALUES_OF_CHECKPOINT
=
'use_statistic_values_of_checkpoint'
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY
=
(
'use_statistic_values_for_cp_modal_only'
)
CSV_LOG
=
'csv_log'
ERROR_RECORD
=
'error_record'
BEST_METRIC
=
'best_metric'
NUM_WORKERS
=
'num_workers'
# not work
RANK
=
'rank'
LOCAL_RANK
=
'local_rank'
WORLD_SIZE
=
'world_size'
IS_DDP
=
'is_ddp'
DDP_BACKEND
=
'ddp_backend'
PER_EPOCH
=
'per_epoch'
USE_WEIGHT
=
'use_weight'
USE_MODALITY
=
'use_modality'
DEFAULT_MODAL
=
'default_modal'
# ==================================================#
# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ #
# ==================================================#
# ~~ global model configuration ~~ #
# note that these names are directly used for input.yaml for user input
MODEL_TYPE
=
'_model_type'
CUTOFF
=
'cutoff'
CHEMICAL_SPECIES
=
'chemical_species'
MODAL_LIST
=
'modal_list'
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
=
'_chemical_species_by_atomic_number'
NUM_SPECIES
=
'_number_of_species'
NUM_MODALITIES
=
'_number_of_modalities'
TYPE_MAP
=
'_type_map'
MODAL_MAP
=
'_modal_map'
# ~~ E3 equivariant model build configuration keys ~~ #
# see model_build default_config for type
IRREPS_MANUAL
=
'irreps_manual'
NODE_FEATURE_MULTIPLICITY
=
'channel'
RADIAL_BASIS
=
'radial_basis'
BESSEL_BASIS_NUM
=
'bessel_basis_num'
CUTOFF_FUNCTION
=
'cutoff_function'
POLY_CUT_P
=
'poly_cut_p_value'
LMAX
=
'lmax'
LMAX_EDGE
=
'lmax_edge'
LMAX_NODE
=
'lmax_node'
IS_PARITY
=
'is_parity'
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
=
'weight_nn_hidden_neurons'
NUM_CONVOLUTION
=
'num_convolution_layer'
ACTIVATION_SCARLAR
=
'act_scalar'
ACTIVATION_GATE
=
'act_gate'
ACTIVATION_RADIAL
=
'act_radial'
SELF_CONNECTION_TYPE
=
'self_connection_type'
RADIAL_BASIS_NAME
=
'radial_basis_name'
CUTOFF_FUNCTION_NAME
=
'cutoff_function_name'
USE_BIAS_IN_LINEAR
=
'use_bias_in_linear'
USE_MODAL_NODE_EMBEDDING
=
'use_modal_node_embedding'
USE_MODAL_SELF_INTER_INTRO
=
'use_modal_self_inter_intro'
USE_MODAL_SELF_INTER_OUTRO
=
'use_modal_self_inter_outro'
USE_MODAL_OUTPUT_BLOCK
=
'use_modal_output_block'
READOUT_AS_FCN
=
'readout_as_fcn'
READOUT_FCN_HIDDEN_NEURONS
=
'readout_fcn_hidden_neurons'
READOUT_FCN_ACTIVATION
=
'readout_fcn_activation'
AVG_NUM_NEIGH
=
'avg_num_neigh'
CONV_DENOMINATOR
=
'conv_denominator'
SHIFT
=
'shift'
SCALE
=
'scale'
USE_SPECIES_WISE_SHIFT_SCALE
=
'use_species_wise_shift_scale'
USE_MODAL_WISE_SHIFT
=
'use_modal_wise_shift'
USE_MODAL_WISE_SCALE
=
'use_modal_wise_scale'
TRAIN_SHIFT_SCALE
=
'train_shift_scale'
TRAIN_DENOMINTAOR
=
'train_denominator'
INTERACTION_TYPE
=
'interaction_type'
TRAIN_AVG_NUM_NEIGH
=
'train_avg_num_neigh'
# deprecated
CUEQUIVARIANCE_CONFIG
=
'cuequivariance_config'
_NORMALIZE_SPH
=
'_normalize_sph'
OPTIMIZE_BY_REDUCE
=
'optimize_by_reduce'
mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py
0 → 100644
View file @
73866b01
from
typing
import
Optional
import
torch
import
torch_geometric.data
import
sevenn._keys
as
KEY
import
sevenn.util
class
AtomGraphData
(
torch_geometric
.
data
.
Data
):
"""
Args:
x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes,
atomic_numbers]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in coordinate
format with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y_energy: scalar # unit of eV (VASP raw)
y_force: [num_nodes, 3] # unit of eV/A (VASP raw)
y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
x, y_force, pos should be aligned with each other.
"""
def
__init__
(
self
,
x
:
Optional
[
torch
.
Tensor
]
=
None
,
edge_index
:
Optional
[
torch
.
Tensor
]
=
None
,
pos
:
Optional
[
torch
.
Tensor
]
=
None
,
edge_attr
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
super
(
AtomGraphData
,
self
).
__init__
(
x
,
edge_index
,
edge_attr
,
pos
=
pos
)
self
[
KEY
.
NODE_ATTR
]
=
x
# ?
for
k
,
v
in
kwargs
.
items
():
self
[
k
]
=
v
def
to_numpy_dict
(
self
):
# This is not debugged yet!
dct
=
{
k
:
v
.
detach
().
cpu
().
numpy
()
if
type
(
v
)
is
torch
.
Tensor
else
v
for
k
,
v
in
self
.
items
()
}
return
dct
def
fit_dimension
(
self
):
per_atom_keys
=
[
KEY
.
ATOMIC_NUMBERS
,
KEY
.
ATOMIC_ENERGY
,
KEY
.
POS
,
KEY
.
FORCE
,
KEY
.
PRED_FORCE
,
]
natoms
=
self
.
num_atoms
.
item
()
for
k
,
v
in
self
.
items
():
if
not
isinstance
(
v
,
torch
.
Tensor
):
continue
if
natoms
==
1
and
k
in
per_atom_keys
:
self
[
k
]
=
v
.
squeeze
().
unsqueeze
(
0
)
else
:
self
[
k
]
=
v
.
squeeze
()
return
self
@
staticmethod
def
from_numpy_dict
(
dct
):
for
k
,
v
in
dct
.
items
():
if
k
==
KEY
.
CELL_SHIFT
:
dct
[
k
]
=
torch
.
Tensor
(
v
)
# this is special
else
:
dct
[
k
]
=
sevenn
.
util
.
dtype_correct
(
v
)
return
AtomGraphData
(
**
dct
)
mace-bench/3rdparty/SevenNet/sevenn/build.ninja
0 → 100644
View file @
73866b01
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda/bin/nvcc
cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17
post_cflags =
cuda_cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 --expt-relaxed-constexpr -fmad=false -std=c++17
cuda_post_cflags =
cuda_dlink_post_cflags =
ldflags = -shared -L/home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
rule link
command = $cxx $in $ldflags -o $out
build pair_d3_for_ase.cuda.o: cuda_compile /home/mazhaojia/mace-project/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/pair_d3_for_ase.cu
build pair_d3.so: link pair_d3_for_ase.cuda.o
default pair_d3.so
mace-bench/3rdparty/SevenNet/sevenn/calculator.py
0 → 100644
View file @
73866b01
import
ctypes
import
os
import
pathlib
import
warnings
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.jit
import
torch.jit._script
from
ase.calculators.calculator
import
Calculator
,
all_changes
from
ase.calculators.mixing
import
SumCalculator
from
ase.data
import
chemical_symbols
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.nn.sequential
import
AtomGraphSequential
from
sevenn.train.dataload
import
unlabeled_atoms_to_graph
import
logging
torch_script_type
=
torch
.
jit
.
_script
.
RecursiveScriptModule
class
SevenNetCalculator
(
Calculator
):
"""Supporting properties:
'free_energy', 'energy', 'forces', 'stress', 'energies'
free_energy equals energy. 'energies' stores atomic energy.
Multi-GPU acceleration is not supported with ASE calculator.
You should use LAMMPS for the acceleration.
"""
def
__init__
(
self
,
model
:
Union
[
str
,
pathlib
.
PurePath
,
AtomGraphSequential
]
=
'7net-0'
,
file_type
:
str
=
'checkpoint'
,
device
:
Union
[
torch
.
device
,
str
]
=
'auto'
,
modal
:
Optional
[
str
]
=
None
,
enable_cueq
:
bool
=
False
,
sevennet_config
:
Optional
[
Dict
]
=
None
,
# Not used in logic, just meta info
**
kwargs
,
):
"""Initialize SevenNetCalculator.
Parameters
----------
model: str | Path | AtomGraphSequential, default='7net-0'
Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or
path to the checkpoint, deployed model or the model itself
file_type: str, default='checkpoint'
one of 'checkpoint' | 'torchscript' | 'model_instance'
device: str | torch.device, default='auto'
if not given, use CUDA if available
modal: str | None, default=None
modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa,
it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24)
case insensitive
enable_cueq: bool, default=False
if True, use cuEquivariant to accelerate inference.
sevennet_config: dict | None, default=None
Not used, but can be used to carry meta information of this calculator
"""
print
(
"&&& Initializing SevenNetCalculator"
)
super
().
__init__
(
**
kwargs
)
self
.
sevennet_config
=
None
if
isinstance
(
model
,
pathlib
.
PurePath
):
model
=
str
(
model
)
allowed_file_types
=
[
'checkpoint'
,
'torchscript'
,
'model_instance'
]
file_type
=
file_type
.
lower
()
if
file_type
not
in
allowed_file_types
:
raise
ValueError
(
f
'file_type not in
{
allowed_file_types
}
'
)
if
enable_cueq
and
file_type
in
[
'model_instance'
,
'torchscript'
]:
warnings
.
warn
(
'file_type should be checkpoint to enable cueq. cueq set to False'
)
enable_cueq
=
False
if
isinstance
(
device
,
str
):
# TODO: do we really need this?
if
device
==
'auto'
:
self
.
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
else
:
self
.
device
=
torch
.
device
(
device
)
else
:
self
.
device
=
device
if
file_type
==
'checkpoint'
and
isinstance
(
model
,
str
):
cp
=
util
.
load_checkpoint
(
model
)
backend
=
'e3nn'
if
not
enable_cueq
else
'cueq'
model_loaded
=
cp
.
build_model
(
backend
)
model_loaded
.
set_is_batch_data
(
False
)
self
.
type_map
=
cp
.
config
[
KEY
.
TYPE_MAP
]
self
.
cutoff
=
cp
.
config
[
KEY
.
CUTOFF
]
self
.
sevennet_config
=
cp
.
config
elif
file_type
==
'torchscript'
and
isinstance
(
model
,
str
):
if
modal
:
raise
NotImplementedError
()
extra_dict
=
{
'chemical_symbols_to_index'
:
b
''
,
'cutoff'
:
b
''
,
'num_species'
:
b
''
,
'model_type'
:
b
''
,
'version'
:
b
''
,
'dtype'
:
b
''
,
'time'
:
b
''
,
}
model_loaded
=
torch
.
jit
.
load
(
model
,
_extra_files
=
extra_dict
,
map_location
=
self
.
device
)
chem_symbols
=
extra_dict
[
'chemical_symbols_to_index'
].
decode
(
'utf-8'
)
sym_to_num
=
{
sym
:
n
for
n
,
sym
in
enumerate
(
chemical_symbols
)}
self
.
type_map
=
{
sym_to_num
[
sym
]:
i
for
i
,
sym
in
enumerate
(
chem_symbols
.
split
())
}
self
.
cutoff
=
float
(
extra_dict
[
'cutoff'
].
decode
(
'utf-8'
))
elif
isinstance
(
model
,
AtomGraphSequential
):
if
model
.
type_map
is
None
:
raise
ValueError
(
'Model must have the type_map to be used with calculator'
)
if
model
.
cutoff
==
0.0
:
raise
ValueError
(
'Model cutoff seems not initialized'
)
model
.
eval_type_map
=
torch
.
tensor
(
True
)
# ?
model
.
set_is_batch_data
(
False
)
model_loaded
=
model
self
.
type_map
=
model
.
type_map
self
.
cutoff
=
model
.
cutoff
else
:
raise
ValueError
(
'Unexpected input combinations'
)
if
self
.
sevennet_config
is
None
and
sevennet_config
is
not
None
:
self
.
sevennet_config
=
sevennet_config
self
.
model
=
model_loaded
self
.
modal
=
None
if
isinstance
(
self
.
model
,
AtomGraphSequential
):
modal_map
=
self
.
model
.
modal_map
if
modal_map
:
modal_ava
=
list
(
modal_map
.
keys
())
if
not
modal
:
raise
ValueError
(
f
'modal argument missing (avail:
{
modal_ava
}
)'
)
elif
modal
not
in
modal_ava
:
raise
ValueError
(
f
'unknown modal
{
modal
}
(not in
{
modal_ava
}
)'
)
self
.
modal
=
modal
elif
not
self
.
model
.
modal_map
and
modal
:
warnings
.
warn
(
f
'modal=
{
modal
}
is ignored as model has no modal_map'
)
self
.
model
.
to
(
self
.
device
)
self
.
model
.
eval
()
self
.
implemented_properties
=
[
'free_energy'
,
'energy'
,
'forces'
,
'stress'
,
'energies'
,
]
def
set_atoms
(
self
,
atoms
):
# called by ase, when atoms.calc = calc
zs
=
tuple
(
set
(
atoms
.
get_atomic_numbers
()))
for
z
in
zs
:
if
z
not
in
self
.
type_map
:
sp
=
list
(
self
.
type_map
.
keys
())
raise
ValueError
(
f
'Model do not know atomic number:
{
z
}
, (knows:
{
sp
}
)'
)
def
output_to_results
(
self
,
output
):
energy
=
output
[
KEY
.
PRED_TOTAL_ENERGY
].
detach
().
cpu
().
item
()
num_atoms
=
output
[
'num_atoms'
].
item
()
atomic_energies
=
output
[
KEY
.
ATOMIC_ENERGY
].
detach
().
cpu
().
numpy
().
flatten
()
forces
=
output
[
KEY
.
PRED_FORCE
].
detach
().
cpu
().
numpy
()[:
num_atoms
,
:]
stress
=
np
.
array
(
(
-
output
[
KEY
.
PRED_STRESS
])
.
detach
()
.
cpu
()
.
numpy
()[[
0
,
1
,
2
,
4
,
5
,
3
]]
# as voigt notation
)
# Store results
return
{
'free_energy'
:
energy
,
'energy'
:
energy
,
'energies'
:
atomic_energies
,
'forces'
:
forces
,
'stress'
:
stress
,
'num_edges'
:
output
[
KEY
.
EDGE_IDX
].
shape
[
1
],
}
def
calculate
(
self
,
atoms
=
None
,
properties
=
None
,
system_changes
=
all_changes
):
# call parent class to set necessary atom attributes
Calculator
.
calculate
(
self
,
atoms
,
properties
,
system_changes
)
if
atoms
is
None
:
raise
ValueError
(
'No atoms to evaluate'
)
data
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms
,
self
.
cutoff
)
)
if
self
.
modal
:
data
[
KEY
.
DATA_MODALITY
]
=
self
.
modal
data
.
to
(
self
.
device
)
# type: ignore
if
isinstance
(
self
.
model
,
torch_script_type
):
data
[
KEY
.
NODE_FEATURE
]
=
torch
.
tensor
(
[
self
.
type_map
[
z
.
item
()]
for
z
in
data
[
KEY
.
NODE_FEATURE
]],
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
data
[
KEY
.
POS
].
requires_grad_
(
True
)
# backward compatibility
data
[
KEY
.
EDGE_VEC
].
requires_grad_
(
True
)
# backward compatibility
data
=
data
.
to_dict
()
del
data
[
'data_info'
]
import
logging
logging
.
debug
(
f
"data:
{
data
}
"
)
# logging.debug(f"data[pos]: {data['pos']}")
# logging.debug(f"data[x]: {data['x']}")
logging
.
debug
(
f
"data[cell_lattice_vectors]:
{
data
[
'cell_lattice_vectors'
]
}
"
)
logging
.
debug
(
f
"data[cell_volume]:
{
data
[
'cell_volume'
]
}
"
)
output
=
self
.
model
(
data
)
# logging.info(f"input: {data}")
# logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}")
# logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}")
# logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}")
self
.
results
=
self
.
output_to_results
(
output
)
# logging.debug(f"results['energy'] = {self.results['energy']}")
# logging.debug(f"results['forces'] = {self.results['forces']}")
# logging.debug(f"results['stress'] = {self.results['stress']}")
def
predict_one
(
self
,
atoms
):
if
atoms
is
None
:
raise
ValueError
(
'No atoms to evaluate'
)
data
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms
,
self
.
cutoff
)
)
if
self
.
modal
:
data
[
KEY
.
DATA_MODALITY
]
=
self
.
modal
data
.
to
(
self
.
device
)
# type: ignore
if
isinstance
(
self
.
model
,
torch_script_type
):
data
[
KEY
.
NODE_FEATURE
]
=
torch
.
tensor
(
[
self
.
type_map
[
z
.
item
()]
for
z
in
data
[
KEY
.
NODE_FEATURE
]],
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
data
[
KEY
.
POS
].
requires_grad_
(
True
)
# backward compatibility
data
[
KEY
.
EDGE_VEC
].
requires_grad_
(
True
)
# backward compatibility
data
=
data
.
to_dict
()
del
data
[
'data_info'
]
return
self
.
model
(
data
)
def
predict
(
self
,
atoms_list
,
properties
=
None
):
# if len(atoms_list) == 1:
# output = self.predict_one(atoms_list[0])
# predictions = {}
# predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0)
# predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0)
# voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0)
# stress_list = []
# for i in range(voigt.shape[0]):
# stress_list.append(self._stress2tensor(voigt[i,:]))
# predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3)
# return predictions
if
not
atoms_list
:
raise
ValueError
(
"Empty atoms_list provided"
)
if
not
isinstance
(
atoms_list
,
list
):
atoms_list
=
[
atoms_list
]
# Convert atoms to graph data
graph_list
=
[]
for
atoms
in
atoms_list
:
data
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms
,
self
.
cutoff
)
)
if
self
.
modal
:
data
[
KEY
.
DATA_MODALITY
]
=
self
.
modal
if
isinstance
(
self
.
model
,
torch_script_type
):
data
[
KEY
.
NODE_FEATURE
]
=
torch
.
tensor
(
[
self
.
type_map
[
z
.
item
()]
for
z
in
data
[
KEY
.
NODE_FEATURE
]],
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
data
[
KEY
.
POS
].
requires_grad_
(
True
)
# backward compatibility
data
[
KEY
.
EDGE_VEC
].
requires_grad_
(
True
)
# backward compatibility
graph_list
.
append
(
data
)
# Process graphs based on model type
# was_batch_mode = True
if
isinstance
(
self
.
model
,
AtomGraphSequential
):
# was_batch_mode = self.model.is_batch_data
self
.
model
.
set_is_batch_data
(
True
)
self
.
model
.
eval
()
# Batch the data if there are multiple atoms
from
torch_geometric.loader.dataloader
import
Collater
batched_data
=
Collater
(
graph_list
)(
graph_list
)
batched_data
=
batched_data
.
to
(
self
.
device
)
import
logging
logging
.
debug
(
f
"batched_data:
{
batched_data
}
"
)
# logging.debug(f"batched_data[pos]: {batched_data['pos']}")
# logging.debug(f"batched_data[x]: {batched_data['x']}")
logging
.
debug
(
f
"batched_data[cell_lattice_vectors]:
{
batched_data
[
'cell_lattice_vectors'
]
}
"
)
logging
.
debug
(
f
"batched_data[cell_volume]:
{
batched_data
[
'cell_volume'
]
}
"
)
# Run model on batched data
if
isinstance
(
self
.
model
,
torch_script_type
):
batched_dict
=
batched_data
.
to_dict
()
if
'data_info'
in
batched_dict
:
del
batched_dict
[
'data_info'
]
output
=
self
.
model
(
batched_dict
)
else
:
output
=
self
.
model
(
batched_data
)
# Convert to list of individual outputs using util.to_atom_graph_list
# logging.info(f"input: {batched_data}")
# logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}")
# logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}")
# logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}")
predictions
=
{}
predictions
[
'energy'
]
=
output
[
KEY
.
PRED_TOTAL_ENERGY
].
to
(
torch
.
float64
).
detach
()
predictions
[
'forces'
]
=
output
[
KEY
.
PRED_FORCE
].
to
(
torch
.
float64
).
detach
()
voigt
=
(
-
output
[
KEY
.
PRED_STRESS
])[:,
[
0
,
1
,
2
,
4
,
5
,
3
]].
to
(
torch
.
float64
).
detach
()
stress_list
=
[]
for
i
in
range
(
voigt
.
shape
[
0
]):
stress_list
.
append
(
self
.
_stress2tensor
(
voigt
[
i
,:]))
predictions
[
'stress'
]
=
torch
.
stack
(
stress_list
,
dim
=
0
).
view
(
-
1
,
3
,
3
).
detach
()
# logging.debug(f"predictions['energy'] = {predictions['energy']}")
# logging.debug(f"predictions['forces'] = {predictions['forces']}")
# logging.debug(f"predictions['stress'] = {predictions['stress']}")
return
predictions
def
_stress2tensor
(
self
,
stress
):
tensor
=
torch
.
tensor
(
[
[
stress
[
0
],
stress
[
5
],
stress
[
4
]],
[
stress
[
5
],
stress
[
1
],
stress
[
3
]],
[
stress
[
4
],
stress
[
3
],
stress
[
2
]],
],
device
=
self
.
device
)
return
tensor
class
SevenNetD3Calculator
(
SumCalculator
):
def
__init__
(
self
,
model
:
Union
[
str
,
pathlib
.
PurePath
,
AtomGraphSequential
]
=
'7net-0'
,
file_type
:
str
=
'checkpoint'
,
device
:
Union
[
torch
.
device
,
str
]
=
'auto'
,
sevennet_config
:
Optional
[
Any
]
=
None
,
# hold meta information
damping_type
:
str
=
'damp_bj'
,
functional_name
:
str
=
'pbe'
,
vdw_cutoff
:
float
=
9000
,
# au^2, 0.52917726 angstrom = 1 au
cn_cutoff
:
float
=
1600
,
# au^2, 0.52917726 angstrom = 1 au
batch_size
=
10
,
**
kwargs
,
):
"""Initialize SevenNetD3Calculator. CUDA required.
Parameters
----------
model: str | Path | AtomGraphSequential
Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or
path to the checkpoint, deployed model or the model itself
file_type: str, default='checkpoint'
one of 'checkpoint' | 'torchscript' | 'model_instance'
device: str | torch.device, default='auto'
if not given, use CUDA if available
modal: str | None, default=None
modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa,
it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24)
enable_cueq: bool, default=False
if True, use cuEquivariant to accelerate inference.
damping_type: str, default='damp_bj'
Damping type of D3, one of 'damp_bj' | 'damp_zero'
functional_name: str, default='pbe'
Target functional name of D3 parameters.
vdw_cutoff: float, default=9000
vdw cutoff of D3 calculator in au
cn_cutoff: float, default=1600
cn cutoff of D3 calculator in au
"""
self
.
d3_calc
=
D3Calculator
(
damping_type
=
damping_type
,
functional_name
=
functional_name
,
vdw_cutoff
=
vdw_cutoff
,
cn_cutoff
=
cn_cutoff
,
**
kwargs
,
)
self
.
sevennet_calc
=
SevenNetCalculator
(
model
=
model
,
file_type
=
file_type
,
device
=
device
,
sevennet_config
=
sevennet_config
,
**
kwargs
,
)
super
().
__init__
([
self
.
sevennet_calc
,
self
.
d3_calc
])
self
.
device
=
device
self
.
d3_calcs
=
[]
for
_
in
range
(
batch_size
):
self
.
d3_calcs
.
append
(
D3Calculator
(
damping_type
=
damping_type
,
functional_name
=
functional_name
,
vdw_cutoff
=
vdw_cutoff
,
cn_cutoff
=
cn_cutoff
,
**
kwargs
,
)
)
def
predict
(
self
,
atoms_list
):
"""Predict the energy and forces for a list of atoms.
"""
# Call the predict method of the first calculator (SevenNetCalculator)
predictions
=
self
.
sevennet_calc
.
predict
(
atoms_list
)
energy_list
=
[]
forces_list
=
[]
stress_list
=
[]
predictions3d
=
{}
for
i
,
atoms
in
enumerate
(
atoms_list
):
prediction
=
self
.
d3_calcs
[
i
].
predict_one
(
atoms
)
energy_list
.
append
(
torch
.
tensor
(
prediction
[
'energy'
]))
forces_list
.
append
(
torch
.
from_numpy
(
prediction
[
'forces'
]).
to
(
self
.
device
))
stress_list
.
append
(
self
.
_stress2tensor
(
torch
.
from_numpy
(
prediction
[
'stress'
])))
# Convert lists to tensors
predictions3d
[
'energy'
]
=
torch
.
stack
(
energy_list
,
dim
=
0
).
to
(
self
.
device
)
predictions3d
[
'forces'
]
=
torch
.
cat
(
forces_list
,
dim
=
0
).
view
(
-
1
,
3
)
predictions3d
[
'stress'
]
=
torch
.
stack
(
stress_list
,
dim
=
0
).
view
(
-
1
,
3
,
3
)
predictions
[
'energy'
]
+=
predictions3d
[
'energy'
].
detach
()
predictions
[
'forces'
]
+=
predictions3d
[
'forces'
].
detach
()
predictions
[
'stress'
]
+=
predictions3d
[
'stress'
].
detach
()
return
predictions
def
_stress2tensor
(
self
,
stress
):
tensor
=
torch
.
tensor
(
[
# [stress[0], stress[3], stress[4]],
# [stress[3], stress[1], stress[5]],
# [stress[4], stress[5], stress[2]],
[
stress
[
0
],
stress
[
5
],
stress
[
4
]],
[
stress
[
5
],
stress
[
1
],
stress
[
3
]],
[
stress
[
4
],
stress
[
3
],
stress
[
2
]],
],
device
=
self
.
device
)
return
tensor
def
_load
(
name
:
str
)
->
ctypes
.
CDLL
:
from
torch.utils.cpp_extension
import
LIB_EXT
,
_get_build_directory
,
load
# Load the library from the candidate locations
package_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
try
:
return
ctypes
.
CDLL
(
os
.
path
.
join
(
package_dir
,
f
'
{
name
}{
LIB_EXT
}
'
))
except
OSError
:
pass
cache_dir
=
_get_build_directory
(
name
,
verbose
=
False
)
try
:
return
ctypes
.
CDLL
(
os
.
path
.
join
(
cache_dir
,
f
'
{
name
}{
LIB_EXT
}
'
))
except
OSError
:
pass
# Compile the library if it is not found
if
os
.
access
(
package_dir
,
os
.
W_OK
):
compile_dir
=
package_dir
else
:
print
(
'Warning: package directory is not writable. Using cache directory.'
)
compile_dir
=
cache_dir
if
'TORCH_CUDA_ARCH_LIST'
not
in
os
.
environ
:
print
(
'Warning: TORCH_CUDA_ARCH_LIST is not set.'
)
print
(
'Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90'
)
os
.
environ
[
'TORCH_CUDA_ARCH_LIST'
]
=
'6.1;7.0;7.5;8.0;8.6;8.9;9.0'
load
(
name
=
name
,
sources
=
[
os
.
path
.
join
(
package_dir
,
'pair_e3gnn'
,
'pair_d3_for_ase.cu'
)],
extra_cuda_cflags
=
[
'-O3'
,
'--expt-relaxed-constexpr'
,
'-fmad=false'
],
build_directory
=
compile_dir
,
verbose
=
True
,
is_python_module
=
False
,
)
return
ctypes
.
CDLL
(
os
.
path
.
join
(
compile_dir
,
f
'
{
name
}{
LIB_EXT
}
'
))
class
PairD3
(
ctypes
.
Structure
):
pass
# Opaque structure; only used as a pointer
class
D3Calculator
(
Calculator
):
"""ASE calculator for accelerated D3 van der Waals (vdW) correction.
Example:
from ase.calculators.mixing import SumCalculator
calc_1 = SevenNetCalculator()
calc_2 = D3Calculator()
return SumCalculator([calc_1, calc_2])
This calculator interfaces with the `libpaird3.so` library,
which is compiled by nvcc during the package installation.
If you encounter any errors, please verify
the installation process and the compilation options in `setup.py`.
Note: Multi-GPU parallel MD is not supported in this mode.
Note: Cffi could be used, but it was avoided to reduce dependencies.
"""
# Here, free_energy = energy
implemented_properties
=
[
'free_energy'
,
'energy'
,
'forces'
,
'stress'
]
def
__init__
(
self
,
damping_type
:
str
=
'damp_bj'
,
# damp_bj, damp_zero
functional_name
:
str
=
'pbe'
,
# check the source code
vdw_cutoff
:
float
=
9000
,
# au^2, 0.52917726 angstrom = 1 au
cn_cutoff
:
float
=
1600
,
# au^2, 0.52917726 angstrom = 1 au
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'CPU + D3 is not implemented yet'
)
self
.
rthr
=
vdw_cutoff
self
.
cnthr
=
cn_cutoff
self
.
damp_name
=
damping_type
.
lower
()
self
.
func_name
=
functional_name
.
lower
()
if
self
.
damp_name
not
in
[
'damp_bj'
,
'damp_zero'
]:
raise
ValueError
(
'Error: Invalid damping type.'
)
self
.
_lib
=
_load
(
'pair_d3'
)
self
.
_lib
.
pair_init
.
restype
=
ctypes
.
POINTER
(
PairD3
)
self
.
pair
=
self
.
_lib
.
pair_init
()
self
.
_lib
.
pair_set_atom
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
),
# PairD3* pair
ctypes
.
c_int
,
# int natoms
ctypes
.
c_int
,
# int ntypes
ctypes
.
POINTER
(
ctypes
.
c_int
),
# int* types
ctypes
.
POINTER
(
ctypes
.
c_double
),
# double* x
]
self
.
_lib
.
pair_set_atom
.
restype
=
None
self
.
_lib
.
pair_set_domain
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
),
# PairD3* pair
ctypes
.
c_int
,
# int xperiodic
ctypes
.
c_int
,
# int yperiodic
ctypes
.
c_int
,
# int zperiodic
ctypes
.
POINTER
(
ctypes
.
c_double
),
# double* boxlo
ctypes
.
POINTER
(
ctypes
.
c_double
),
# double* boxhi
ctypes
.
c_double
,
# double xy
ctypes
.
c_double
,
# double xz
ctypes
.
c_double
,
# double yz
]
self
.
_lib
.
pair_set_domain
.
restype
=
None
self
.
_lib
.
pair_run_settings
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
),
# PairD3* pair
ctypes
.
c_double
,
# double rthr
ctypes
.
c_double
,
# double cnthr
ctypes
.
c_char_p
,
# const char* damp_name
ctypes
.
c_char_p
,
# const char* func_name
]
self
.
_lib
.
pair_run_settings
.
restype
=
None
self
.
_lib
.
pair_run_coeff
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
),
# PairD3* pair
ctypes
.
POINTER
(
ctypes
.
c_int
),
# int* atomic_numbers
]
self
.
_lib
.
pair_run_coeff
.
restype
=
None
self
.
_lib
.
pair_run_compute
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
)]
self
.
_lib
.
pair_run_compute
.
restype
=
None
self
.
_lib
.
pair_get_energy
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
)]
self
.
_lib
.
pair_get_energy
.
restype
=
ctypes
.
c_double
self
.
_lib
.
pair_get_force
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
)]
self
.
_lib
.
pair_get_force
.
restype
=
ctypes
.
POINTER
(
ctypes
.
c_double
)
self
.
_lib
.
pair_get_stress
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
)]
self
.
_lib
.
pair_get_stress
.
restype
=
ctypes
.
POINTER
(
ctypes
.
c_double
*
6
)
self
.
_lib
.
pair_fin
.
argtypes
=
[
ctypes
.
POINTER
(
PairD3
)]
self
.
_lib
.
pair_fin
.
restype
=
None
def
_idx_to_numbers
(
self
,
Z_of_atoms
):
unique_numbers
=
list
(
dict
.
fromkeys
(
Z_of_atoms
))
return
unique_numbers
def
_idx_to_types
(
self
,
Z_of_atoms
):
unique_numbers
=
list
(
dict
.
fromkeys
(
Z_of_atoms
))
mapping
=
{
num
:
idx
+
1
for
idx
,
num
in
enumerate
(
unique_numbers
)}
atom_types
=
[
mapping
[
num
]
for
num
in
Z_of_atoms
]
return
atom_types
def
_convert_domain_ase2lammps
(
self
,
cell
):
qtrans
,
ltrans
=
np
.
linalg
.
qr
(
cell
.
T
,
mode
=
'complete'
)
lammps_cell
=
ltrans
.
T
signs
=
np
.
sign
(
np
.
diag
(
lammps_cell
))
lammps_cell
=
lammps_cell
*
signs
qtrans
=
qtrans
*
signs
lammps_cell
=
lammps_cell
[(
0
,
1
,
2
,
1
,
2
,
2
),
(
0
,
1
,
2
,
0
,
0
,
1
)]
rotator
=
qtrans
.
T
return
lammps_cell
,
rotator
def
_stress2tensor
(
self
,
stress
):
tensor
=
np
.
array
(
[
[
stress
[
0
],
stress
[
3
],
stress
[
4
]],
[
stress
[
3
],
stress
[
1
],
stress
[
5
]],
[
stress
[
4
],
stress
[
5
],
stress
[
2
]],
]
)
return
tensor
def
_tensor2stress
(
self
,
tensor
):
stress
=
-
np
.
array
(
[
tensor
[
0
,
0
],
tensor
[
1
,
1
],
tensor
[
2
,
2
],
tensor
[
1
,
2
],
tensor
[
0
,
2
],
tensor
[
0
,
1
],
]
)
return
stress
def
calculate
(
self
,
atoms
=
None
,
properties
=
None
,
system_changes
=
all_changes
):
Calculator
.
calculate
(
self
,
atoms
,
properties
,
system_changes
)
if
atoms
is
None
:
raise
ValueError
(
'No atoms to evaluate'
)
if
atoms
.
get_cell
().
sum
()
==
0
:
print
(
'Warning: D3Calculator requires a cell.
\n
'
'Warning: An orthogonal cell large enough is generated.'
)
positions
=
atoms
.
get_positions
()
min_pos
=
positions
.
min
(
axis
=
0
)
max_pos
=
positions
.
max
(
axis
=
0
)
max_cutoff
=
np
.
sqrt
(
max
(
self
.
rthr
,
self
.
cnthr
))
*
0.52917726
cell_lengths
=
max_pos
-
min_pos
+
max_cutoff
+
1.0
# extra margin
cell
=
np
.
eye
(
3
)
*
cell_lengths
atoms
.
set_cell
(
cell
)
atoms
.
set_pbc
([
True
,
True
,
True
])
# for minus positions
cell
,
rotator
=
self
.
_convert_domain_ase2lammps
(
atoms
.
get_cell
())
Z_of_atoms
=
atoms
.
get_atomic_numbers
()
natoms
=
len
(
atoms
)
ntypes
=
len
(
set
(
Z_of_atoms
))
types
=
(
ctypes
.
c_int
*
natoms
)(
*
self
.
_idx_to_types
(
Z_of_atoms
))
positions
=
atoms
.
get_positions
()
@
rotator
.
T
x_flat
=
(
ctypes
.
c_double
*
(
natoms
*
3
))(
*
positions
.
flatten
())
atomic_numbers
=
(
ctypes
.
c_int
*
ntypes
)(
*
self
.
_idx_to_numbers
(
Z_of_atoms
))
boxlo
=
(
ctypes
.
c_double
*
3
)(
0.0
,
0.0
,
0.0
)
boxhi
=
(
ctypes
.
c_double
*
3
)(
cell
[
0
],
cell
[
1
],
cell
[
2
])
xy
=
cell
[
3
]
xz
=
cell
[
4
]
yz
=
cell
[
5
]
xperiodic
,
yperiodic
,
zperiodic
=
atoms
.
get_pbc
()
lib
=
self
.
_lib
assert
lib
is
not
None
lib
.
pair_set_atom
(
self
.
pair
,
natoms
,
ntypes
,
types
,
x_flat
)
xperiodic
=
xperiodic
.
astype
(
int
)
yperiodic
=
yperiodic
.
astype
(
int
)
zperiodic
=
zperiodic
.
astype
(
int
)
lib
.
pair_set_domain
(
self
.
pair
,
xperiodic
,
yperiodic
,
zperiodic
,
boxlo
,
boxhi
,
xy
,
xz
,
yz
)
lib
.
pair_run_settings
(
self
.
pair
,
self
.
rthr
,
self
.
cnthr
,
self
.
damp_name
.
encode
(
'utf-8'
),
self
.
func_name
.
encode
(
'utf-8'
),
)
lib
.
pair_run_coeff
(
self
.
pair
,
atomic_numbers
)
lib
.
pair_run_compute
(
self
.
pair
)
result_E
=
lib
.
pair_get_energy
(
self
.
pair
)
result_F_ptr
=
lib
.
pair_get_force
(
self
.
pair
)
result_F_size
=
natoms
*
3
result_F
=
np
.
ctypeslib
.
as_array
(
result_F_ptr
,
shape
=
(
result_F_size
,)
).
reshape
((
natoms
,
3
))
result_F
=
np
.
array
(
result_F
)
result_F
=
result_F
@
rotator
result_S
=
lib
.
pair_get_stress
(
self
.
pair
)
result_S
=
np
.
array
(
result_S
.
contents
)
result_S
=
(
self
.
_tensor2stress
(
rotator
.
T
@
self
.
_stress2tensor
(
result_S
)
@
rotator
)
/
atoms
.
get_volume
()
)
self
.
results
=
{
'free_energy'
:
result_E
,
'energy'
:
result_E
,
'forces'
:
result_F
,
'stress'
:
result_S
,
}
def
predict_one
(
self
,
atoms
):
atoms
=
atoms
.
copy
()
if
atoms
is
None
:
raise
ValueError
(
'No atoms to evaluate'
)
if
atoms
.
get_cell
().
sum
()
==
0
:
print
(
'Warning: D3Calculator requires a cell.
\n
'
'Warning: An orthogonal cell large enough is generated.'
)
positions
=
atoms
.
get_positions
()
min_pos
=
positions
.
min
(
axis
=
0
)
max_pos
=
positions
.
max
(
axis
=
0
)
max_cutoff
=
np
.
sqrt
(
max
(
self
.
rthr
,
self
.
cnthr
))
*
0.52917726
cell_lengths
=
max_pos
-
min_pos
+
max_cutoff
+
1.0
# extra margin
cell
=
np
.
eye
(
3
)
*
cell_lengths
atoms
.
set_cell
(
cell
)
atoms
.
set_pbc
([
True
,
True
,
True
])
# for minus positions
cell
,
rotator
=
self
.
_convert_domain_ase2lammps
(
atoms
.
get_cell
())
Z_of_atoms
=
atoms
.
get_atomic_numbers
()
natoms
=
len
(
atoms
)
ntypes
=
len
(
set
(
Z_of_atoms
))
types
=
(
ctypes
.
c_int
*
natoms
)(
*
self
.
_idx_to_types
(
Z_of_atoms
))
positions
=
atoms
.
get_positions
()
@
rotator
.
T
x_flat
=
(
ctypes
.
c_double
*
(
natoms
*
3
))(
*
positions
.
flatten
())
atomic_numbers
=
(
ctypes
.
c_int
*
ntypes
)(
*
self
.
_idx_to_numbers
(
Z_of_atoms
))
boxlo
=
(
ctypes
.
c_double
*
3
)(
0.0
,
0.0
,
0.0
)
boxhi
=
(
ctypes
.
c_double
*
3
)(
cell
[
0
],
cell
[
1
],
cell
[
2
])
xy
=
cell
[
3
]
xz
=
cell
[
4
]
yz
=
cell
[
5
]
xperiodic
,
yperiodic
,
zperiodic
=
atoms
.
get_pbc
()
lib
=
self
.
_lib
assert
lib
is
not
None
lib
.
pair_set_atom
(
self
.
pair
,
natoms
,
ntypes
,
types
,
x_flat
)
xperiodic
=
xperiodic
.
astype
(
int
)
yperiodic
=
yperiodic
.
astype
(
int
)
zperiodic
=
zperiodic
.
astype
(
int
)
lib
.
pair_set_domain
(
self
.
pair
,
xperiodic
,
yperiodic
,
zperiodic
,
boxlo
,
boxhi
,
xy
,
xz
,
yz
)
lib
.
pair_run_settings
(
self
.
pair
,
self
.
rthr
,
self
.
cnthr
,
self
.
damp_name
.
encode
(
'utf-8'
),
self
.
func_name
.
encode
(
'utf-8'
),
)
lib
.
pair_run_coeff
(
self
.
pair
,
atomic_numbers
)
lib
.
pair_run_compute
(
self
.
pair
)
result_E
=
lib
.
pair_get_energy
(
self
.
pair
)
result_F_ptr
=
lib
.
pair_get_force
(
self
.
pair
)
result_F_size
=
natoms
*
3
result_F
=
np
.
ctypeslib
.
as_array
(
result_F_ptr
,
shape
=
(
result_F_size
,)
).
reshape
((
natoms
,
3
))
result_F
=
np
.
array
(
result_F
)
result_F
=
result_F
@
rotator
result_S
=
lib
.
pair_get_stress
(
self
.
pair
)
result_S
=
np
.
array
(
result_S
.
contents
)
result_S
=
(
self
.
_tensor2stress
(
rotator
.
T
@
self
.
_stress2tensor
(
result_S
)
@
rotator
)
/
atoms
.
get_volume
()
)
prediction
=
{
'free_energy'
:
float
(
result_E
),
'energy'
:
float
(
result_E
),
'forces'
:
result_F
.
copy
(),
'stress'
:
result_S
.
copy
(),
}
return
prediction
def
__del__
(
self
):
if
self
.
_lib
is
not
None
:
self
.
_lib
.
pair_fin
(
self
.
pair
)
self
.
_lib
=
None
self
.
pair
=
None
mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py
0 → 100644
View file @
73866b01
import
os
import
pathlib
import
uuid
import
warnings
from
copy
import
deepcopy
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
pandas
as
pd
from
packaging.version
import
Version
from
torch
import
Tensor
from
torch
import
load
as
torch_load
import
sevenn
import
sevenn._const
as
consts
import
sevenn._keys
as
KEY
import
sevenn.scripts.backward_compatibility
as
compat
from
sevenn
import
model_build
from
sevenn.nn.scale
import
get_resolved_shift_scale
from
sevenn.nn.sequential
import
AtomGraphSequential
def
assert_atoms
(
atoms1
,
atoms2
,
rtol
=
1e-5
,
atol
=
1e-6
):
import
numpy
as
np
def
acl
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
):
return
np
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
assert
len
(
atoms1
)
==
len
(
atoms2
)
assert
acl
(
atoms1
.
get_cell
(),
atoms2
.
get_cell
())
assert
acl
(
atoms1
.
get_potential_energy
(),
atoms2
.
get_potential_energy
())
assert
acl
(
atoms1
.
get_forces
(),
atoms2
.
get_forces
(),
rtol
*
10
,
atol
*
10
)
assert
acl
(
atoms1
.
get_stress
(
voigt
=
False
),
atoms2
.
get_stress
(
voigt
=
False
),
rtol
*
10
,
atol
*
10
,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
def
copy_state_dict
(
state_dict
)
->
dict
:
if
isinstance
(
state_dict
,
dict
):
return
{
key
:
copy_state_dict
(
value
)
for
key
,
value
in
state_dict
.
items
()}
elif
isinstance
(
state_dict
,
list
):
return
[
copy_state_dict
(
item
)
for
item
in
state_dict
]
# type: ignore
elif
isinstance
(
state_dict
,
Tensor
):
return
state_dict
.
clone
()
# type: ignore
else
:
# For non-tensor values (e.g., scalars, None), return as-is
return
state_dict
def
_config_cp_routine
(
config
):
cp_ver
=
Version
(
config
.
get
(
'version'
,
None
))
this_ver
=
Version
(
sevenn
.
__version__
)
if
cp_ver
>
this_ver
:
warnings
.
warn
(
f
'The checkpoint version (
{
cp_ver
}
) is newer than this source'
f
'(
{
this_ver
}
). This may cause unexpected behaviors'
)
defaults
=
{
**
consts
.
model_defaults
(
config
)}
config
=
compat
.
patch_old_config
(
config
)
# type: ignore
scaler
=
model_build
.
init_shift_scale
(
config
)
shift
,
scale
=
get_resolved_shift_scale
(
scaler
,
config
.
get
(
KEY
.
TYPE_MAP
),
config
.
get
(
KEY
.
MODAL_MAP
,
None
)
)
config
[
'shift'
]
=
shift
config
[
'scale'
]
=
scale
for
k
,
v
in
defaults
.
items
():
if
k
in
config
:
continue
if
os
.
getenv
(
'SEVENN_DEBUG'
,
False
):
warnings
.
warn
(
f
'
{
k
}
not in config, use default value
{
v
}
'
,
UserWarning
)
config
[
k
]
=
v
for
k
,
v
in
config
.
items
():
if
isinstance
(
v
,
Tensor
):
config
[
k
]
=
v
.
cpu
()
return
config
def
_convert_e3nn_and_cueq
(
stct_src
,
stct_dst
,
src_config
,
from_cueq
):
"""
manually check keys and assert if something unexpected happens
"""
n_layer
=
src_config
[
'num_convolution_layer'
]
linear_module_names
=
[
'onehot_to_feature_x'
,
'reduce_input_to_hidden'
,
'reduce_hidden_to_energy'
,
]
convolution_module_names
=
[]
fc_tensor_product_module_names
=
[]
for
i
in
range
(
n_layer
):
linear_module_names
.
append
(
f
'
{
i
}
_self_interaction_1'
)
linear_module_names
.
append
(
f
'
{
i
}
_self_interaction_2'
)
if
src_config
.
get
(
KEY
.
SELF_CONNECTION_TYPE
)
==
'linear'
:
linear_module_names
.
append
(
f
'
{
i
}
_self_connection_intro'
)
elif
src_config
.
get
(
KEY
.
SELF_CONNECTION_TYPE
)
==
'nequip'
:
fc_tensor_product_module_names
.
append
(
f
'
{
i
}
_self_connection_intro'
)
convolution_module_names
.
append
(
f
'
{
i
}
_convolution'
)
# Rule: those keys can be safely ignored before state dict load,
# except for linear.bias. This should be aborted in advance to
# this function. Others are not parameters but constants.
cue_only_linear_followers
=
[
'linear.f.tp.f_fx.module.c'
]
e3nn_only_linear_followers
=
[
'linear.bias'
,
'linear.output_mask'
]
ignores_in_linear
=
cue_only_linear_followers
+
e3nn_only_linear_followers
cue_only_conv_followers
=
[
'convolution.f.tp.f_fx.module.c'
,
'convolution.f.tp.module.module.f.module.module._f.data'
,
]
e3nn_only_conv_followers
=
[
'convolution._compiled_main_left_right._w3j'
,
'convolution.weight'
,
'convolution.output_mask'
,
]
ignores_in_conv
=
cue_only_conv_followers
+
e3nn_only_conv_followers
cue_only_fc_followers
=
[
'fc_tensor_product.f.tp.f_fx.module.c'
]
e3nn_only_fc_followers
=
[
'fc_tensor_product.output_mask'
,
]
ignores_in_fc
=
cue_only_fc_followers
+
e3nn_only_fc_followers
updated_keys
=
[]
for
k
,
v
in
stct_src
.
items
():
module_name
=
k
.
split
(
'.'
)[
0
]
flag
=
False
if
module_name
in
linear_module_names
:
for
ignore
in
ignores_in_linear
:
if
'.'
.
join
([
module_name
,
ignore
])
in
k
:
flag
=
True
break
if
not
flag
and
k
==
'.'
.
join
([
module_name
,
'linear.weight'
]):
updated_keys
.
append
(
k
)
stct_dst
[
k
]
=
v
.
clone
().
reshape
(
stct_dst
[
k
].
shape
)
flag
=
True
assert
flag
,
f
'Unexpected key from linear:
{
k
}
'
elif
module_name
in
convolution_module_names
:
for
ignore
in
ignores_in_conv
:
if
'.'
.
join
([
module_name
,
ignore
])
in
k
:
flag
=
True
break
if
not
flag
and
(
k
.
startswith
(
f
'
{
module_name
}
.weight_nn'
)
or
k
==
'.'
.
join
([
module_name
,
'denominator'
])
):
updated_keys
.
append
(
k
)
stct_dst
[
k
]
=
v
.
clone
().
reshape
(
stct_dst
[
k
].
shape
)
flag
=
True
assert
flag
,
f
'Unexpected key from linear:
{
k
}
'
elif
module_name
in
fc_tensor_product_module_names
:
for
ignore
in
ignores_in_fc
:
if
'.'
.
join
([
module_name
,
ignore
])
in
k
:
flag
=
True
break
if
not
flag
and
k
==
'.'
.
join
([
module_name
,
'fc_tensor_product.weight'
]):
updated_keys
.
append
(
k
)
stct_dst
[
k
]
=
v
.
clone
().
reshape
(
stct_dst
[
k
].
shape
)
flag
=
True
assert
flag
,
f
'Unexpected key from fc tensor product:
{
k
}
'
else
:
# assert k in stct_dst
updated_keys
.
append
(
k
)
stct_dst
[
k
]
=
v
.
clone
().
reshape
(
stct_dst
[
k
].
shape
)
return
stct_dst
class
SevenNetCheckpoint
:
"""
Tool box for checkpoint processed from SevenNet.
"""
def
__init__
(
self
,
checkpoint_path
:
Union
[
pathlib
.
Path
,
str
]):
self
.
_checkpoint_path
=
os
.
path
.
abspath
(
checkpoint_path
)
self
.
_config
=
None
self
.
_epoch
=
None
self
.
_model_state_dict
=
None
self
.
_optimizer_state_dict
=
None
self
.
_scheduler_state_dict
=
None
self
.
_hash
=
None
self
.
_time
=
None
self
.
_loaded
=
False
def
__repr__
(
self
)
->
str
:
cfg
=
self
.
config
# just alias
if
len
(
cfg
)
==
0
:
return
''
dct
=
{
'Sevennet version'
:
cfg
.
get
(
'version'
,
'Not found'
),
'When'
:
self
.
time
,
'Hash'
:
self
.
hash
,
'Cutoff'
:
cfg
.
get
(
'cutoff'
),
'Channel'
:
cfg
.
get
(
'channel'
),
'Lmax'
:
cfg
.
get
(
'lmax'
),
'Group (parity)'
:
'O3'
if
cfg
.
get
(
'is_parity'
)
else
'SO3'
,
'Interaction layers'
:
cfg
.
get
(
'num_convolution_layer'
),
'Self connection type'
:
cfg
.
get
(
'self_connection_type'
,
'nequip'
),
'Last epoch'
:
self
.
epoch
,
'Elements'
:
len
(
cfg
.
get
(
'chemical_species'
,
[])),
}
if
cfg
.
get
(
'use_modality'
,
False
):
dct
[
'Modality'
]
=
', '
.
join
(
list
(
cfg
.
get
(
'_modal_map'
,
{}).
keys
()))
df
=
pd
.
DataFrame
.
from_dict
([
dct
]).
T
# type: ignore
df
.
columns
=
[
''
]
return
df
.
to_string
()
@
property
def
checkpoint_path
(
self
)
->
str
:
return
str
(
self
.
_checkpoint_path
)
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
if
not
self
.
_loaded
:
self
.
_load
()
assert
isinstance
(
self
.
_config
,
dict
)
return
deepcopy
(
self
.
_config
)
@
property
def
model_state_dict
(
self
)
->
Dict
[
str
,
Any
]:
if
not
self
.
_loaded
:
self
.
_load
()
assert
isinstance
(
self
.
_model_state_dict
,
dict
)
return
copy_state_dict
(
self
.
_model_state_dict
)
@
property
def
optimizer_state_dict
(
self
)
->
Dict
[
str
,
Any
]:
if
not
self
.
_loaded
:
self
.
_load
()
assert
isinstance
(
self
.
_optimizer_state_dict
,
dict
)
return
copy_state_dict
(
self
.
_optimizer_state_dict
)
@
property
def
scheduler_state_dict
(
self
)
->
Dict
[
str
,
Any
]:
if
not
self
.
_loaded
:
self
.
_load
()
assert
isinstance
(
self
.
_scheduler_state_dict
,
dict
)
return
copy_state_dict
(
self
.
_scheduler_state_dict
)
@
property
def
epoch
(
self
)
->
Optional
[
int
]:
if
not
self
.
_loaded
:
self
.
_load
()
return
self
.
_epoch
@
property
def
time
(
self
)
->
str
:
if
not
self
.
_loaded
:
self
.
_load
()
assert
isinstance
(
self
.
_time
,
str
)
return
self
.
_time
@
property
def
hash
(
self
)
->
str
:
if
not
self
.
_loaded
:
self
.
_load
()
assert
isinstance
(
self
.
_hash
,
str
)
return
self
.
_hash
def
_load
(
self
)
->
None
:
assert
not
self
.
_loaded
cp_path
=
self
.
checkpoint_path
# just alias
cp
=
torch_load
(
cp_path
,
weights_only
=
False
,
map_location
=
'cpu'
)
self
.
_config_original
=
cp
.
get
(
'config'
,
{})
self
.
_model_state_dict
=
cp
.
get
(
'model_state_dict'
,
{})
self
.
_optimizer_state_dict
=
cp
.
get
(
'optimizer_state_dict'
,
{})
self
.
_scheduler_state_dict
=
cp
.
get
(
'scheduler_state_dict'
,
{})
self
.
_epoch
=
cp
.
get
(
'epoch'
,
None
)
self
.
_time
=
cp
.
get
(
'time'
,
'Not found'
)
self
.
_hash
=
cp
.
get
(
'hash'
,
'Not found'
)
if
len
(
self
.
_config_original
)
==
0
:
warnings
.
warn
(
f
'config is not found from
{
cp_path
}
'
)
self
.
_config
=
{}
else
:
self
.
_config
=
_config_cp_routine
(
self
.
_config_original
)
if
len
(
self
.
_model_state_dict
)
==
0
:
warnings
.
warn
(
f
'model_state_dict is not found from
{
cp_path
}
'
)
self
.
_loaded
=
True
def
build_model
(
self
,
backend
:
Optional
[
str
]
=
None
)
->
AtomGraphSequential
:
from
.model_build
import
build_E3_equivariant_model
use_cue
=
not
backend
or
backend
.
lower
()
in
[
'cue'
,
'cueq'
]
try
:
cp_using_cue
=
self
.
config
[
KEY
.
CUEQUIVARIANCE_CONFIG
][
'use'
]
except
KeyError
:
cp_using_cue
=
False
if
(
not
backend
)
or
(
use_cue
==
cp_using_cue
):
# backend not given, or checkpoint backend is same as requested
model
=
build_E3_equivariant_model
(
self
.
config
)
state_dict
=
compat
.
patch_state_dict_if_old
(
self
.
model_state_dict
,
self
.
config
,
model
)
else
:
cfg_new
=
self
.
config
cfg_new
[
KEY
.
CUEQUIVARIANCE_CONFIG
]
=
{
'use'
:
use_cue
}
model
=
build_E3_equivariant_model
(
cfg_new
)
stct_src
=
compat
.
patch_state_dict_if_old
(
self
.
model_state_dict
,
self
.
config
,
model
)
state_dict
=
_convert_e3nn_and_cueq
(
stct_src
,
model
.
state_dict
(),
self
.
config
,
from_cueq
=
cp_using_cue
)
missing
,
not_used
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
if
len
(
not_used
)
>
0
:
warnings
.
warn
(
f
'Some keys are not used:
{
not_used
}
'
,
UserWarning
)
assert
len
(
missing
)
==
0
,
f
'Missing keys:
{
missing
}
'
return
model
def
yaml_dict
(
self
,
mode
:
str
)
->
dict
:
"""
Return dict for input.yaml from checkpoint config
Dataset paths and statistic values are removed intentionally
"""
if
mode
not
in
[
'reproduce'
,
'continue'
,
'continue_modal'
]:
raise
ValueError
(
f
'Unknown mode:
{
mode
}
'
)
ignore
=
[
'when'
,
KEY
.
DDP_BACKEND
,
KEY
.
LOCAL_RANK
,
KEY
.
IS_DDP
,
KEY
.
DEVICE
,
KEY
.
MODEL_TYPE
,
KEY
.
SHIFT
,
KEY
.
SCALE
,
KEY
.
CONV_DENOMINATOR
,
KEY
.
SAVE_DATASET
,
KEY
.
SAVE_BY_LABEL
,
KEY
.
SAVE_BY_TRAIN_VALID
,
KEY
.
CONTINUE
,
KEY
.
LOAD_DATASET
,
# old
]
cfg
=
self
.
config
len_atoms
=
len
(
cfg
[
KEY
.
TYPE_MAP
])
world_size
=
cfg
.
pop
(
KEY
.
WORLD_SIZE
,
1
)
cfg
[
KEY
.
BATCH_SIZE
]
=
cfg
[
KEY
.
BATCH_SIZE
]
*
world_size
cfg
[
KEY
.
LOAD_TRAINSET
]
=
'**path_to_training_set**'
major
,
minor
,
_
=
cfg
.
pop
(
'version'
,
'0.0.0'
).
split
(
'.'
)[:
3
]
if
int
(
major
)
==
0
and
int
(
minor
)
<=
9
:
warnings
.
warn
(
'checkpoint version too old, yaml may wrong'
)
ret
=
{
'model'
:
{},
'train'
:
{},
'data'
:
{}}
for
k
,
v
in
cfg
.
items
():
if
k
.
startswith
(
'_'
)
or
k
in
ignore
or
k
.
endswith
(
'set_path'
):
continue
if
k
in
consts
.
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
:
ret
[
'model'
][
k
]
=
v
elif
k
in
consts
.
DEFAULT_TRAINING_CONFIG
:
ret
[
'train'
][
k
]
=
v
elif
k
in
consts
.
DEFAULT_DATA_CONFIG
:
ret
[
'data'
][
k
]
=
v
ret
[
'model'
][
KEY
.
CHEMICAL_SPECIES
]
=
(
'univ'
if
len_atoms
==
consts
.
NUM_UNIV_ELEMENT
else
'auto'
)
ret
[
'data'
][
KEY
.
LOAD_TRAINSET
]
=
'**path_to_trainset**'
ret
[
'data'
][
KEY
.
LOAD_VALIDSET
]
=
'**path_to_validset**'
# TODO
ret
[
'data'
][
KEY
.
SHIFT
]
=
'**failed to infer shift, should be set**'
ret
[
'data'
][
KEY
.
SCALE
]
=
'**failed to infer scale, should be set**'
if
mode
.
startswith
(
'continue'
):
ret
[
'train'
].
update
(
{
KEY
.
CONTINUE
:
{
KEY
.
CHECKPOINT
:
self
.
checkpoint_path
}}
)
modal_names
=
None
if
mode
==
'continue_modal'
and
not
cfg
.
get
(
KEY
.
USE_MODALITY
,
False
):
ret
[
'train'
][
KEY
.
USE_MODALITY
]
=
True
# suggest defaults
ret
[
'model'
][
KEY
.
USE_MODAL_NODE_EMBEDDING
]
=
False
ret
[
'model'
][
KEY
.
USE_MODAL_SELF_INTER_INTRO
]
=
True
ret
[
'model'
][
KEY
.
USE_MODAL_SELF_INTER_OUTRO
]
=
True
ret
[
'model'
][
KEY
.
USE_MODAL_OUTPUT_BLOCK
]
=
True
ret
[
'data'
][
KEY
.
USE_MODAL_WISE_SHIFT
]
=
True
ret
[
'data'
][
KEY
.
USE_MODAL_WISE_SCALE
]
=
False
modal_names
=
[
'my_modal1'
,
'my_modal2'
]
elif
cfg
.
get
(
KEY
.
USE_MODALITY
,
False
):
modal_names
=
list
(
cfg
[
KEY
.
MODAL_MAP
].
keys
())
if
modal_names
:
ret
[
'data'
][
KEY
.
LOAD_TRAINSET
]
=
[
{
'data_modality'
:
mm
,
'file_list'
:
[{
'file'
:
f
'**path_to_
{
mm
}
**'
}]}
for
mm
in
modal_names
]
return
ret
def
append_modal
(
self
,
dst_config
,
original_modal_name
:
str
=
'origin'
,
working_dir
:
str
=
os
.
getcwd
(),
):
""" """
import
sevenn.train.modal_dataset
as
modal_dataset
from
sevenn.model_build
import
init_shift_scale
from
sevenn.scripts.convert_model_modality
import
_append_modal_weight
src_config
=
self
.
config
src_has_no_modal
=
not
src_config
.
get
(
KEY
.
USE_MODALITY
,
False
)
# inherit element things first
chem_keys
=
[
KEY
.
TYPE_MAP
,
KEY
.
NUM_SPECIES
,
KEY
.
CHEMICAL_SPECIES
,
KEY
.
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
,
]
dst_config
.
update
({
k
:
src_config
[
k
]
for
k
in
chem_keys
})
if
dst_config
[
KEY
.
USE_MODAL_WISE_SHIFT
]
and
(
KEY
.
SHIFT
not
in
dst_config
or
not
isinstance
(
dst_config
[
KEY
.
SHIFT
],
str
)
):
raise
ValueError
(
'To use modal wise shift, keyword shift is required'
)
if
dst_config
[
KEY
.
USE_MODAL_WISE_SCALE
]
and
(
KEY
.
SCALE
not
in
dst_config
or
not
isinstance
(
dst_config
[
KEY
.
SCALE
],
str
)
):
raise
ValueError
(
'To use modal wise scale, keyword scale is required'
)
if
src_has_no_modal
and
not
dst_config
[
KEY
.
USE_MODAL_WISE_SHIFT
]:
dst_config
[
KEY
.
SHIFT
]
=
src_config
[
KEY
.
SHIFT
]
if
src_has_no_modal
and
not
dst_config
[
KEY
.
USE_MODAL_WISE_SCALE
]:
dst_config
[
KEY
.
SCALE
]
=
src_config
[
KEY
.
SCALE
]
# get statistics of given datasets of yaml
# dst_config updated
_
=
modal_dataset
.
from_config
(
dst_config
,
working_dir
=
working_dir
)
dst_modal_map
=
dst_config
[
KEY
.
MODAL_MAP
]
found_modal_names
=
list
(
dst_modal_map
.
keys
())
if
len
(
found_modal_names
)
==
0
:
raise
ValueError
(
'No modality is found from config'
)
# Check difference btw given modals and new modal map
orig_modal_map
=
src_config
.
get
(
KEY
.
MODAL_MAP
,
{
original_modal_name
:
0
})
assert
isinstance
(
orig_modal_map
,
dict
)
new_modal_map
=
orig_modal_map
.
copy
()
for
modal_name
in
found_modal_names
:
if
modal_name
in
orig_modal_map
:
# duplicate, skipping
continue
new_modal_map
[
modal_name
]
=
len
(
new_modal_map
)
# assign new
print
(
f
'New modals:
{
list
(
new_modal_map
.
keys
())
}
'
)
if
src_has_no_modal
:
append_num
=
len
(
new_modal_map
)
else
:
append_num
=
len
(
new_modal_map
)
-
len
(
orig_modal_map
)
if
append_num
==
0
:
raise
ValueError
(
'Nothing to append from checkpoint'
)
dst_config
[
KEY
.
NUM_MODALITIES
]
=
len
(
new_modal_map
)
dst_config
[
KEY
.
MODAL_MAP
]
=
new_modal_map
# update dst_config's shift scales based on src_config
for
ss_key
,
use_mw
in
(
(
KEY
.
SHIFT
,
dst_config
[
KEY
.
USE_MODAL_WISE_SHIFT
]),
(
KEY
.
SCALE
,
dst_config
[
KEY
.
USE_MODAL_WISE_SCALE
]),
):
if
not
use_mw
:
# not using mw ss, just assign
assert
not
isinstance
(
dst_config
[
ss_key
],
dict
)
dst_config
[
ss_key
]
=
src_config
[
ss_key
]
elif
src_has_no_modal
:
assert
isinstance
(
dst_config
[
ss_key
],
dict
)
# mw ss, update by dict but use original_modal_name
dst_config
[
ss_key
].
update
({
original_modal_name
:
src_config
[
ss_key
]})
else
:
assert
isinstance
(
dst_config
[
ss_key
],
dict
)
# mw ss, update by dict
dst_config
[
ss_key
].
update
(
src_config
[
ss_key
])
scaler
=
init_shift_scale
(
dst_config
)
# finally, prepare updated continuable state dict using above
orig_model
=
self
.
build_model
()
orig_state_dict
=
orig_model
.
state_dict
()
new_state_dict
=
copy_state_dict
(
orig_state_dict
)
for
stct_key
in
orig_state_dict
:
sp
=
stct_key
.
split
(
'.'
)
k
,
follower
=
sp
[
0
],
'.'
.
join
(
sp
[
1
:])
if
k
==
'rescale_atomic_energy'
and
follower
==
'shift'
:
new_state_dict
[
stct_key
]
=
scaler
.
shift
.
clone
()
elif
k
==
'rescale_atomic_energy'
and
follower
==
'scale'
:
new_state_dict
[
stct_key
]
=
scaler
.
scale
.
clone
()
elif
follower
==
'linear.weight'
and
(
# append linear layer
(
dst_config
[
KEY
.
USE_MODAL_NODE_EMBEDDING
]
and
k
.
endswith
(
'onehot_to_feature_x'
)
)
or
(
dst_config
[
KEY
.
USE_MODAL_SELF_INTER_INTRO
]
and
k
.
endswith
(
'self_interaction_1'
)
)
or
(
dst_config
[
KEY
.
USE_MODAL_SELF_INTER_OUTRO
]
and
k
.
endswith
(
'self_interaction_2'
)
)
or
(
dst_config
[
KEY
.
USE_MODAL_OUTPUT_BLOCK
]
and
k
==
'reduce_input_to_hidden'
)
):
orig_linear
=
getattr
(
orig_model
.
_modules
[
k
],
'linear'
)
# assert normalization element
new_state_dict
[
stct_key
]
=
_append_modal_weight
(
orig_state_dict
,
k
,
orig_linear
.
irreps_in
,
orig_linear
.
irreps_out
,
append_num
,
)
dst_config
[
'version'
]
=
sevenn
.
__version__
return
new_state_dict
def
get_checkpoint_dict
(
self
)
->
dict
:
"""
Return duplicate of this checkpoint with new hash and time.
Convenient for creating variant of the checkpoint
"""
return
{
'config'
:
self
.
config
,
'epoch'
:
self
.
epoch
,
'model_state_dict'
:
self
.
model_state_dict
,
'optimizer_state_dict'
:
self
.
optimizer_state_dict
,
'scheduler_state_dict'
:
self
.
scheduler_state_dict
,
'time'
:
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M'
),
'hash'
:
uuid
.
uuid4
().
hex
,
}
mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py
0 → 100644
View file @
73866b01
from
copy
import
deepcopy
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
sevenn._keys
as
KEY
from
sevenn.train.loss
import
LossDefinition
from
.atom_graph_data
import
AtomGraphData
from
.train.optim
import
loss_dict
_ERROR_TYPES
=
{
'TotalEnergy'
:
{
'name'
:
'Energy'
,
'ref_key'
:
KEY
.
ENERGY
,
'pred_key'
:
KEY
.
PRED_TOTAL_ENERGY
,
'unit'
:
'eV'
,
'vdim'
:
1
,
},
'Energy'
:
{
# by default per-atom for energy
'name'
:
'Energy'
,
'ref_key'
:
KEY
.
ENERGY
,
'pred_key'
:
KEY
.
PRED_TOTAL_ENERGY
,
'unit'
:
'eV/atom'
,
'per_atom'
:
True
,
'vdim'
:
1
,
},
'Force'
:
{
'name'
:
'Force'
,
'ref_key'
:
KEY
.
FORCE
,
'pred_key'
:
KEY
.
PRED_FORCE
,
'unit'
:
'eV/Å'
,
'vdim'
:
3
,
},
'Stress'
:
{
'name'
:
'Stress'
,
'ref_key'
:
KEY
.
STRESS
,
'pred_key'
:
KEY
.
PRED_STRESS
,
'unit'
:
'kbar'
,
'coeff'
:
1602.1766208
,
'vdim'
:
6
,
},
'Stress_GPa'
:
{
'name'
:
'Stress'
,
'ref_key'
:
KEY
.
STRESS
,
'pred_key'
:
KEY
.
PRED_STRESS
,
'unit'
:
'GPa'
,
'coeff'
:
160.21766208
,
'vdim'
:
6
,
},
'TotalLoss'
:
{
'name'
:
'TotalLoss'
,
'unit'
:
None
,
},
}
def
get_err_type
(
name
:
str
)
->
Dict
[
str
,
Any
]:
return
deepcopy
(
_ERROR_TYPES
[
name
])
def
_get_loss_function_from_name
(
loss_functions
,
name
):
for
loss_def
,
w
in
loss_functions
:
if
loss_def
.
name
.
lower
()
==
name
.
lower
():
return
loss_def
,
w
return
None
,
None
class
AverageNumber
:
def
__init__
(
self
):
self
.
_sum
=
0.0
self
.
_count
=
0
def
update
(
self
,
values
:
torch
.
Tensor
):
self
.
_sum
+=
values
.
sum
().
item
()
self
.
_count
+=
values
.
numel
()
def
_ddp_reduce
(
self
,
device
):
_sum
=
torch
.
tensor
(
self
.
_sum
,
device
=
device
)
_count
=
torch
.
tensor
(
self
.
_count
,
device
=
device
)
dist
.
all_reduce
(
_sum
,
op
=
dist
.
ReduceOp
.
SUM
)
dist
.
all_reduce
(
_count
,
op
=
dist
.
ReduceOp
.
SUM
)
self
.
_sum
=
_sum
.
item
()
self
.
_count
=
_count
.
item
()
def
get
(
self
):
if
self
.
_count
==
0
:
return
torch
.
nan
return
self
.
_sum
/
self
.
_count
class
ErrorMetric
:
"""
Base class for error metrics We always average error by # of structures,
and designed to collect errors in the middle of iteration (by AverageNumber)
"""
def
__init__
(
self
,
name
:
str
,
ref_key
:
str
,
pred_key
:
str
,
coeff
:
float
=
1.0
,
unit
:
Optional
[
str
]
=
None
,
per_atom
:
bool
=
False
,
ignore_unlabeled
:
bool
=
True
,
**
kwargs
,
):
self
.
name
=
name
self
.
unit
=
unit
self
.
coeff
=
coeff
self
.
ref_key
=
ref_key
self
.
pred_key
=
pred_key
self
.
per_atom
=
per_atom
self
.
ignore_unlabeled
=
ignore_unlabeled
self
.
value
=
AverageNumber
()
def
update
(
self
,
output
:
AtomGraphData
):
raise
NotImplementedError
def
_retrieve
(
self
,
output
:
AtomGraphData
):
y_ref
=
output
[
self
.
ref_key
]
*
self
.
coeff
y_pred
=
output
[
self
.
pred_key
]
*
self
.
coeff
if
self
.
per_atom
:
assert
y_ref
.
dim
()
==
1
and
y_pred
.
dim
()
==
1
natoms
=
output
[
KEY
.
NUM_ATOMS
]
y_ref
=
y_ref
/
natoms
y_pred
=
y_pred
/
natoms
if
self
.
ignore_unlabeled
:
unlabelled_idx
=
torch
.
isnan
(
y_ref
)
y_ref
=
y_ref
[
~
unlabelled_idx
]
y_pred
=
y_pred
[
~
unlabelled_idx
]
return
y_ref
,
y_pred
def
ddp_reduce
(
self
,
device
):
self
.
value
.
_ddp_reduce
(
device
)
def
reset
(
self
):
self
.
value
=
AverageNumber
()
def
get
(
self
):
return
self
.
value
.
get
()
def
key_str
(
self
,
with_unit
=
True
):
if
self
.
unit
is
None
or
not
with_unit
:
return
self
.
name
else
:
return
f
'
{
self
.
name
}
(
{
self
.
unit
}
)'
def
__str__
(
self
):
return
f
'
{
self
.
key_str
()
}
:
{
self
.
value
.
get
():.
6
f
}
'
class
RMSError
(
ErrorMetric
):
"""
Vector squared error
"""
def
__init__
(
self
,
vdim
:
int
=
1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
vdim
=
vdim
self
.
_se
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
def
_square_error
(
self
,
y_ref
,
y_pred
,
vdim
:
int
):
return
self
.
_se
(
y_ref
.
view
(
-
1
,
vdim
),
y_pred
.
view
(
-
1
,
vdim
)).
sum
(
dim
=
1
)
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
se
=
self
.
_square_error
(
y_ref
,
y_pred
,
self
.
vdim
)
self
.
value
.
update
(
se
)
def
get
(
self
):
return
self
.
value
.
get
()
**
0.5
class
ComponentRMSError
(
ErrorMetric
):
"""
Ignore vector dim and just average over components
Results smaller error
"""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_se
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
def
_square_error
(
self
,
y_ref
,
y_pred
):
return
self
.
_se
(
y_ref
,
y_pred
)
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
y_ref
=
y_ref
.
view
(
-
1
)
y_pred
=
y_pred
.
view
(
-
1
)
se
=
self
.
_square_error
(
y_ref
,
y_pred
)
self
.
value
.
update
(
se
)
def
get
(
self
):
return
self
.
value
.
get
()
**
0.5
class
MAError
(
ErrorMetric
):
"""
Average over all component
"""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
def
_square_error
(
self
,
y_ref
,
y_pred
):
return
torch
.
abs
(
y_ref
-
y_pred
)
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
y_ref
=
y_ref
.
reshape
((
-
1
,))
y_pred
=
y_pred
.
reshape
((
-
1
,))
se
=
self
.
_square_error
(
y_ref
,
y_pred
)
self
.
value
.
update
(
se
)
class
CustomError
(
ErrorMetric
):
"""
Custom error metric
Args:
func: a function that takes y_ref and y_pred
and returns a list of errors
"""
def
__init__
(
self
,
func
:
Callable
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
func
=
func
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
se
=
self
.
func
(
y_ref
,
y_pred
)
if
len
(
y_ref
)
>
0
else
torch
.
tensor
([])
self
.
value
.
update
(
se
)
class
LossError
(
ErrorMetric
):
"""
Error metric that record loss
"""
def
__init__
(
self
,
name
:
str
,
loss_def
:
LossDefinition
,
**
kwargs
,
):
super
().
__init__
(
name
,
ignore_unlabeld
=
loss_def
.
ignore_unlabeled
,
**
kwargs
,
)
self
.
loss_def
=
loss_def
def
update
(
self
,
output
:
AtomGraphData
):
loss
=
self
.
loss_def
.
get_loss
(
output
)
# type: ignore
self
.
value
.
update
(
loss
)
# type: ignore
class
CombinedError
(
ErrorMetric
):
"""
Combine multiple error metrics with weights
corresponds to a weighted sum of errors (normally used in loss)
"""
def
__init__
(
self
,
metrics
:
List
[
Tuple
[
ErrorMetric
,
float
]],
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
metrics
=
metrics
assert
kwargs
[
'unit'
]
is
None
def
update
(
self
,
output
:
AtomGraphData
):
for
metric
,
_
in
self
.
metrics
:
metric
.
update
(
output
)
def
reset
(
self
):
for
metric
,
_
in
self
.
metrics
:
metric
.
reset
()
def
ddp_reduce
(
self
,
device
):
# override
for
metric
,
_
in
self
.
metrics
:
metric
.
value
.
_ddp_reduce
(
device
)
def
get
(
self
):
val
=
0.0
for
metric
,
weight
in
self
.
metrics
:
val
+=
metric
.
get
()
*
weight
return
val
class
ErrorRecorder
:
"""
record errors of a model
"""
METRIC_DICT
=
{
'RMSE'
:
RMSError
,
'ComponentRMSE'
:
ComponentRMSError
,
'MAE'
:
MAError
,
'Loss'
:
LossError
,
}
def
__init__
(
self
,
metrics
:
List
[
ErrorMetric
]):
self
.
history
=
[]
self
.
metrics
=
metrics
def
_update
(
self
,
output
:
AtomGraphData
):
for
metric
in
self
.
metrics
:
metric
.
update
(
output
)
def
update
(
self
,
output
:
AtomGraphData
,
no_grad
=
True
):
if
no_grad
:
with
torch
.
no_grad
():
self
.
_update
(
output
)
else
:
self
.
_update
(
output
)
def
get_metric_dict
(
self
,
with_unit
=
True
):
return
{
metric
.
key_str
(
with_unit
):
metric
.
get
()
for
metric
in
self
.
metrics
}
def
get_current
(
self
):
dct
=
{}
for
metric
in
self
.
metrics
:
dct
[
metric
.
name
]
=
{
'value'
:
metric
.
get
(),
'unit'
:
metric
.
unit
,
'ref_key'
:
metric
.
ref_key
,
'pred_key'
:
metric
.
pred_key
,
}
return
dct
def
get_dct
(
self
,
prefix
=
''
):
dct
=
{}
if
prefix
.
endswith
(
'_'
)
is
False
and
prefix
!=
''
:
prefix
=
prefix
+
'_'
for
metric
in
self
.
metrics
:
dct
[
f
'
{
prefix
}{
metric
.
name
}
'
]
=
f
'
{
metric
.
get
():
6
f
}
'
return
dct
def
get_key_str
(
self
,
name
:
str
):
for
metric
in
self
.
metrics
:
if
name
==
metric
.
name
:
return
metric
.
key_str
()
return
None
def
epoch_forward
(
self
):
self
.
history
.
append
(
self
.
get_current
())
pretty
=
self
.
get_metric_dict
(
with_unit
=
True
)
for
metric
in
self
.
metrics
:
metric
.
reset
()
return
pretty
# for print
@
staticmethod
def
init_total_loss_metric
(
config
,
criteria
:
Optional
[
Callable
]
=
None
,
loss_functions
:
Optional
[
List
[
Tuple
[
LossDefinition
,
float
]]]
=
None
,
):
if
criteria
is
None
and
loss_functions
is
None
:
raise
ValueError
(
'both criteria and loss functions not given'
)
is_stress
=
config
[
KEY
.
IS_TRAIN_STRESS
]
metrics
=
[]
if
criteria
is
not
None
:
energy_metric
=
CustomError
(
criteria
,
**
get_err_type
(
'Energy'
))
metrics
.
append
((
energy_metric
,
1
))
force_metric
=
CustomError
(
criteria
,
**
get_err_type
(
'Force'
))
metrics
.
append
((
force_metric
,
config
[
KEY
.
FORCE_WEIGHT
]))
if
is_stress
:
stress_metric
=
CustomError
(
criteria
,
**
get_err_type
(
'Stress'
))
metrics
.
append
((
stress_metric
,
config
[
KEY
.
STRESS_WEIGHT
]))
else
:
# TODO: this is hard-coded
for
efs
in
[
'Energy'
,
'Force'
,
'Stress'
]:
if
efs
==
'Stress'
and
not
is_stress
:
continue
lf
,
w
=
_get_loss_function_from_name
(
loss_functions
,
efs
)
if
lf
is
None
:
raise
ValueError
(
f
'
{
efs
}
not found from loss_functions'
)
metric
=
LossError
(
loss_def
=
lf
,
**
get_err_type
(
efs
))
metrics
.
append
((
metric
,
w
))
total_loss_metric
=
CombinedError
(
metrics
,
name
=
'TotalLoss'
,
unit
=
None
,
ref_key
=
None
,
pred_key
=
None
)
return
total_loss_metric
@
staticmethod
def
from_config
(
config
:
dict
,
loss_functions
=
None
):
loss_cls
=
loss_dict
[
config
.
get
(
KEY
.
LOSS
,
'mse'
).
lower
()]
loss_param
=
config
.
get
(
KEY
.
LOSS_PARAM
,
{})
criteria
=
loss_cls
(
**
loss_param
)
if
loss_functions
is
None
else
None
err_config
=
config
.
get
(
KEY
.
ERROR_RECORD
,
False
)
if
not
err_config
:
raise
ValueError
(
'No error_record config found. Consider util.get_error_recorder'
)
err_config_n
=
[]
if
not
config
.
get
(
KEY
.
IS_TRAIN_STRESS
,
True
):
for
err_type
,
metric_name
in
err_config
:
if
'Stress'
in
err_type
:
continue
err_config_n
.
append
((
err_type
,
metric_name
))
err_config
=
err_config_n
err_metrics
=
[]
for
err_type
,
metric_name
in
err_config
:
metric_kwargs
=
get_err_type
(
err_type
)
if
err_type
==
'TotalLoss'
:
# special case
err_metrics
.
append
(
ErrorRecorder
.
init_total_loss_metric
(
config
,
criteria
,
loss_functions
)
)
continue
metric_cls
=
ErrorRecorder
.
METRIC_DICT
[
metric_name
]
assert
isinstance
(
metric_kwargs
[
'name'
],
str
)
if
metric_name
==
'Loss'
:
if
loss_functions
is
not
None
:
metric_cls
=
LossError
metric_kwargs
[
'loss_def'
],
_
=
_get_loss_function_from_name
(
loss_functions
,
metric_kwargs
[
'name'
]
)
else
:
metric_cls
=
CustomError
metric_kwargs
[
'func'
]
=
criteria
metric_kwargs
.
pop
(
'unit'
,
None
)
metric_kwargs
[
'name'
]
+=
f
'_
{
metric_name
}
'
err_metrics
.
append
(
metric_cls
(
**
metric_kwargs
))
return
ErrorRecorder
(
err_metrics
)
mace-bench/3rdparty/SevenNet/sevenn/logger.py
0 → 100644
View file @
73866b01
import
os
import
time
import
traceback
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Optional
from
ase.data
import
atomic_numbers
import
sevenn._keys
as
KEY
from
sevenn
import
__version__
CHEM_SYMBOLS
=
{
v
:
k
for
k
,
v
in
atomic_numbers
.
items
()}
class
Singleton
(
type
):
_instances
=
{}
def
__call__
(
cls
,
*
args
,
**
kwargs
):
if
cls
not
in
cls
.
_instances
:
cls
.
_instances
[
cls
]
=
super
(
Singleton
,
cls
).
__call__
(
*
args
,
**
kwargs
)
return
cls
.
_instances
[
cls
]
class
Logger
(
metaclass
=
Singleton
):
SCREEN_WIDTH
=
120
# half size of my screen / changed due to stress output
def
__init__
(
self
,
filename
:
Optional
[
str
]
=
None
,
screen
:
bool
=
False
,
rank
:
int
=
0
):
self
.
rank
=
rank
self
.
_filename
=
filename
if
rank
==
0
:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self
.
logfile
=
None
self
.
files
=
{}
self
.
screen
=
screen
else
:
self
.
logfile
=
None
self
.
screen
=
False
self
.
timer_dct
=
{}
self
.
active
=
True
def
__enter__
(
self
):
if
self
.
rank
!=
0
:
return
self
if
self
.
logfile
is
None
and
self
.
_filename
is
not
None
:
try
:
self
.
logfile
=
open
(
self
.
_filename
,
'a'
,
buffering
=
1
,
encoding
=
'utf-8'
)
except
IOError
as
e
:
print
(
f
'Failed to re-open log file
{
self
.
_filename
}
:
{
e
}
'
)
self
.
logfile
=
None
self
.
files
=
{}
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
self
.
rank
!=
0
:
return
self
try
:
if
self
.
logfile
is
not
None
:
self
.
logfile
.
close
()
self
.
logfile
=
None
for
f
in
self
.
files
.
values
():
f
.
close
()
except
IOError
as
e
:
print
(
f
'Failed to close log files:
{
e
}
'
)
finally
:
self
.
logfile
=
None
self
.
files
=
{}
def
switch_file
(
self
,
new_filename
:
str
):
if
self
.
rank
!=
0
:
return
self
if
self
.
logfile
is
not
None
:
raise
ValueError
(
'Current logfile is not yet closed'
)
self
.
_filename
=
new_filename
return
self
def
write
(
self
,
content
:
str
):
if
self
.
rank
!=
0
:
return
# no newline!
if
self
.
logfile
is
not
None
and
self
.
active
:
self
.
logfile
.
write
(
content
)
if
self
.
screen
and
self
.
active
:
print
(
content
,
end
=
''
)
def
writeline
(
self
,
content
:
str
):
content
=
content
+
'
\n
'
self
.
write
(
content
)
def
init_csv
(
self
,
filename
:
str
,
header
:
list
):
"""
Deprecated
"""
if
self
.
rank
==
0
:
self
.
files
[
filename
]
=
open
(
filename
,
'w'
,
buffering
=
1
,
encoding
=
'utf-8'
)
self
.
files
[
filename
].
write
(
','
.
join
(
header
)
+
'
\n
'
)
else
:
pass
def
append_csv
(
self
,
filename
:
str
,
content
:
list
,
decimal
:
int
=
6
):
"""
Deprecated
"""
if
self
.
rank
==
0
:
if
filename
not
in
self
.
files
:
self
.
files
[
filename
]
=
open
(
filename
,
'a'
,
buffering
=
1
)
str_content
=
[]
for
c
in
content
:
if
isinstance
(
c
,
float
):
str_content
.
append
(
f
'
{
c
:.
{
decimal
}
f
}
'
)
else
:
str_content
.
append
(
str
(
c
))
self
.
files
[
filename
].
write
(
','
.
join
(
str_content
)
+
'
\n
'
)
else
:
pass
def
natoms_write
(
self
,
natoms
:
Dict
[
str
,
Dict
]):
content
=
''
total_natom
=
{}
for
label
,
natom
in
natoms
.
items
():
content
+=
self
.
format_k_v
(
label
,
natom
)
for
specie
,
num
in
natom
.
items
():
try
:
total_natom
[
specie
]
+=
num
except
KeyError
:
total_natom
[
specie
]
=
num
content
+=
self
.
format_k_v
(
'Total, label wise'
,
total_natom
)
content
+=
self
.
format_k_v
(
'Total'
,
sum
(
total_natom
.
values
()))
self
.
write
(
content
)
def
statistic_write
(
self
,
statistic
:
Dict
[
str
,
Dict
]):
content
=
''
for
label
,
dct
in
statistic
.
items
():
if
label
.
startswith
(
'_'
):
continue
if
not
isinstance
(
dct
,
dict
):
continue
dct_new
=
{}
for
k
,
v
in
dct
.
items
():
if
k
.
startswith
(
'_'
):
continue
if
isinstance
(
v
,
int
):
dct_new
[
k
]
=
v
else
:
dct_new
[
k
]
=
f
'
{
v
:.
3
f
}
'
content
+=
self
.
format_k_v
(
label
,
dct_new
)
self
.
write
(
content
)
# TODO : refactoring!!!, this is not loss, rmse
def
epoch_write_specie_wise_loss
(
self
,
train_loss
,
valid_loss
):
lb_pad
=
21
fs
=
6
pad
=
21
-
fs
ln
=
'-'
*
fs
total_atom_type
=
train_loss
.
keys
()
content
=
''
for
at
in
total_atom_type
:
t_F
=
train_loss
[
at
]
v_F
=
valid_loss
[
at
]
at_sym
=
CHEM_SYMBOLS
[
at
]
content
+=
'{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'
.
format
(
label
=
at_sym
,
t_E
=
ln
,
v_E
=
ln
,
lb_pad
=
lb_pad
,
pad
=
pad
,
fs
=
fs
)
+
'{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'
.
format
(
t_F
=
t_F
,
v_F
=
v_F
,
pad
=
pad
,
fs
=
fs
)
content
+=
'{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'
.
format
(
t_S
=
ln
,
v_S
=
ln
,
pad
=
pad
,
fs
=
fs
)
content
+=
'
\n
'
self
.
write
(
content
)
def
write_full_table
(
self
,
dict_list
:
List
[
Dict
],
row_labels
:
List
[
str
],
decimal_places
:
int
=
6
,
pad
:
int
=
2
,
):
"""
Assume data_list is list of dict with same keys
"""
assert
len
(
dict_list
)
==
len
(
row_labels
)
label_len
=
max
(
map
(
len
,
row_labels
))
# Extract the column names and create a 2D array of values
col_names
=
list
(
dict_list
[
0
].
keys
())
values
=
[
list
(
d
.
values
())
for
d
in
dict_list
]
# Format the numbers with the given decimal places
formatted_values
=
[
[
f
'
{
value
:.
{
decimal_places
}
f
}
'
for
value
in
row
]
for
row
in
values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths
=
[
max
(
len
(
str
(
value
))
for
value
in
col
)
+
pad
for
col
in
zip
(
col_names
,
*
formatted_values
)
]
# Create header row and separator
header
=
' '
*
(
label_len
+
pad
)
+
' '
.
join
(
col_name
.
ljust
(
pad
)
for
col_name
,
pad
in
zip
(
col_names
,
max_col_lengths
)
)
separator
=
'-'
.
join
(
'-'
*
pad
for
pad
in
max_col_lengths
)
+
'-'
*
(
label_len
+
pad
)
# Print header and separator
self
.
writeline
(
header
)
self
.
writeline
(
separator
)
# Print the data rows with row labels
for
row_label
,
row
in
zip
(
row_labels
,
formatted_values
):
data_row
=
' '
.
join
(
value
.
rjust
(
pad
)
for
value
,
pad
in
zip
(
row
,
max_col_lengths
)
)
self
.
writeline
(
f
'
{
row_label
.
ljust
(
label_len
)
}{
data_row
}
'
)
def
format_k_v
(
self
,
key
:
Any
,
val
:
Any
,
write
:
bool
=
False
):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE
=
20
SEPARATOR
=
', '
EMPTY_PADDING
=
' '
*
(
MAX_KEY_SIZE
+
3
)
NEW_LINE_LEN
=
Logger
.
SCREEN_WIDTH
-
5
key
=
str
(
key
)
val
=
str
(
val
)
content
=
f
'
{
key
:
<
{
MAX_KEY_SIZE
}}
:
{
val
}
'
if
len
(
content
)
>
NEW_LINE_LEN
:
content
=
f
'
{
key
:
<
{
MAX_KEY_SIZE
}}
: '
# septate val by separator
val_list
=
val
.
split
(
SEPARATOR
)
current_len
=
len
(
content
)
for
val_compo
in
val_list
:
current_len
+=
len
(
val_compo
)
if
current_len
>
NEW_LINE_LEN
:
newline_content
=
f
'
{
EMPTY_PADDING
}{
val_compo
}{
SEPARATOR
}
'
content
+=
f
'
\\\n
{
newline_content
}
'
current_len
=
len
(
newline_content
)
else
:
content
+=
f
'
{
val_compo
}{
SEPARATOR
}
'
if
content
.
endswith
(
f
'
{
SEPARATOR
}
'
):
content
=
content
[:
-
len
(
SEPARATOR
)]
content
+=
'
\n
'
if
write
is
False
:
return
content
else
:
self
.
write
(
content
)
return
''
def
greeting
(
self
):
LOGO_ASCII_FILE
=
f
'
{
os
.
path
.
dirname
(
__file__
)
}
/logo_ascii'
with
open
(
LOGO_ASCII_FILE
,
'r'
)
as
logo_f
:
logo_ascii
=
logo_f
.
read
()
content
=
'SevenNet: Scalable EquiVariance-Enabled Neural Network
\n
'
content
+=
f
'version
{
__version__
}
,
{
time
.
ctime
()
}
\n
'
self
.
write
(
content
)
self
.
write
(
logo_ascii
)
def
bar
(
self
):
content
=
'-'
*
Logger
.
SCREEN_WIDTH
+
'
\n
'
self
.
write
(
content
)
def
print_config
(
self
,
model_config
:
Dict
[
str
,
Any
],
data_config
:
Dict
[
str
,
Any
],
train_config
:
Dict
[
str
,
Any
],
):
"""
print some important information from config
"""
content
=
'successfully read yaml config!
\n\n
'
+
'from model configuration
\n
'
for
k
,
v
in
model_config
.
items
():
content
+=
self
.
format_k_v
(
k
,
str
(
v
))
content
+=
'
\n
from train configuration
\n
'
for
k
,
v
in
train_config
.
items
():
content
+=
self
.
format_k_v
(
k
,
str
(
v
))
content
+=
'
\n
from data configuration
\n
'
for
k
,
v
in
data_config
.
items
():
content
+=
self
.
format_k_v
(
k
,
str
(
v
))
self
.
write
(
content
)
# TODO: This is not good make own exception
def
error
(
self
,
e
:
Exception
):
content
=
''
if
type
(
e
)
is
ValueError
:
content
+=
'Error occurred!
\n
'
content
+=
str
(
e
)
+
'
\n
'
else
:
content
+=
'Unknown error occurred!
\n
'
content
+=
traceback
.
format_exc
()
self
.
write
(
content
)
def
timer_start
(
self
,
name
:
str
):
self
.
timer_dct
[
name
]
=
datetime
.
now
()
def
timer_end
(
self
,
name
:
str
,
message
:
str
,
remove
:
bool
=
True
):
"""
print f"{message}: {elapsed}"
"""
elapsed
=
str
(
datetime
.
now
()
-
self
.
timer_dct
[
name
])
# elapsed = elapsed.strftime('%H-%M-%S')
if
remove
:
del
self
.
timer_dct
[
name
]
self
.
write
(
f
'
{
message
}
:
{
elapsed
[:
-
4
]
}
\n
'
)
# TODO: print it without config
# TODO: refactoring, readout part name :(
def
print_model_info
(
self
,
model
,
config
):
from
functools
import
partial
kv_write
=
partial
(
self
.
format_k_v
,
write
=
True
)
self
.
writeline
(
'Irreps of features'
)
kv_write
(
'edge_feature'
,
model
.
get_irreps_in
(
'edge_embedding'
,
'irreps_out'
))
for
i
in
range
(
config
[
KEY
.
NUM_CONVOLUTION
]):
kv_write
(
f
'
{
i
}
th node'
,
model
.
get_irreps_in
(
f
'
{
i
}
_self_interaction_1'
),
)
i
=
config
[
KEY
.
NUM_CONVOLUTION
]
-
1
kv_write
(
'readout irreps'
,
model
.
get_irreps_in
(
f
'
{
i
}
_equivariant_gate'
,
'irreps_out'
),
)
num_weights
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
self
.
writeline
(
f
'# learnable parameters:
{
num_weights
}
\n
'
)
mace-bench/3rdparty/SevenNet/sevenn/logo_ascii
0 → 100644
View file @
73866b01
****
******** .
*//////, .. . ,*.
,,***. .. , ********. ./,
. . .. /////. ., . *///////// /////////.
.&@&/ . .(((((((.. / *//////*. ... *((((((((((.
@@@@@@@@@@* @@@@@@@@@@ @@@@@ *((@@@@@ ( %@@@@@@@@@@ .@@@@@@ ..@@@@. @@@@@@* .(@@@@@(((*
@@@@@. @@@@ @@@@@ . @@@@@ # %@@@@ @@@@@@@@ @@@@(, @@@@@@@@. @@@@@(*.
%@@@@@@@& @@@@@@@@@@ @@@@@ @@@@@ # ., .%@@@@@@@@@ @@@@@@@@@@ @@@@, @@@@@@@@@@ @@@@@
,(%@@@@@@@@@ @@@@@@@@@@ @@@@@ @@@@& (//////%@@@@@@@@@ @@@@ @@@@@@ @@@@ . @@@@@ @@@@@.@@@@@
. @@@@@ @@@@ . . @@@@@@@@% . . ( .////,%@@@@ @@@@ @@@@@@@@@ @@@@@ @@@@@@@@@
(@@@@@@@@@@@ @@@@@@@@@@**. @@@@@@* *. .%@@@@@@@@@@ @@@@ . @@@@@@@ @@@@@ .@@@@@@@
@@@@@@@@@. @@@@@@@@@@///, @@@@. . / %@@@@@@@@@@ @@@@***, @@@@@ @@@@@ @@@@@
. //////////*. / . .*******... . ,.
.&&&&&... ,//////*. ...////. / ,*/. . ,////, .,/////
&@@@@@@ ,(/((, * ,((((((. .***.
,/@(, .. * ,((((*
,
.
mace-bench/3rdparty/SevenNet/sevenn/model_build.py
0 → 100644
View file @
73866b01
import
copy
import
warnings
from
collections
import
OrderedDict
from
typing
import
List
,
Literal
,
Union
,
overload
from
e3nn.o3
import
Irreps
import
sevenn._const
as
_const
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
from
.nn.convolution
import
IrrepsConvolution
from
.nn.edge_embedding
import
(
BesselBasis
,
EdgeEmbedding
,
PolynomialCutoff
,
SphericalEncoding
,
XPLORCutoff
,
)
from
.nn.force_output
import
ForceStressOutputFromEdge
from
.nn.interaction_blocks
import
NequIP_interaction_block
from
.nn.linear
import
AtomReduce
,
FCN_e3nn
,
IrrepsLinear
from
.nn.node_embedding
import
OnehotEmbedding
from
.nn.scale
import
ModalWiseRescale
,
Rescale
,
SpeciesWiseRescale
from
.nn.self_connection
import
(
SelfConnectionIntro
,
SelfConnectionLinearIntro
,
SelfConnectionOutro
,
)
from
.nn.sequential
import
AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings
.
filterwarnings
(
'ignore'
,
message
=
(
"The TorchScript type system doesn't "
'support instance-level annotations'
),
)
def
_insert_after
(
module_name_after
,
key_module_pair
,
layers
):
idx
=
-
1
for
i
,
(
key
,
_
)
in
enumerate
(
layers
):
if
key
==
module_name_after
:
idx
=
i
break
if
idx
==
-
1
:
return
layers
# do nothing if not found
layers
.
insert
(
idx
+
1
,
key_module_pair
)
return
layers
def
init_self_connection
(
config
):
self_connection_type_list
=
config
[
KEY
.
SELF_CONNECTION_TYPE
]
num_conv
=
config
[
KEY
.
NUM_CONVOLUTION
]
if
isinstance
(
self_connection_type_list
,
str
):
self_connection_type_list
=
[
self_connection_type_list
]
*
num_conv
io_pair_list
=
[]
for
sc_type
in
self_connection_type_list
:
if
sc_type
==
'none'
:
io_pair
=
None
elif
sc_type
==
'nequip'
:
io_pair
=
SelfConnectionIntro
,
SelfConnectionOutro
elif
sc_type
==
'linear'
:
io_pair
=
SelfConnectionLinearIntro
,
SelfConnectionOutro
else
:
raise
ValueError
(
f
'Unknown self_connection_type found:
{
sc_type
}
'
)
io_pair_list
.
append
(
io_pair
)
return
io_pair_list
def
init_edge_embedding
(
config
):
_cutoff_param
=
{
'cutoff_length'
:
config
[
KEY
.
CUTOFF
]}
rbf
,
env
,
sph
=
None
,
None
,
None
rbf_dct
=
copy
.
deepcopy
(
config
[
KEY
.
RADIAL_BASIS
])
rbf_dct
.
update
(
_cutoff_param
)
rbf_name
=
rbf_dct
.
pop
(
KEY
.
RADIAL_BASIS_NAME
)
if
rbf_name
==
'bessel'
:
rbf
=
BesselBasis
(
**
rbf_dct
)
envelop_dct
=
copy
.
deepcopy
(
config
[
KEY
.
CUTOFF_FUNCTION
])
envelop_dct
.
update
(
_cutoff_param
)
envelop_name
=
envelop_dct
.
pop
(
KEY
.
CUTOFF_FUNCTION_NAME
)
if
envelop_name
==
'poly_cut'
:
env
=
PolynomialCutoff
(
**
envelop_dct
)
elif
envelop_name
==
'XPLOR'
:
env
=
XPLORCutoff
(
**
envelop_dct
)
lmax_edge
=
config
[
KEY
.
LMAX
]
if
config
[
KEY
.
LMAX_EDGE
]
>
0
:
lmax_edge
=
config
[
KEY
.
LMAX_EDGE
]
parity
=
-
1
if
config
[
KEY
.
IS_PARITY
]
else
1
_normalize_sph
=
config
[
KEY
.
_NORMALIZE_SPH
]
sph
=
SphericalEncoding
(
lmax_edge
,
parity
,
normalize
=
_normalize_sph
)
return
EdgeEmbedding
(
basis_module
=
rbf
,
cutoff_module
=
env
,
spherical_module
=
sph
)
def
init_feature_reduce
(
config
,
irreps_x
):
# features per node to scalar per node
layers
=
OrderedDict
()
if
config
[
KEY
.
READOUT_AS_FCN
]
is
False
:
hidden_irreps
=
Irreps
([(
irreps_x
.
dim
//
2
,
(
0
,
1
))])
layers
.
update
(
{
'reduce_input_to_hidden'
:
IrrepsLinear
(
irreps_x
,
hidden_irreps
,
data_key_in
=
KEY
.
NODE_FEATURE
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
'reduce_hidden_to_energy'
:
IrrepsLinear
(
hidden_irreps
,
Irreps
([(
1
,
(
0
,
1
))]),
data_key_in
=
KEY
.
NODE_FEATURE
,
data_key_out
=
KEY
.
SCALED_ATOMIC_ENERGY
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
}
)
else
:
act
=
_const
.
ACTIVATION
[
config
[
KEY
.
READOUT_FCN_ACTIVATION
]]
hidden_neurons
=
config
[
KEY
.
READOUT_FCN_HIDDEN_NEURONS
]
layers
.
update
(
{
'readout_FCN'
:
FCN_e3nn
(
dim_out
=
1
,
hidden_neurons
=
hidden_neurons
,
activation
=
act
,
data_key_in
=
KEY
.
NODE_FEATURE
,
data_key_out
=
KEY
.
SCALED_ATOMIC_ENERGY
,
irreps_in
=
irreps_x
,
)
}
)
return
layers
def
init_shift_scale
(
config
):
# for mm, ex, shift: modal_idx -> shifts
shift_scale
=
[]
train_shift_scale
=
config
[
KEY
.
TRAIN_SHIFT_SCALE
]
type_map
=
config
[
KEY
.
TYPE_MAP
]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for
s
in
(
config
[
KEY
.
SHIFT
],
config
[
KEY
.
SCALE
]):
if
hasattr
(
s
,
'tolist'
):
# numpy or torch
s
=
s
.
tolist
()
if
isinstance
(
s
,
dict
):
s
=
{
k
:
v
.
tolist
()
if
hasattr
(
v
,
'tolist'
)
else
v
for
k
,
v
in
s
.
items
()}
if
isinstance
(
s
,
list
)
and
len
(
s
)
==
1
:
s
=
s
[
0
]
shift_scale
.
append
(
s
)
shift
,
scale
=
shift_scale
rescale_module
=
None
if
config
.
get
(
KEY
.
USE_MODALITY
,
False
):
rescale_module
=
ModalWiseRescale
.
from_mappers
(
# type: ignore
shift
,
scale
,
config
[
KEY
.
USE_MODAL_WISE_SHIFT
],
config
[
KEY
.
USE_MODAL_WISE_SCALE
],
type_map
=
type_map
,
modal_map
=
config
[
KEY
.
MODAL_MAP
],
train_shift_scale
=
train_shift_scale
,
)
elif
all
([
isinstance
(
s
,
float
)
for
s
in
shift_scale
]):
rescale_module
=
Rescale
(
shift
,
scale
,
train_shift_scale
=
train_shift_scale
)
elif
any
([
isinstance
(
s
,
list
)
for
s
in
shift_scale
]):
rescale_module
=
SpeciesWiseRescale
.
from_mappers
(
# type: ignore
shift
,
scale
,
type_map
=
type_map
,
train_shift_scale
=
train_shift_scale
)
else
:
raise
ValueError
(
'shift, scale should be list of float or float'
)
return
rescale_module
def
patch_modality
(
layers
:
OrderedDict
,
config
):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg
=
config
if
not
cfg
.
get
(
KEY
.
USE_MODALITY
,
False
):
return
layers
_layers
=
list
(
layers
.
items
())
_layers
=
_insert_after
(
'onehot_idx_to_onehot'
,
(
'one_hot_modality'
,
OnehotEmbedding
(
num_classes
=
config
[
KEY
.
NUM_MODALITIES
],
data_key_x
=
KEY
.
MODAL_TYPE
,
data_key_out
=
KEY
.
MODAL_ATTR
,
data_key_save
=
None
,
data_key_additional
=
None
,
),
),
_layers
,
)
layers
=
OrderedDict
(
_layers
)
num_modal
=
config
[
KEY
.
NUM_MODALITIES
]
for
k
,
module
in
layers
.
items
():
if
not
isinstance
(
module
,
IrrepsLinear
):
continue
if
(
(
cfg
[
KEY
.
USE_MODAL_NODE_EMBEDDING
]
and
k
.
endswith
(
'onehot_to_feature_x'
))
or
(
cfg
[
KEY
.
USE_MODAL_SELF_INTER_INTRO
]
and
k
.
endswith
(
'self_interaction_1'
)
)
or
(
cfg
[
KEY
.
USE_MODAL_SELF_INTER_OUTRO
]
and
k
.
endswith
(
'self_interaction_2'
)
)
or
(
cfg
[
KEY
.
USE_MODAL_OUTPUT_BLOCK
]
and
k
==
'reduce_input_to_hidden'
)
):
module
.
set_num_modalities
(
num_modal
)
return
layers
def
patch_cue
(
layers
:
OrderedDict
,
config
):
import
sevenn.nn.cue_helper
as
cue_helper
cue_cfg
=
copy
.
deepcopy
(
config
.
get
(
KEY
.
CUEQUIVARIANCE_CONFIG
,
{}))
if
not
cue_cfg
.
pop
(
'use'
,
False
):
return
layers
if
not
cue_helper
.
is_cue_available
():
warnings
.
warn
(
(
'cuEquivariance is requested, but the package is not installed. '
+
'Fallback to original code.'
)
)
return
layers
if
not
cue_helper
.
is_cue_cuda_available_model
(
config
):
return
layers
group
=
'O3'
if
config
[
KEY
.
IS_PARITY
]
else
'SO3'
cueq_module_params
=
dict
(
layout
=
'mul_ir'
)
cueq_module_params
.
update
(
cue_cfg
)
updates
=
{}
for
k
,
module
in
layers
.
items
():
if
isinstance
(
module
,
(
IrrepsLinear
,
SelfConnectionLinearIntro
)):
if
k
==
'reduce_hidden_to_energy'
:
# TODO: has bug with 0 shape
continue
module_patched
=
cue_helper
.
patch_linear
(
module
,
group
,
**
cueq_module_params
)
updates
[
k
]
=
module_patched
elif
isinstance
(
module
,
SelfConnectionIntro
):
module_patched
=
cue_helper
.
patch_fully_connected
(
module
,
group
,
**
cueq_module_params
)
updates
[
k
]
=
module_patched
elif
isinstance
(
module
,
IrrepsConvolution
):
module_patched
=
cue_helper
.
patch_convolution
(
module
,
group
,
**
cueq_module_params
)
updates
[
k
]
=
module_patched
layers
.
update
(
updates
)
return
layers
def
patch_modules
(
layers
:
OrderedDict
,
config
):
layers
=
patch_modality
(
layers
,
config
)
layers
=
patch_cue
(
layers
,
config
)
return
layers
def
_to_parallel_model
(
layers
:
OrderedDict
,
config
):
num_classes
=
layers
[
'onehot_idx_to_onehot'
].
num_classes
one_hot_irreps
=
Irreps
(
f
'
{
num_classes
}
x0e'
)
irreps_node_zero
=
layers
[
'onehot_to_feature_x'
].
irreps_out
_layers
=
list
(
layers
.
items
())
layers_list
=
[]
num_convolution_layer
=
config
[
KEY
.
NUM_CONVOLUTION
]
def
slice_until_this
(
module_name
,
layers
):
idx
=
-
1
for
i
,
(
key
,
_
)
in
enumerate
(
layers
):
if
key
==
module_name
:
idx
=
i
break
first_to
=
layers
[:
idx
+
1
]
remain
=
layers
[
idx
+
1
:]
return
first_to
,
remain
_layers
=
_insert_after
(
'onehot_to_feature_x'
,
(
'one_hot_ghost'
,
OnehotEmbedding
(
data_key_x
=
KEY
.
NODE_FEATURE_GHOST
,
num_classes
=
num_classes
,
data_key_save
=
None
,
data_key_additional
=
None
,
),
),
_layers
,
)
_layers
=
_insert_after
(
'one_hot_ghost'
,
(
'ghost_onehot_to_feature_x'
,
IrrepsLinear
(
irreps_in
=
one_hot_irreps
,
irreps_out
=
irreps_node_zero
,
data_key_in
=
KEY
.
NODE_FEATURE_GHOST
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
),
_layers
,
)
_layers
=
_insert_after
(
'0_self_interaction_1'
,
(
'ghost_0_self_interaction_1'
,
IrrepsLinear
(
irreps_node_zero
,
irreps_node_zero
,
data_key_in
=
KEY
.
NODE_FEATURE_GHOST
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
),
_layers
,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for
i
in
range
(
1
,
num_convolution_layer
):
sliced
,
_layers
=
slice_until_this
(
f
'
{
i
}
_self_interaction_1'
,
_layers
)
layers_list
.
append
(
OrderedDict
(
sliced
))
_layers
.
insert
(
0
,
(
'edge_embedding'
,
init_edge_embedding
(
config
)))
layers_list
.
append
(
OrderedDict
(
_layers
))
del
layers_list
[
-
1
][
'force_output'
]
# done in LAMMPS
return
layers_list
@
overload
def
build_E3_equivariant_model
(
config
:
dict
,
parallel
:
Literal
[
False
]
=
False
)
->
AtomGraphSequential
:
# noqa
...
@
overload
def
build_E3_equivariant_model
(
config
:
dict
,
parallel
:
Literal
[
True
]
)
->
List
[
AtomGraphSequential
]:
# noqa
...
def
build_E3_equivariant_model
(
config
:
dict
,
parallel
:
bool
=
False
)
->
Union
[
AtomGraphSequential
,
List
[
AtomGraphSequential
]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers
=
OrderedDict
()
cutoff
=
config
[
KEY
.
CUTOFF
]
num_species
=
config
[
KEY
.
NUM_SPECIES
]
feature_multiplicity
=
config
[
KEY
.
NODE_FEATURE_MULTIPLICITY
]
num_convolution_layer
=
config
[
KEY
.
NUM_CONVOLUTION
]
interaction_type
=
config
[
KEY
.
INTERACTION_TYPE
]
use_bias_in_linear
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
]
lmax_node
=
config
[
KEY
.
LMAX
]
# ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if
config
[
KEY
.
LMAX_NODE
]
>
0
:
lmax_node
=
config
[
KEY
.
LMAX_NODE
]
act_radial
=
_const
.
ACTIVATION
[
config
[
KEY
.
ACTIVATION_RADIAL
]]
self_connection_pair_list
=
init_self_connection
(
config
)
irreps_manual
=
None
if
config
[
KEY
.
IRREPS_MANUAL
]
is
not
False
:
irreps_manual
=
config
[
KEY
.
IRREPS_MANUAL
]
try
:
irreps_manual
=
[
Irreps
(
irr
)
for
irr
in
irreps_manual
]
assert
len
(
irreps_manual
)
==
num_convolution_layer
+
1
except
Exception
:
raise
RuntimeError
(
'invalid irreps_manual input given'
)
conv_denominator
=
config
[
KEY
.
CONV_DENOMINATOR
]
if
not
isinstance
(
conv_denominator
,
list
):
conv_denominator
=
[
conv_denominator
]
*
num_convolution_layer
train_conv_denominator
=
config
[
KEY
.
TRAIN_DENOMINTAOR
]
edge_embedding
=
init_edge_embedding
(
config
)
irreps_filter
=
edge_embedding
.
spherical
.
irreps_out
radial_basis_num
=
edge_embedding
.
basis_function
.
num_basis
layers
.
update
({
'edge_embedding'
:
edge_embedding
})
one_hot_irreps
=
Irreps
(
f
'
{
num_species
}
x0e'
)
irreps_x
=
(
Irreps
(
f
'
{
feature_multiplicity
}
x0e'
)
if
irreps_manual
is
None
else
irreps_manual
[
0
]
)
layers
.
update
(
{
'onehot_idx_to_onehot'
:
OnehotEmbedding
(
num_classes
=
num_species
,
data_key_x
=
KEY
.
NODE_FEATURE
,
data_key_out
=
KEY
.
NODE_FEATURE
,
data_key_save
=
KEY
.
ATOM_TYPE
,
# atomic numbers
data_key_additional
=
KEY
.
NODE_ATTR
,
# one-hot embeddings
),
'onehot_to_feature_x'
:
IrrepsLinear
(
irreps_in
=
one_hot_irreps
,
irreps_out
=
irreps_x
,
data_key_in
=
KEY
.
NODE_FEATURE
,
biases
=
use_bias_in_linear
,
),
}
)
weight_nn_hidden
=
config
[
KEY
.
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
]
weight_nn_layers
=
[
radial_basis_num
]
+
weight_nn_hidden
param_interaction_block
=
{
'irreps_filter'
:
irreps_filter
,
'weight_nn_layers'
:
weight_nn_layers
,
'train_conv_denominator'
:
train_conv_denominator
,
'act_radial'
:
act_radial
,
'bias_in_linear'
:
use_bias_in_linear
,
'num_species'
:
num_species
,
'parallel'
:
parallel
,
}
interaction_builder
=
None
if
interaction_type
in
[
'nequip'
]:
act_scalar
=
{}
act_gate
=
{}
for
k
,
v
in
config
[
KEY
.
ACTIVATION_SCARLAR
].
items
():
act_scalar
[
k
]
=
_const
.
ACTIVATION_DICT
[
k
][
v
]
for
k
,
v
in
config
[
KEY
.
ACTIVATION_GATE
].
items
():
act_gate
[
k
]
=
_const
.
ACTIVATION_DICT
[
k
][
v
]
param_interaction_block
.
update
(
{
'act_scalar'
:
act_scalar
,
'act_gate'
:
act_gate
,
}
)
if
interaction_type
==
'nequip'
:
interaction_builder
=
NequIP_interaction_block
else
:
raise
ValueError
(
f
'Unknown interaction type:
{
interaction_type
}
'
)
for
t
in
range
(
num_convolution_layer
):
param_interaction_block
.
update
(
{
'irreps_x'
:
irreps_x
,
't'
:
t
,
'conv_denominator'
:
conv_denominator
[
t
],
'self_connection_pair'
:
self_connection_pair_list
[
t
],
}
)
if
interaction_type
==
'nequip'
:
parity_mode
=
'full'
fix_multiplicity
=
False
if
t
==
num_convolution_layer
-
1
:
lmax_node
=
0
parity_mode
=
'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out
=
(
util
.
infer_irreps_out
(
irreps_x
,
# type: ignore
irreps_filter
,
lmax_node
,
# type: ignore
parity_mode
,
fix_multiplicity
=
feature_multiplicity
,
)
if
irreps_manual
is
None
else
irreps_manual
[
t
+
1
]
)
irreps_out_tp
=
util
.
infer_irreps_out
(
irreps_x
,
# type: ignore
irreps_filter
,
irreps_out
.
lmax
,
# type: ignore
parity_mode
,
fix_multiplicity
,
)
else
:
raise
ValueError
(
f
'Unknown interaction type:
{
interaction_type
}
'
)
param_interaction_block
.
update
(
{
'irreps_out_tp'
:
irreps_out_tp
,
'irreps_out'
:
irreps_out
,
}
)
layers
.
update
(
interaction_builder
(
**
param_interaction_block
))
irreps_x
=
irreps_out
layers
.
update
(
init_feature_reduce
(
config
,
irreps_x
))
layers
.
update
(
{
'rescale_atomic_energy'
:
init_shift_scale
(
config
),
'reduce_total_enegy'
:
AtomReduce
(
data_key_in
=
KEY
.
ATOMIC_ENERGY
,
data_key_out
=
KEY
.
PRED_TOTAL_ENERGY
,
),
}
)
gradient_module
=
ForceStressOutputFromEdge
()
grad_key
=
gradient_module
.
get_grad_key
()
layers
.
update
({
'force_output'
:
gradient_module
})
common_args
=
{
'cutoff'
:
cutoff
,
'type_map'
:
config
[
KEY
.
TYPE_MAP
],
'modal_map'
:
config
.
get
(
KEY
.
MODAL_MAP
,
None
),
'eval_type_map'
:
False
if
parallel
else
True
,
'eval_modal_map'
:
False
if
not
config
.
get
(
KEY
.
USE_MODALITY
,
False
)
or
parallel
else
True
,
'data_key_grad'
:
grad_key
,
}
if
parallel
:
layers_list
=
_to_parallel_model
(
layers
,
config
)
return
[
AtomGraphSequential
(
patch_modules
(
layers
,
config
),
**
common_args
)
for
layers
in
layers_list
]
else
:
return
AtomGraphSequential
(
patch_modules
(
layers
,
config
),
**
common_args
)
mace-bench/3rdparty/SevenNet/sevenn/pair_d3.so
0 → 100644
View file @
73866b01
File added
mace-bench/3rdparty/SevenNet/sevenn/pair_d3_for_ase.cuda.o
0 → 100644
View file @
73866b01
File added
mace-bench/3rdparty/SevenNet/sevenn/parse_input.py
0 → 100644
View file @
73866b01
import
glob
import
os
import
warnings
from
typing
import
Any
,
Callable
,
Dict
import
torch
import
yaml
import
sevenn._const
as
_const
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
def
config_initialize
(
key
:
str
,
config
:
Dict
,
default
:
Any
,
conditions
:
Dict
,
):
# default value exist & no user input -> return default
if
key
not
in
config
.
keys
():
return
default
# No validation method exist => accept user input
user_input
=
config
[
key
]
if
key
in
conditions
:
condition
=
conditions
[
key
]
else
:
return
user_input
if
type
(
default
)
is
dict
and
isinstance
(
condition
,
dict
):
for
i_key
,
val
in
default
.
items
():
user_input
[
i_key
]
=
config_initialize
(
i_key
,
user_input
,
val
,
condition
)
return
user_input
elif
isinstance
(
condition
,
type
):
if
isinstance
(
user_input
,
condition
):
return
user_input
else
:
try
:
return
condition
(
user_input
)
# try type casting
except
ValueError
:
raise
ValueError
(
f
"Expect '
{
user_input
}
' for '
{
key
}
' is
{
condition
}
"
)
elif
isinstance
(
condition
,
Callable
)
and
condition
(
user_input
):
return
user_input
else
:
raise
ValueError
(
f
"Given input '
{
user_input
}
' for '
{
key
}
' is not valid"
)
def
init_model_config
(
config
:
Dict
):
# defaults = _const.model_defaults(config)
model_meta
=
{}
# init complicated ones
if
KEY
.
CHEMICAL_SPECIES
not
in
config
.
keys
():
raise
ValueError
(
'required key chemical_species not exist'
)
input_chem
=
config
[
KEY
.
CHEMICAL_SPECIES
]
if
isinstance
(
input_chem
,
str
)
and
input_chem
.
lower
()
==
'auto'
:
model_meta
[
KEY
.
CHEMICAL_SPECIES
]
=
'auto'
model_meta
[
KEY
.
NUM_SPECIES
]
=
'auto'
model_meta
[
KEY
.
TYPE_MAP
]
=
'auto'
elif
isinstance
(
input_chem
,
str
)
and
'univ'
in
input_chem
.
lower
():
model_meta
.
update
(
util
.
chemical_species_preprocess
([],
universal
=
True
))
else
:
if
isinstance
(
input_chem
,
list
)
and
all
(
isinstance
(
x
,
str
)
for
x
in
input_chem
):
pass
elif
isinstance
(
input_chem
,
str
):
input_chem
=
(
input_chem
.
replace
(
'-'
,
','
).
replace
(
' '
,
','
).
split
(
','
)
)
input_chem
=
[
chem
for
chem
in
input_chem
if
len
(
chem
)
!=
0
]
else
:
raise
ValueError
(
f
'given
{
KEY
.
CHEMICAL_SPECIES
}
input is strange'
)
model_meta
.
update
(
util
.
chemical_species_preprocess
(
input_chem
))
# deprecation warnings
if
KEY
.
AVG_NUM_NEIGH
in
config
:
warnings
.
warn
(
"key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'."
' We use the default, the average number of neighbors in the'
' dataset, if not provided.'
,
UserWarning
,
)
config
.
pop
(
KEY
.
AVG_NUM_NEIGH
)
if
KEY
.
TRAIN_AVG_NUM_NEIGH
in
config
:
warnings
.
warn
(
"key 'train_avg_num_neigh' is deprecated. Please use"
" 'train_denominator'. We overwrite train_denominator as given"
' train_avg_num_neigh'
,
UserWarning
,
)
config
[
KEY
.
TRAIN_DENOMINTAOR
]
=
config
[
KEY
.
TRAIN_AVG_NUM_NEIGH
]
config
.
pop
(
KEY
.
TRAIN_AVG_NUM_NEIGH
)
if
KEY
.
OPTIMIZE_BY_REDUCE
in
config
:
warnings
.
warn
(
"key 'optimize_by_reduce' is deprecated. Always true"
,
UserWarning
,
)
config
.
pop
(
KEY
.
OPTIMIZE_BY_REDUCE
)
# init simpler ones
for
key
,
default
in
_const
.
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
.
items
():
model_meta
[
key
]
=
config_initialize
(
key
,
config
,
default
,
_const
.
MODEL_CONFIG_CONDITION
)
unknown_keys
=
[
key
for
key
in
config
.
keys
()
if
key
not
in
model_meta
.
keys
()
]
if
len
(
unknown_keys
)
!=
0
:
warnings
.
warn
(
f
'Unexpected model keys:
{
unknown_keys
}
will be ignored'
,
UserWarning
,
)
return
model_meta
def
init_train_config
(
config
:
Dict
):
train_meta
=
{}
# defaults = _const.train_defaults(config)
try
:
device_input
=
config
[
KEY
.
DEVICE
]
train_meta
[
KEY
.
DEVICE
]
=
torch
.
device
(
device_input
)
except
KeyError
:
train_meta
[
KEY
.
DEVICE
]
=
(
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
)
train_meta
[
KEY
.
DEVICE
]
=
str
(
train_meta
[
KEY
.
DEVICE
])
# init simpler ones
for
key
,
default
in
_const
.
DEFAULT_TRAINING_CONFIG
.
items
():
train_meta
[
key
]
=
config_initialize
(
key
,
config
,
default
,
_const
.
TRAINING_CONFIG_CONDITION
)
if
KEY
.
CONTINUE
in
config
.
keys
():
cnt_dct
=
config
[
KEY
.
CONTINUE
]
if
KEY
.
CHECKPOINT
not
in
cnt_dct
.
keys
():
raise
ValueError
(
'no checkpoint is given in continue'
)
checkpoint
=
cnt_dct
[
KEY
.
CHECKPOINT
]
if
os
.
path
.
isfile
(
checkpoint
):
checkpoint_file
=
checkpoint
else
:
checkpoint_file
=
util
.
pretrained_name_to_path
(
checkpoint
)
train_meta
[
KEY
.
CONTINUE
].
update
({
KEY
.
CHECKPOINT
:
checkpoint_file
})
unknown_keys
=
[
key
for
key
in
config
.
keys
()
if
key
not
in
train_meta
.
keys
()
]
if
len
(
unknown_keys
)
!=
0
:
warnings
.
warn
(
f
'Unexpected train keys:
{
unknown_keys
}
will be ignored'
,
UserWarning
,
)
return
train_meta
def
init_data_config
(
config
:
Dict
):
data_meta
=
{}
# defaults = _const.data_defaults(config)
load_data_keys
=
[]
for
k
in
config
:
if
k
.
startswith
(
'load_'
)
and
k
.
endswith
(
'_path'
):
load_data_keys
.
append
(
k
)
for
load_data_key
in
load_data_keys
:
if
load_data_key
in
config
.
keys
():
inp
=
config
[
load_data_key
]
extended
=
[]
if
type
(
inp
)
not
in
[
str
,
list
]:
raise
ValueError
(
f
'unexpected input
{
inp
}
for sturcture_list'
)
if
type
(
inp
)
is
str
:
extended
=
glob
.
glob
(
inp
)
elif
type
(
inp
)
is
list
:
for
i
in
inp
:
if
isinstance
(
i
,
str
):
extended
.
extend
(
glob
.
glob
(
i
))
elif
isinstance
(
i
,
dict
):
extended
.
append
(
i
)
if
len
(
extended
)
==
0
:
raise
ValueError
(
f
'Cannot find
{
inp
}
for
{
load_data_key
}
'
+
' or path is not given'
)
data_meta
[
load_data_key
]
=
extended
else
:
data_meta
[
load_data_key
]
=
False
for
key
,
default
in
_const
.
DEFAULT_DATA_CONFIG
.
items
():
data_meta
[
key
]
=
config_initialize
(
key
,
config
,
default
,
_const
.
DATA_CONFIG_CONDITION
)
unknown_keys
=
[
key
for
key
in
config
.
keys
()
if
key
not
in
data_meta
.
keys
()
]
if
len
(
unknown_keys
)
!=
0
:
warnings
.
warn
(
f
'Unexpected data keys:
{
unknown_keys
}
will be ignored'
,
UserWarning
,
)
return
data_meta
def
read_config_yaml
(
filename
:
str
,
return_separately
:
bool
=
False
):
with
open
(
filename
,
'r'
)
as
fstream
:
inputs
=
yaml
.
safe_load
(
fstream
)
model_meta
,
train_meta
,
data_meta
=
{},
{},
{}
for
key
,
config
in
inputs
.
items
():
if
key
==
'model'
:
model_meta
=
init_model_config
(
config
)
elif
key
==
'train'
:
train_meta
=
init_train_config
(
config
)
elif
key
==
'data'
:
data_meta
=
init_data_config
(
config
)
else
:
raise
ValueError
(
f
'Unexpected input
{
key
}
given'
)
if
return_separately
:
return
model_meta
,
train_meta
,
data_meta
else
:
model_meta
.
update
(
train_meta
)
model_meta
.
update
(
data_meta
)
return
model_meta
def
main
():
filename
=
'./input.yaml'
read_config_yaml
(
filename
)
if
__name__
==
'__main__'
:
main
()
mace-bench/3rdparty/SevenNet/sevenn/py.typed
0 → 100644
View file @
73866b01
mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py
0 → 100644
View file @
73866b01
import
warnings
from
.logger
import
*
# noqa: F403
warnings
.
warn
(
'Please use sevenn.logger instead of sevenn.sevenn_logger'
,
DeprecationWarning
,
stacklevel
=
2
)
mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py
0 → 100644
View file @
73866b01
import
warnings
from
.calculator
import
*
# noqa: F403
warnings
.
warn
(
'Please use sevenn.calculator instead of sevenn.sevennet_calculator'
,
DeprecationWarning
,
stacklevel
=
2
)
mace-bench/3rdparty/SevenNet/sevenn/util.py
0 → 100644
View file @
73866b01
import
os
import
os.path
as
osp
import
pathlib
import
shutil
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
requests
import
torch
import
torch.nn
from
e3nn.o3
import
FullTensorProduct
,
Irreps
from
tqdm
import
tqdm
import
sevenn._const
as
_const
import
sevenn._keys
as
KEY
def
to_atom_graph_list
(
atom_graph_batch
):
"""
torch_geometric batched data to separate list
original to_data_list() by PyG is not enough since
it doesn't handle inferred tensors
"""
is_stress
=
KEY
.
PRED_STRESS
in
atom_graph_batch
data_list
=
atom_graph_batch
.
to_data_list
()
indices
=
atom_graph_batch
[
KEY
.
NUM_ATOMS
].
tolist
()
atomic_energy_list
=
torch
.
split
(
atom_graph_batch
[
KEY
.
ATOMIC_ENERGY
],
indices
)
inferred_total_energy_list
=
torch
.
unbind
(
atom_graph_batch
[
KEY
.
PRED_TOTAL_ENERGY
]
)
inferred_force_list
=
torch
.
split
(
atom_graph_batch
[
KEY
.
PRED_FORCE
],
indices
)
inferred_stress_list
=
None
if
is_stress
:
inferred_stress_list
=
torch
.
unbind
(
atom_graph_batch
[
KEY
.
PRED_STRESS
])
for
i
,
data
in
enumerate
(
data_list
):
data
[
KEY
.
ATOMIC_ENERGY
]
=
atomic_energy_list
[
i
]
data
[
KEY
.
PRED_TOTAL_ENERGY
]
=
inferred_total_energy_list
[
i
]
data
[
KEY
.
PRED_FORCE
]
=
inferred_force_list
[
i
]
# To fit with KEY.STRESS (ref) format
if
is_stress
and
inferred_stress_list
is
not
None
:
data
[
KEY
.
PRED_STRESS
]
=
torch
.
unsqueeze
(
inferred_stress_list
[
i
],
0
)
return
data_list
def
error_recorder_from_loss_functions
(
loss_functions
):
from
.error_recorder
import
ErrorRecorder
,
MAError
,
RMSError
,
get_err_type
from
.train.loss
import
ForceLoss
,
PerAtomEnergyLoss
,
StressLoss
metrics
=
[]
for
loss_function
,
_
in
loss_functions
:
ref_key
=
loss_function
.
ref_key
pred_key
=
loss_function
.
pred_key
# unit = loss_function.unit
criterion
=
loss_function
.
criterion
name
=
loss_function
.
name
base
=
None
if
type
(
loss_function
)
is
PerAtomEnergyLoss
:
base
=
get_err_type
(
'Energy'
)
elif
type
(
loss_function
)
is
ForceLoss
:
base
=
get_err_type
(
'Force'
)
elif
type
(
loss_function
)
is
StressLoss
:
base
=
get_err_type
(
'Stress'
)
else
:
base
=
{}
base
[
'name'
]
=
name
base
[
'ref_key'
]
=
ref_key
base
[
'pred_key'
]
=
pred_key
if
type
(
criterion
)
is
torch
.
nn
.
MSELoss
:
base
[
'name'
]
=
base
[
'name'
]
+
'_RMSE'
metrics
.
append
(
RMSError
(
**
base
))
elif
type
(
criterion
)
is
torch
.
nn
.
L1Loss
:
metrics
.
append
(
MAError
(
**
base
))
return
ErrorRecorder
(
metrics
)
def
onehot_to_chem
(
one_hot_indices
:
List
[
int
],
type_map
:
Dict
[
int
,
int
]):
from
ase.data
import
chemical_symbols
type_map_rev
=
{
v
:
k
for
k
,
v
in
type_map
.
items
()}
return
[
chemical_symbols
[
type_map_rev
[
x
]]
for
x
in
one_hot_indices
]
def
model_from_checkpoint
(
checkpoint
:
str
,
)
->
Tuple
[
torch
.
nn
.
Module
,
Dict
]:
cp
=
load_checkpoint
(
checkpoint
)
model
=
cp
.
build_model
()
return
model
,
cp
.
config
def
model_from_checkpoint_with_backend
(
checkpoint
:
str
,
backend
:
str
=
'e3nn'
,
)
->
Tuple
[
torch
.
nn
.
Module
,
Dict
]:
cp
=
load_checkpoint
(
checkpoint
)
model
=
cp
.
build_model
(
backend
)
return
model
,
cp
.
config
def
unlabeled_atoms_to_input
(
atoms
,
cutoff
:
float
,
grad_key
:
str
=
KEY
.
EDGE_VEC
):
from
.atom_graph_data
import
AtomGraphData
from
.train.dataload
import
unlabeled_atoms_to_graph
atom_graph
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms
,
cutoff
)
)
atom_graph
[
grad_key
].
requires_grad_
(
True
)
atom_graph
[
KEY
.
BATCH
]
=
torch
.
zeros
([
0
])
return
atom_graph
def
chemical_species_preprocess
(
input_chem
:
List
[
str
],
universal
:
bool
=
False
):
from
ase.data
import
atomic_numbers
,
chemical_symbols
from
.nn.node_embedding
import
get_type_mapper_from_specie
config
=
{}
if
not
universal
:
input_chem
=
list
(
set
(
input_chem
))
chemical_specie
=
sorted
([
x
.
strip
()
for
x
in
input_chem
])
config
[
KEY
.
CHEMICAL_SPECIES
]
=
chemical_specie
config
[
KEY
.
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
]
=
[
atomic_numbers
[
x
]
for
x
in
chemical_specie
]
config
[
KEY
.
NUM_SPECIES
]
=
len
(
chemical_specie
)
config
[
KEY
.
TYPE_MAP
]
=
get_type_mapper_from_specie
(
chemical_specie
)
else
:
config
[
KEY
.
CHEMICAL_SPECIES
]
=
chemical_symbols
len_univ
=
len
(
chemical_symbols
)
config
[
KEY
.
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
]
=
list
(
range
(
len_univ
))
config
[
KEY
.
NUM_SPECIES
]
=
len_univ
config
[
KEY
.
TYPE_MAP
]
=
{
z
:
z
for
z
in
range
(
len_univ
)}
return
config
def
dtype_correct
(
v
:
Union
[
np
.
ndarray
,
torch
.
Tensor
,
int
,
float
],
float_dtype
:
torch
.
dtype
=
torch
.
float32
,
int_dtype
:
torch
.
dtype
=
torch
.
int64
,
):
if
isinstance
(
v
,
np
.
ndarray
):
if
np
.
issubdtype
(
v
.
dtype
,
np
.
floating
):
return
torch
.
from_numpy
(
v
).
to
(
float_dtype
)
elif
np
.
issubdtype
(
v
.
dtype
,
np
.
integer
):
return
torch
.
from_numpy
(
v
).
to
(
int_dtype
)
elif
isinstance
(
v
,
torch
.
Tensor
):
if
v
.
dtype
.
is_floating_point
:
return
v
.
to
(
float_dtype
)
# convert to specified float dtype
else
:
# assuming non-floating point tensors are integers
return
v
.
to
(
int_dtype
)
# convert to specified int dtype
else
:
# scalar values
if
isinstance
(
v
,
int
):
return
torch
.
tensor
(
v
,
dtype
=
int_dtype
)
elif
isinstance
(
v
,
float
):
return
torch
.
tensor
(
v
,
dtype
=
float_dtype
)
else
:
# Not numeric
return
v
def
infer_irreps_out
(
irreps_x
:
Irreps
,
irreps_operand
:
Irreps
,
drop_l
:
Union
[
bool
,
int
]
=
False
,
parity_mode
:
str
=
'full'
,
fix_multiplicity
:
Union
[
bool
,
int
]
=
False
,
):
assert
parity_mode
in
[
'full'
,
'even'
,
'sph'
]
# (mul, (ir, p))
irreps_out
=
FullTensorProduct
(
irreps_x
,
irreps_operand
).
irreps_out
.
simplify
()
new_irreps_elem
=
[]
for
mul
,
(
l
,
p
)
in
irreps_out
:
# noqa
elem
=
(
mul
,
(
l
,
p
))
if
drop_l
is
not
False
and
l
>
drop_l
:
continue
if
parity_mode
==
'even'
and
p
==
-
1
:
continue
elif
parity_mode
==
'sph'
and
p
!=
(
-
1
)
**
l
:
continue
if
fix_multiplicity
:
elem
=
(
fix_multiplicity
,
(
l
,
p
))
new_irreps_elem
.
append
(
elem
)
return
Irreps
(
new_irreps_elem
)
def
download_checkpoint
(
path
:
str
,
url
:
str
):
fname
=
osp
.
basename
(
path
)
temp_path
=
path
+
'.partial'
try
:
# raises permission error if fails
os
.
makedirs
(
osp
.
dirname
(
path
),
exist_ok
=
True
)
response
=
requests
.
get
(
url
,
stream
=
True
,
timeout
=
30
)
response
.
raise_for_status
()
# Raise exception for bad status codes
total_size
=
int
(
response
.
headers
.
get
(
'content-length'
,
0
))
block_size
=
1024
# 1 KB chunks
progress_bar
=
tqdm
(
total
=
total_size
,
unit
=
'B'
,
unit_scale
=
True
,
desc
=
f
'Downloading
{
fname
}
'
,
)
with
open
(
temp_path
,
'wb'
)
as
file
:
for
data
in
response
.
iter_content
(
block_size
):
progress_bar
.
update
(
len
(
data
))
file
.
write
(
data
)
progress_bar
.
close
()
shutil
.
move
(
temp_path
,
path
)
print
(
f
'Checkpoint downloaded:
{
path
}
'
)
return
path
except
PermissionError
:
raise
except
Exception
as
e
:
# Clean up partial downloads on failure
# May not work as errors handled internally by tqdm etc.
print
(
f
'Download failed:
{
str
(
e
)
}
'
)
if
os
.
path
.
exists
(
temp_path
):
print
(
f
'Cleaning up partial download:
{
temp_path
}
'
)
os
.
remove
(
temp_path
)
raise
def
pretrained_name_to_path
(
name
:
str
)
->
str
:
name
=
name
.
lower
()
heads
=
[
'sevennet'
,
'7net'
]
checkpoint_path
=
None
url
=
None
if
(
# TODO: regex
name
in
[
f
'
{
n
}
-0_11july2024'
for
n
in
heads
]
or
name
in
[
f
'
{
n
}
-0_11jul2024'
for
n
in
heads
]
or
name
in
[
'sevennet-0'
,
'7net-0'
]
):
checkpoint_path
=
_const
.
SEVENNET_0_11Jul2024
elif
name
in
[
f
'
{
n
}
-0_22may2024'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_0_22May2024
elif
name
in
[
f
'
{
n
}
-l3i5'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_l3i5
elif
name
in
[
f
'
{
n
}
-mf-0'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_MF_0
elif
name
in
[
f
'
{
n
}
-mf-ompa'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_MF_ompa
elif
name
in
[
f
'
{
n
}
-omat'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_omat
else
:
raise
ValueError
(
'Not a valid pretrained model name'
)
url
=
_const
.
CHECKPOINT_DOWNLOAD_LINKS
.
get
(
checkpoint_path
)
paths
=
[
checkpoint_path
,
checkpoint_path
.
replace
(
_const
.
_prefix
,
osp
.
expanduser
(
'~/.cache/sevennet'
)),
]
for
path
in
paths
:
if
osp
.
exists
(
path
):
return
path
# File not found check url and try download
if
url
is
None
:
raise
FileNotFoundError
(
checkpoint_path
)
try
:
return
download_checkpoint
(
paths
[
0
],
url
)
# 7net package path
except
PermissionError
:
return
download_checkpoint
(
paths
[
1
],
url
)
# ~/.cache
def
load_checkpoint
(
checkpoint
:
Union
[
pathlib
.
Path
,
str
]):
from
sevenn.checkpoint
import
SevenNetCheckpoint
suggests
=
[
'7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat'
]
if
osp
.
isfile
(
checkpoint
):
checkpoint_path
=
checkpoint
else
:
try
:
checkpoint_path
=
pretrained_name_to_path
(
str
(
checkpoint
))
except
ValueError
:
raise
ValueError
(
f
'Given
{
checkpoint
}
is not exists and not a pre-trained name.
\n
'
f
'Valid pretrained model names:
{
suggests
}
'
)
return
SevenNetCheckpoint
(
checkpoint_path
)
def
unique_filepath
(
filepath
:
str
)
->
str
:
if
not
os
.
path
.
isfile
(
filepath
):
return
filepath
else
:
dirname
=
os
.
path
.
dirname
(
filepath
)
fname
=
os
.
path
.
basename
(
filepath
)
name
,
ext
=
os
.
path
.
splitext
(
fname
)
cnt
=
0
new_name
=
f
'
{
name
}{
cnt
}{
ext
}
'
new_path
=
os
.
path
.
join
(
dirname
,
new_name
)
while
os
.
path
.
exists
(
new_path
):
cnt
+=
1
new_name
=
f
'
{
name
}{
cnt
}{
ext
}
'
new_path
=
os
.
path
.
join
(
dirname
,
new_name
)
return
new_path
def
get_error_recorder
(
recorder_tuples
:
List
[
Tuple
[
str
,
str
]]
=
[
(
'Energy'
,
'RMSE'
),
(
'Force'
,
'RMSE'
),
(
'Stress'
,
'RMSE'
),
(
'Energy'
,
'MAE'
),
(
'Force'
,
'MAE'
),
(
'Stress'
,
'MAE'
),
],
):
# TODO add criterion argument and loss recorder selections
import
sevenn.error_recorder
as
error_recorder
config
=
recorder_tuples
err_metrics
=
[]
for
err_type
,
metric_name
in
config
:
metric_kwargs
=
error_recorder
.
get_err_type
(
err_type
).
copy
()
metric_kwargs
[
'name'
]
+=
f
'_
{
metric_name
}
'
metric_cls
=
error_recorder
.
ErrorRecorder
.
METRIC_DICT
[
metric_name
]
err_metrics
.
append
(
metric_cls
(
**
metric_kwargs
))
return
error_recorder
.
ErrorRecorder
(
err_metrics
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment