Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
nivren
ICT-CSP
Commits
73866b01
Unverified
Commit
73866b01
authored
Aug 24, 2025
by
zcxzcx1
Committed by
GitHub
Aug 24, 2025
Browse files
Add files via upload
parent
ca86f720
Changes
18
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
3985 additions
and
0 deletions
+3985
-0
mace-bench/3rdparty/SevenNet/sevenn/__init__.py
mace-bench/3rdparty/SevenNet/sevenn/__init__.py
+13
-0
mace-bench/3rdparty/SevenNet/sevenn/_const.py
mace-bench/3rdparty/SevenNet/sevenn/_const.py
+310
-0
mace-bench/3rdparty/SevenNet/sevenn/_keys.py
mace-bench/3rdparty/SevenNet/sevenn/_keys.py
+226
-0
mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py
mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py
+75
-0
mace-bench/3rdparty/SevenNet/sevenn/build.ninja
mace-bench/3rdparty/SevenNet/sevenn/build.ninja
+33
-0
mace-bench/3rdparty/SevenNet/sevenn/calculator.py
mace-bench/3rdparty/SevenNet/sevenn/calculator.py
+846
-0
mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py
mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py
+552
-0
mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py
mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py
+430
-0
mace-bench/3rdparty/SevenNet/sevenn/logger.py
mace-bench/3rdparty/SevenNet/sevenn/logger.py
+336
-0
mace-bench/3rdparty/SevenNet/sevenn/logo_ascii
mace-bench/3rdparty/SevenNet/sevenn/logo_ascii
+20
-0
mace-bench/3rdparty/SevenNet/sevenn/model_build.py
mace-bench/3rdparty/SevenNet/sevenn/model_build.py
+556
-0
mace-bench/3rdparty/SevenNet/sevenn/pair_d3.so
mace-bench/3rdparty/SevenNet/sevenn/pair_d3.so
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/pair_d3_for_ase.cuda.o
mace-bench/3rdparty/SevenNet/sevenn/pair_d3_for_ase.cuda.o
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/parse_input.py
mace-bench/3rdparty/SevenNet/sevenn/parse_input.py
+246
-0
mace-bench/3rdparty/SevenNet/sevenn/py.typed
mace-bench/3rdparty/SevenNet/sevenn/py.typed
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py
mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py
+6
-0
mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py
mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py
+6
-0
mace-bench/3rdparty/SevenNet/sevenn/util.py
mace-bench/3rdparty/SevenNet/sevenn/util.py
+330
-0
No files found.
mace-bench/3rdparty/SevenNet/sevenn/__init__.py
0 → 100644
View file @
73866b01
from
importlib.metadata
import
version
from
packaging.version
import
Version
__version__
=
version
(
'sevenn'
)
from
e3nn
import
__version__
as
e3nn_ver
if
Version
(
e3nn_ver
)
<
Version
(
'0.5.0'
):
raise
ValueError
(
'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient '
'convention.'
)
mace-bench/3rdparty/SevenNet/sevenn/_const.py
0 → 100644
View file @
73866b01
import
os
from
enum
import
Enum
from
typing
import
Dict
import
torch
import
sevenn._keys
as
KEY
from
sevenn.nn.activation
import
ShiftedSoftPlus
NUM_UNIV_ELEMENT
=
119
# Z = 0 ~ 118
IMPLEMENTED_RADIAL_BASIS
=
[
'bessel'
]
IMPLEMENTED_CUTOFF_FUNCTION
=
[
'poly_cut'
,
'XPLOR'
]
# TODO: support None. This became difficult because of parallel model
IMPLEMENTED_SELF_CONNECTION_TYPE
=
[
'nequip'
,
'linear'
]
IMPLEMENTED_INTERACTION_TYPE
=
[
'nequip'
]
IMPLEMENTED_SHIFT
=
[
'per_atom_energy_mean'
,
'elemwise_reference_energies'
]
IMPLEMENTED_SCALE
=
[
'force_rms'
,
'per_atom_energy_std'
,
'elemwise_force_rms'
]
SUPPORTING_METRICS
=
[
'RMSE'
,
'ComponentRMSE'
,
'MAE'
,
'Loss'
]
SUPPORTING_ERROR_TYPES
=
[
'TotalEnergy'
,
'Energy'
,
'Force'
,
'Stress'
,
'Stress_GPa'
,
'TotalLoss'
,
]
IMPLEMENTED_MODEL
=
[
'E3_equivariant_model'
]
# string input to real torch function
ACTIVATION
=
{
'relu'
:
torch
.
nn
.
functional
.
relu
,
'silu'
:
torch
.
nn
.
functional
.
silu
,
'tanh'
:
torch
.
tanh
,
'abs'
:
torch
.
abs
,
'ssp'
:
ShiftedSoftPlus
,
'sigmoid'
:
torch
.
sigmoid
,
'elu'
:
torch
.
nn
.
functional
.
elu
,
}
ACTIVATION_FOR_EVEN
=
{
'ssp'
:
ShiftedSoftPlus
,
'silu'
:
torch
.
nn
.
functional
.
silu
,
}
ACTIVATION_FOR_ODD
=
{
'tanh'
:
torch
.
tanh
,
'abs'
:
torch
.
abs
}
ACTIVATION_DICT
=
{
'e'
:
ACTIVATION_FOR_EVEN
,
'o'
:
ACTIVATION_FOR_ODD
}
_prefix
=
os
.
path
.
abspath
(
f
'
{
os
.
path
.
dirname
(
__file__
)
}
/pretrained_potentials'
)
SEVENNET_0_11Jul2024
=
f
'
{
_prefix
}
/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth'
SEVENNET_0_22May2024
=
f
'
{
_prefix
}
/SevenNet_0__22May2024/checkpoint_sevennet_0.pth'
SEVENNET_l3i5
=
f
'
{
_prefix
}
/SevenNet_l3i5/checkpoint_l3i5.pth'
SEVENNET_MF_0
=
f
'
{
_prefix
}
/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth'
SEVENNET_MF_ompa
=
f
'
{
_prefix
}
/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth'
SEVENNET_omat
=
f
'
{
_prefix
}
/SevenNet_omat/checkpoint_sevennet_omat.pth'
_git_prefix
=
'https://github.com/MDIL-SNU/SevenNet/releases/download'
CHECKPOINT_DOWNLOAD_LINKS
=
{
SEVENNET_MF_ompa
:
f
'
{
_git_prefix
}
/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth'
,
SEVENNET_omat
:
f
'
{
_git_prefix
}
/v0.11.0.cp/checkpoint_sevennet_omat.pth'
,
}
# to avoid torch script to compile torch_geometry.data
AtomGraphDataType
=
Dict
[
str
,
torch
.
Tensor
]
class
LossType
(
Enum
):
# only used for train_v1, do not use it afterwards
ENERGY
=
'energy'
# eV or eV/atom
FORCE
=
'force'
# eV/A
STRESS
=
'stress'
# kB
def
error_record_condition
(
x
):
if
type
(
x
)
is
not
list
:
return
False
for
v
in
x
:
if
type
(
v
)
is
not
list
or
len
(
v
)
!=
2
:
return
False
if
v
[
0
]
not
in
SUPPORTING_ERROR_TYPES
:
return
False
if
v
[
0
]
==
'TotalLoss'
:
continue
if
v
[
1
]
not
in
SUPPORTING_METRICS
:
return
False
return
True
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
=
{
KEY
.
CUTOFF
:
4.5
,
KEY
.
NODE_FEATURE_MULTIPLICITY
:
32
,
KEY
.
IRREPS_MANUAL
:
False
,
KEY
.
LMAX
:
1
,
KEY
.
LMAX_EDGE
:
-
1
,
# -1 means lmax_edge = lmax
KEY
.
LMAX_NODE
:
-
1
,
# -1 means lmax_node = lmax
KEY
.
IS_PARITY
:
True
,
KEY
.
NUM_CONVOLUTION
:
3
,
KEY
.
RADIAL_BASIS
:
{
KEY
.
RADIAL_BASIS_NAME
:
'bessel'
,
},
KEY
.
CUTOFF_FUNCTION
:
{
KEY
.
CUTOFF_FUNCTION_NAME
:
'poly_cut'
,
},
KEY
.
ACTIVATION_RADIAL
:
'silu'
,
KEY
.
ACTIVATION_SCARLAR
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
KEY
.
ACTIVATION_GATE
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
KEY
.
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
:
[
64
,
64
],
# KEY.AVG_NUM_NEIGH: True, # deprecated
# KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
KEY
.
CONV_DENOMINATOR
:
'avg_num_neigh'
,
KEY
.
TRAIN_DENOMINTAOR
:
False
,
KEY
.
TRAIN_SHIFT_SCALE
:
False
,
# KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY
.
USE_BIAS_IN_LINEAR
:
False
,
KEY
.
USE_MODAL_NODE_EMBEDDING
:
False
,
KEY
.
USE_MODAL_SELF_INTER_INTRO
:
False
,
KEY
.
USE_MODAL_SELF_INTER_OUTRO
:
False
,
KEY
.
USE_MODAL_OUTPUT_BLOCK
:
False
,
KEY
.
READOUT_AS_FCN
:
False
,
# Applied af readout as fcn is True
KEY
.
READOUT_FCN_HIDDEN_NEURONS
:
[
30
,
30
],
KEY
.
READOUT_FCN_ACTIVATION
:
'relu'
,
KEY
.
SELF_CONNECTION_TYPE
:
'nequip'
,
KEY
.
INTERACTION_TYPE
:
'nequip'
,
KEY
.
_NORMALIZE_SPH
:
True
,
KEY
.
CUEQUIVARIANCE_CONFIG
:
{},
}
# Basically, "If provided, it should be type of ..."
MODEL_CONFIG_CONDITION
=
{
KEY
.
NODE_FEATURE_MULTIPLICITY
:
int
,
KEY
.
LMAX
:
int
,
KEY
.
LMAX_EDGE
:
int
,
KEY
.
LMAX_NODE
:
int
,
KEY
.
IS_PARITY
:
bool
,
KEY
.
RADIAL_BASIS
:
{
KEY
.
RADIAL_BASIS_NAME
:
lambda
x
:
x
in
IMPLEMENTED_RADIAL_BASIS
,
},
KEY
.
CUTOFF_FUNCTION
:
{
KEY
.
CUTOFF_FUNCTION_NAME
:
lambda
x
:
x
in
IMPLEMENTED_CUTOFF_FUNCTION
,
},
KEY
.
CUTOFF
:
float
,
KEY
.
NUM_CONVOLUTION
:
int
,
KEY
.
CONV_DENOMINATOR
:
lambda
x
:
isinstance
(
x
,
float
)
or
x
in
[
'avg_num_neigh'
,
'sqrt_avg_num_neigh'
,
],
KEY
.
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
:
list
,
KEY
.
TRAIN_SHIFT_SCALE
:
bool
,
KEY
.
TRAIN_DENOMINTAOR
:
bool
,
KEY
.
USE_BIAS_IN_LINEAR
:
bool
,
KEY
.
USE_MODAL_NODE_EMBEDDING
:
bool
,
KEY
.
USE_MODAL_SELF_INTER_INTRO
:
bool
,
KEY
.
USE_MODAL_SELF_INTER_OUTRO
:
bool
,
KEY
.
USE_MODAL_OUTPUT_BLOCK
:
bool
,
KEY
.
READOUT_AS_FCN
:
bool
,
KEY
.
READOUT_FCN_HIDDEN_NEURONS
:
list
,
KEY
.
READOUT_FCN_ACTIVATION
:
str
,
KEY
.
ACTIVATION_RADIAL
:
str
,
KEY
.
SELF_CONNECTION_TYPE
:
lambda
x
:
(
x
in
IMPLEMENTED_SELF_CONNECTION_TYPE
or
(
isinstance
(
x
,
list
)
and
all
(
sc
in
IMPLEMENTED_SELF_CONNECTION_TYPE
for
sc
in
x
)
)
),
KEY
.
INTERACTION_TYPE
:
lambda
x
:
x
in
IMPLEMENTED_INTERACTION_TYPE
,
KEY
.
_NORMALIZE_SPH
:
bool
,
KEY
.
CUEQUIVARIANCE_CONFIG
:
dict
,
}
def
model_defaults
(
config
):
defaults
=
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
if
KEY
.
READOUT_AS_FCN
not
in
config
:
config
[
KEY
.
READOUT_AS_FCN
]
=
defaults
[
KEY
.
READOUT_AS_FCN
]
if
config
[
KEY
.
READOUT_AS_FCN
]
is
False
:
defaults
.
pop
(
KEY
.
READOUT_FCN_ACTIVATION
,
None
)
defaults
.
pop
(
KEY
.
READOUT_FCN_HIDDEN_NEURONS
,
None
)
return
defaults
DEFAULT_DATA_CONFIG
=
{
KEY
.
DTYPE
:
'single'
,
KEY
.
DATA_FORMAT
:
'ase'
,
KEY
.
DATA_FORMAT_ARGS
:
{},
KEY
.
SAVE_DATASET
:
False
,
KEY
.
SAVE_BY_LABEL
:
False
,
KEY
.
SAVE_BY_TRAIN_VALID
:
False
,
KEY
.
RATIO
:
0.0
,
KEY
.
BATCH_SIZE
:
6
,
KEY
.
PREPROCESS_NUM_CORES
:
1
,
KEY
.
COMPUTE_STATISTICS
:
True
,
KEY
.
DATASET_TYPE
:
'graph'
,
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY
.
USE_MODAL_WISE_SHIFT
:
False
,
KEY
.
USE_MODAL_WISE_SCALE
:
False
,
KEY
.
SHIFT
:
'per_atom_energy_mean'
,
KEY
.
SCALE
:
'force_rms'
,
# KEY.DATA_SHUFFLE: True,
# KEY.DATA_WEIGHT: False,
# KEY.DATA_MODALITY: False,
}
DATA_CONFIG_CONDITION
=
{
KEY
.
DTYPE
:
str
,
KEY
.
DATA_FORMAT
:
str
,
KEY
.
DATA_FORMAT_ARGS
:
dict
,
KEY
.
SAVE_DATASET
:
str
,
KEY
.
SAVE_BY_LABEL
:
bool
,
KEY
.
SAVE_BY_TRAIN_VALID
:
bool
,
KEY
.
RATIO
:
float
,
KEY
.
BATCH_SIZE
:
int
,
KEY
.
PREPROCESS_NUM_CORES
:
int
,
KEY
.
DATASET_TYPE
:
lambda
x
:
x
in
[
'graph'
,
'atoms'
],
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
KEY
.
SHIFT
:
lambda
x
:
type
(
x
)
in
[
float
,
list
]
or
x
in
IMPLEMENTED_SHIFT
,
KEY
.
SCALE
:
lambda
x
:
type
(
x
)
in
[
float
,
list
]
or
x
in
IMPLEMENTED_SCALE
,
KEY
.
USE_MODAL_WISE_SHIFT
:
bool
,
KEY
.
USE_MODAL_WISE_SCALE
:
bool
,
# KEY.DATA_SHUFFLE: bool,
KEY
.
COMPUTE_STATISTICS
:
bool
,
# KEY.DATA_WEIGHT: bool,
# KEY.DATA_MODALITY: bool,
}
def
data_defaults
(
config
):
defaults
=
DEFAULT_DATA_CONFIG
if
KEY
.
LOAD_VALIDSET
in
config
:
defaults
.
pop
(
KEY
.
RATIO
,
None
)
return
defaults
DEFAULT_TRAINING_CONFIG
=
{
KEY
.
RANDOM_SEED
:
1
,
KEY
.
EPOCH
:
300
,
KEY
.
LOSS
:
'mse'
,
KEY
.
LOSS_PARAM
:
{},
KEY
.
OPTIMIZER
:
'adam'
,
KEY
.
OPTIM_PARAM
:
{},
KEY
.
SCHEDULER
:
'exponentiallr'
,
KEY
.
SCHEDULER_PARAM
:
{},
KEY
.
FORCE_WEIGHT
:
0.1
,
KEY
.
STRESS_WEIGHT
:
1e-6
,
# SIMPLE-NN default
KEY
.
PER_EPOCH
:
5
,
# KEY.USE_TESTSET: False,
KEY
.
CONTINUE
:
{
KEY
.
CHECKPOINT
:
False
,
KEY
.
RESET_OPTIMIZER
:
False
,
KEY
.
RESET_SCHEDULER
:
False
,
KEY
.
RESET_EPOCH
:
False
,
KEY
.
USE_STATISTIC_VALUES_OF_CHECKPOINT
:
True
,
KEY
.
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY
:
True
,
},
# KEY.DEFAULT_MODAL: 'common',
KEY
.
CSV_LOG
:
'log.csv'
,
KEY
.
NUM_WORKERS
:
0
,
KEY
.
IS_TRAIN_STRESS
:
True
,
KEY
.
TRAIN_SHUFFLE
:
True
,
KEY
.
ERROR_RECORD
:
[
[
'Energy'
,
'RMSE'
],
[
'Force'
,
'RMSE'
],
[
'Stress'
,
'RMSE'
],
[
'TotalLoss'
,
'None'
],
],
KEY
.
BEST_METRIC
:
'TotalLoss'
,
KEY
.
USE_WEIGHT
:
False
,
KEY
.
USE_MODALITY
:
False
,
}
TRAINING_CONFIG_CONDITION
=
{
KEY
.
RANDOM_SEED
:
int
,
KEY
.
EPOCH
:
int
,
KEY
.
FORCE_WEIGHT
:
float
,
KEY
.
STRESS_WEIGHT
:
float
,
KEY
.
USE_TESTSET
:
None
,
# Not used
KEY
.
NUM_WORKERS
:
int
,
KEY
.
PER_EPOCH
:
int
,
KEY
.
CONTINUE
:
{
KEY
.
CHECKPOINT
:
str
,
KEY
.
RESET_OPTIMIZER
:
bool
,
KEY
.
RESET_SCHEDULER
:
bool
,
KEY
.
RESET_EPOCH
:
bool
,
KEY
.
USE_STATISTIC_VALUES_OF_CHECKPOINT
:
bool
,
KEY
.
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY
:
bool
,
},
KEY
.
DEFAULT_MODAL
:
str
,
KEY
.
IS_TRAIN_STRESS
:
bool
,
KEY
.
TRAIN_SHUFFLE
:
bool
,
KEY
.
ERROR_RECORD
:
error_record_condition
,
KEY
.
BEST_METRIC
:
str
,
KEY
.
CSV_LOG
:
str
,
KEY
.
USE_MODALITY
:
bool
,
KEY
.
USE_WEIGHT
:
bool
,
}
def
train_defaults
(
config
):
defaults
=
DEFAULT_TRAINING_CONFIG
if
KEY
.
IS_TRAIN_STRESS
not
in
config
:
config
[
KEY
.
IS_TRAIN_STRESS
]
=
defaults
[
KEY
.
IS_TRAIN_STRESS
]
if
not
config
[
KEY
.
IS_TRAIN_STRESS
]:
defaults
.
pop
(
KEY
.
STRESS_WEIGHT
,
None
)
return
defaults
mace-bench/3rdparty/SevenNet/sevenn/_keys.py
0 → 100644
View file @
73866b01
"""
How to add new feature?
1. Add new key to this file.
2. Add new key to _const.py
2.1. if the type of input is consistent,
write adequate condition and default to _const.py.
2.2. if the type of input is not consistent,
you must add your own input validation code to
parse_input.py
"""
from
typing
import
Final
# see
# https://github.com/pytorch/pytorch/issues/52312
# for FYI
# ~~ keys ~~ #
# PyG : primitive key of torch_geometric.data.Data type
# ==================================================#
# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ #
# ==================================================#
# some raw properties of graph
ATOMIC_NUMBERS
:
Final
[
str
]
=
'atomic_numbers'
# (N)
POS
:
Final
[
str
]
=
'pos'
# (N, 3) PyG
CELL
:
Final
[
str
]
=
'cell_lattice_vectors'
# (3, 3)
CELL_SHIFT
:
Final
[
str
]
=
'pbc_shift'
# (N, 3)
CELL_VOLUME
:
Final
[
str
]
=
'cell_volume'
EDGE_VEC
:
Final
[
str
]
=
'edge_vec'
# (N_edge, 3)
EDGE_LENGTH
:
Final
[
str
]
=
'edge_length'
# (N_edge, 1)
# some primary data of graph
EDGE_IDX
:
Final
[
str
]
=
'edge_index'
# (2, N_edge) PyG
ATOM_TYPE
:
Final
[
str
]
=
'atom_type'
# (N) one-hot index of nodes
NODE_FEATURE
:
Final
[
str
]
=
'x'
# (N, ?) PyG
NODE_FEATURE_GHOST
:
Final
[
str
]
=
'x_ghost'
NODE_ATTR
:
Final
[
str
]
=
'node_attr'
# (N, N_species) from one_hot
MODAL_ATTR
:
Final
[
str
]
=
(
'modal_attr'
# (1, N_modalities) for handling multi-modal
)
MODAL_TYPE
:
Final
[
str
]
=
'modal_type'
# (1) one-hot index of modal
EDGE_ATTR
:
Final
[
str
]
=
'edge_attr'
# (from spherical harmonics)
EDGE_EMBEDDING
:
Final
[
str
]
=
'edge_embedding'
# (from edge embedding)
# inputs of loss function
ENERGY
:
Final
[
str
]
=
'total_energy'
# (1)
FORCE
:
Final
[
str
]
=
'force_of_atoms'
# (N, 3)
STRESS
:
Final
[
str
]
=
'stress'
# (6)
# This is for training, per atom scale.
SCALED_ENERGY
:
Final
[
str
]
=
'scaled_total_energy'
# general outputs of models
SCALED_ATOMIC_ENERGY
:
Final
[
str
]
=
'scaled_atomic_energy'
ATOMIC_ENERGY
:
Final
[
str
]
=
'atomic_energy'
PRED_TOTAL_ENERGY
:
Final
[
str
]
=
'inferred_total_energy'
PRED_PER_ATOM_ENERGY
:
Final
[
str
]
=
'inferred_per_atom_energy'
PER_ATOM_ENERGY
:
Final
[
str
]
=
'per_atom_energy'
PRED_FORCE
:
Final
[
str
]
=
'inferred_force'
SCALED_FORCE
:
Final
[
str
]
=
'scaled_force'
PRED_STRESS
:
Final
[
str
]
=
'inferred_stress'
SCALED_STRESS
:
Final
[
str
]
=
'scaled_stress'
# very general data property for AtomGraphData
NUM_ATOMS
:
Final
[
str
]
=
'num_atoms'
# int
NUM_GHOSTS
:
Final
[
str
]
=
'num_ghosts'
NLOCAL
:
Final
[
str
]
=
'nlocal'
# only for lammps parallel, must be on cpu
USER_LABEL
:
Final
[
str
]
=
'user_label'
DATA_WEIGHT
:
Final
[
str
]
=
'data_weight'
# weight for given data
DATA_MODALITY
:
Final
[
str
]
=
(
'data_modality'
# modality of given data. e.g. PBE and SCAN
)
BATCH
:
Final
[
str
]
=
'batch'
TAG
=
'tag'
# replace USER_LABEL
# etc
SELF_CONNECTION_TEMP
:
Final
[
str
]
=
'self_cont_tmp'
BATCH_SIZE
:
Final
[
str
]
=
'batch_size'
INFO
:
Final
[
str
]
=
'data_info'
# something special
LABEL_NONE
:
Final
[
str
]
=
'No_label'
# ==================================================#
# ~~~~~~ KEY for train/data configuration ~~~~~~~~ #
# ==================================================#
PREPROCESS_NUM_CORES
=
'preprocess_num_cores'
SAVE_DATASET
=
'save_dataset_path'
SAVE_BY_LABEL
=
'save_by_label'
SAVE_BY_TRAIN_VALID
=
'save_by_train_valid'
DATA_FORMAT
=
'data_format'
DATA_FORMAT_ARGS
=
'data_format_args'
STRUCTURE_LIST
=
'structure_list'
LOAD_DATASET
=
'load_dataset_path'
# not used in v2
LOAD_TRAINSET
=
'load_trainset_path'
LOAD_VALIDSET
=
'load_validset_path'
LOAD_TESTSET
=
'load_testset_path'
FORMAT_OUTPUTS
=
'format_outputs_for_ase'
COMPUTE_STATISTICS
=
'compute_statistics'
DATASET_TYPE
=
'dataset_type'
RANDOM_SEED
=
'random_seed'
RATIO
=
'data_divide_ratio'
USE_TESTSET
=
'use_testset'
EPOCH
=
'epoch'
LOSS
=
'loss'
LOSS_PARAM
=
'loss_param'
OPTIMIZER
=
'optimizer'
OPTIM_PARAM
=
'optim_param'
SCHEDULER
=
'scheduler'
SCHEDULER_PARAM
=
'scheduler_param'
FORCE_WEIGHT
=
'force_loss_weight'
STRESS_WEIGHT
=
'stress_loss_weight'
DEVICE
=
'device'
DTYPE
=
'dtype'
TRAIN_SHUFFLE
=
'train_shuffle'
IS_TRAIN_STRESS
=
'is_train_stress'
CONTINUE
=
'continue'
CHECKPOINT
=
'checkpoint'
RESET_OPTIMIZER
=
'reset_optimizer'
RESET_SCHEDULER
=
'reset_scheduler'
RESET_EPOCH
=
'reset_epoch'
USE_STATISTIC_VALUES_OF_CHECKPOINT
=
'use_statistic_values_of_checkpoint'
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY
=
(
'use_statistic_values_for_cp_modal_only'
)
CSV_LOG
=
'csv_log'
ERROR_RECORD
=
'error_record'
BEST_METRIC
=
'best_metric'
NUM_WORKERS
=
'num_workers'
# not work
RANK
=
'rank'
LOCAL_RANK
=
'local_rank'
WORLD_SIZE
=
'world_size'
IS_DDP
=
'is_ddp'
DDP_BACKEND
=
'ddp_backend'
PER_EPOCH
=
'per_epoch'
USE_WEIGHT
=
'use_weight'
USE_MODALITY
=
'use_modality'
DEFAULT_MODAL
=
'default_modal'
# ==================================================#
# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ #
# ==================================================#
# ~~ global model configuration ~~ #
# note that these names are directly used for input.yaml for user input
MODEL_TYPE
=
'_model_type'
CUTOFF
=
'cutoff'
CHEMICAL_SPECIES
=
'chemical_species'
MODAL_LIST
=
'modal_list'
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
=
'_chemical_species_by_atomic_number'
NUM_SPECIES
=
'_number_of_species'
NUM_MODALITIES
=
'_number_of_modalities'
TYPE_MAP
=
'_type_map'
MODAL_MAP
=
'_modal_map'
# ~~ E3 equivariant model build configuration keys ~~ #
# see model_build default_config for type
IRREPS_MANUAL
=
'irreps_manual'
NODE_FEATURE_MULTIPLICITY
=
'channel'
RADIAL_BASIS
=
'radial_basis'
BESSEL_BASIS_NUM
=
'bessel_basis_num'
CUTOFF_FUNCTION
=
'cutoff_function'
POLY_CUT_P
=
'poly_cut_p_value'
LMAX
=
'lmax'
LMAX_EDGE
=
'lmax_edge'
LMAX_NODE
=
'lmax_node'
IS_PARITY
=
'is_parity'
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
=
'weight_nn_hidden_neurons'
NUM_CONVOLUTION
=
'num_convolution_layer'
ACTIVATION_SCARLAR
=
'act_scalar'
ACTIVATION_GATE
=
'act_gate'
ACTIVATION_RADIAL
=
'act_radial'
SELF_CONNECTION_TYPE
=
'self_connection_type'
RADIAL_BASIS_NAME
=
'radial_basis_name'
CUTOFF_FUNCTION_NAME
=
'cutoff_function_name'
USE_BIAS_IN_LINEAR
=
'use_bias_in_linear'
USE_MODAL_NODE_EMBEDDING
=
'use_modal_node_embedding'
USE_MODAL_SELF_INTER_INTRO
=
'use_modal_self_inter_intro'
USE_MODAL_SELF_INTER_OUTRO
=
'use_modal_self_inter_outro'
USE_MODAL_OUTPUT_BLOCK
=
'use_modal_output_block'
READOUT_AS_FCN
=
'readout_as_fcn'
READOUT_FCN_HIDDEN_NEURONS
=
'readout_fcn_hidden_neurons'
READOUT_FCN_ACTIVATION
=
'readout_fcn_activation'
AVG_NUM_NEIGH
=
'avg_num_neigh'
CONV_DENOMINATOR
=
'conv_denominator'
SHIFT
=
'shift'
SCALE
=
'scale'
USE_SPECIES_WISE_SHIFT_SCALE
=
'use_species_wise_shift_scale'
USE_MODAL_WISE_SHIFT
=
'use_modal_wise_shift'
USE_MODAL_WISE_SCALE
=
'use_modal_wise_scale'
TRAIN_SHIFT_SCALE
=
'train_shift_scale'
TRAIN_DENOMINTAOR
=
'train_denominator'
INTERACTION_TYPE
=
'interaction_type'
TRAIN_AVG_NUM_NEIGH
=
'train_avg_num_neigh'
# deprecated
CUEQUIVARIANCE_CONFIG
=
'cuequivariance_config'
_NORMALIZE_SPH
=
'_normalize_sph'
OPTIMIZE_BY_REDUCE
=
'optimize_by_reduce'
mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py
0 → 100644
View file @
73866b01
from
typing
import
Optional
import
torch
import
torch_geometric.data
import
sevenn._keys
as
KEY
import
sevenn.util
class
AtomGraphData
(
torch_geometric
.
data
.
Data
):
"""
Args:
x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes,
atomic_numbers]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in coordinate
format with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y_energy: scalar # unit of eV (VASP raw)
y_force: [num_nodes, 3] # unit of eV/A (VASP raw)
y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
x, y_force, pos should be aligned with each other.
"""
def
__init__
(
self
,
x
:
Optional
[
torch
.
Tensor
]
=
None
,
edge_index
:
Optional
[
torch
.
Tensor
]
=
None
,
pos
:
Optional
[
torch
.
Tensor
]
=
None
,
edge_attr
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
super
(
AtomGraphData
,
self
).
__init__
(
x
,
edge_index
,
edge_attr
,
pos
=
pos
)
self
[
KEY
.
NODE_ATTR
]
=
x
# ?
for
k
,
v
in
kwargs
.
items
():
self
[
k
]
=
v
def
to_numpy_dict
(
self
):
# This is not debugged yet!
dct
=
{
k
:
v
.
detach
().
cpu
().
numpy
()
if
type
(
v
)
is
torch
.
Tensor
else
v
for
k
,
v
in
self
.
items
()
}
return
dct
def
fit_dimension
(
self
):
per_atom_keys
=
[
KEY
.
ATOMIC_NUMBERS
,
KEY
.
ATOMIC_ENERGY
,
KEY
.
POS
,
KEY
.
FORCE
,
KEY
.
PRED_FORCE
,
]
natoms
=
self
.
num_atoms
.
item
()
for
k
,
v
in
self
.
items
():
if
not
isinstance
(
v
,
torch
.
Tensor
):
continue
if
natoms
==
1
and
k
in
per_atom_keys
:
self
[
k
]
=
v
.
squeeze
().
unsqueeze
(
0
)
else
:
self
[
k
]
=
v
.
squeeze
()
return
self
@
staticmethod
def
from_numpy_dict
(
dct
):
for
k
,
v
in
dct
.
items
():
if
k
==
KEY
.
CELL_SHIFT
:
dct
[
k
]
=
torch
.
Tensor
(
v
)
# this is special
else
:
dct
[
k
]
=
sevenn
.
util
.
dtype_correct
(
v
)
return
AtomGraphData
(
**
dct
)
mace-bench/3rdparty/SevenNet/sevenn/build.ninja
0 → 100644
View file @
73866b01
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda/bin/nvcc
cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17
post_cflags =
cuda_cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 --expt-relaxed-constexpr -fmad=false -std=c++17
cuda_post_cflags =
cuda_dlink_post_cflags =
ldflags = -shared -L/home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
rule link
command = $cxx $in $ldflags -o $out
build pair_d3_for_ase.cuda.o: cuda_compile /home/mazhaojia/mace-project/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/pair_d3_for_ase.cu
build pair_d3.so: link pair_d3_for_ase.cuda.o
default pair_d3.so
mace-bench/3rdparty/SevenNet/sevenn/calculator.py
0 → 100644
View file @
73866b01
This diff is collapsed.
Click to expand it.
mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py
0 → 100644
View file @
73866b01
This diff is collapsed.
Click to expand it.
mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py
0 → 100644
View file @
73866b01
from
copy
import
deepcopy
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
sevenn._keys
as
KEY
from
sevenn.train.loss
import
LossDefinition
from
.atom_graph_data
import
AtomGraphData
from
.train.optim
import
loss_dict
_ERROR_TYPES
=
{
'TotalEnergy'
:
{
'name'
:
'Energy'
,
'ref_key'
:
KEY
.
ENERGY
,
'pred_key'
:
KEY
.
PRED_TOTAL_ENERGY
,
'unit'
:
'eV'
,
'vdim'
:
1
,
},
'Energy'
:
{
# by default per-atom for energy
'name'
:
'Energy'
,
'ref_key'
:
KEY
.
ENERGY
,
'pred_key'
:
KEY
.
PRED_TOTAL_ENERGY
,
'unit'
:
'eV/atom'
,
'per_atom'
:
True
,
'vdim'
:
1
,
},
'Force'
:
{
'name'
:
'Force'
,
'ref_key'
:
KEY
.
FORCE
,
'pred_key'
:
KEY
.
PRED_FORCE
,
'unit'
:
'eV/Å'
,
'vdim'
:
3
,
},
'Stress'
:
{
'name'
:
'Stress'
,
'ref_key'
:
KEY
.
STRESS
,
'pred_key'
:
KEY
.
PRED_STRESS
,
'unit'
:
'kbar'
,
'coeff'
:
1602.1766208
,
'vdim'
:
6
,
},
'Stress_GPa'
:
{
'name'
:
'Stress'
,
'ref_key'
:
KEY
.
STRESS
,
'pred_key'
:
KEY
.
PRED_STRESS
,
'unit'
:
'GPa'
,
'coeff'
:
160.21766208
,
'vdim'
:
6
,
},
'TotalLoss'
:
{
'name'
:
'TotalLoss'
,
'unit'
:
None
,
},
}
def
get_err_type
(
name
:
str
)
->
Dict
[
str
,
Any
]:
return
deepcopy
(
_ERROR_TYPES
[
name
])
def
_get_loss_function_from_name
(
loss_functions
,
name
):
for
loss_def
,
w
in
loss_functions
:
if
loss_def
.
name
.
lower
()
==
name
.
lower
():
return
loss_def
,
w
return
None
,
None
class
AverageNumber
:
def
__init__
(
self
):
self
.
_sum
=
0.0
self
.
_count
=
0
def
update
(
self
,
values
:
torch
.
Tensor
):
self
.
_sum
+=
values
.
sum
().
item
()
self
.
_count
+=
values
.
numel
()
def
_ddp_reduce
(
self
,
device
):
_sum
=
torch
.
tensor
(
self
.
_sum
,
device
=
device
)
_count
=
torch
.
tensor
(
self
.
_count
,
device
=
device
)
dist
.
all_reduce
(
_sum
,
op
=
dist
.
ReduceOp
.
SUM
)
dist
.
all_reduce
(
_count
,
op
=
dist
.
ReduceOp
.
SUM
)
self
.
_sum
=
_sum
.
item
()
self
.
_count
=
_count
.
item
()
def
get
(
self
):
if
self
.
_count
==
0
:
return
torch
.
nan
return
self
.
_sum
/
self
.
_count
class
ErrorMetric
:
"""
Base class for error metrics We always average error by # of structures,
and designed to collect errors in the middle of iteration (by AverageNumber)
"""
def
__init__
(
self
,
name
:
str
,
ref_key
:
str
,
pred_key
:
str
,
coeff
:
float
=
1.0
,
unit
:
Optional
[
str
]
=
None
,
per_atom
:
bool
=
False
,
ignore_unlabeled
:
bool
=
True
,
**
kwargs
,
):
self
.
name
=
name
self
.
unit
=
unit
self
.
coeff
=
coeff
self
.
ref_key
=
ref_key
self
.
pred_key
=
pred_key
self
.
per_atom
=
per_atom
self
.
ignore_unlabeled
=
ignore_unlabeled
self
.
value
=
AverageNumber
()
def
update
(
self
,
output
:
AtomGraphData
):
raise
NotImplementedError
def
_retrieve
(
self
,
output
:
AtomGraphData
):
y_ref
=
output
[
self
.
ref_key
]
*
self
.
coeff
y_pred
=
output
[
self
.
pred_key
]
*
self
.
coeff
if
self
.
per_atom
:
assert
y_ref
.
dim
()
==
1
and
y_pred
.
dim
()
==
1
natoms
=
output
[
KEY
.
NUM_ATOMS
]
y_ref
=
y_ref
/
natoms
y_pred
=
y_pred
/
natoms
if
self
.
ignore_unlabeled
:
unlabelled_idx
=
torch
.
isnan
(
y_ref
)
y_ref
=
y_ref
[
~
unlabelled_idx
]
y_pred
=
y_pred
[
~
unlabelled_idx
]
return
y_ref
,
y_pred
def
ddp_reduce
(
self
,
device
):
self
.
value
.
_ddp_reduce
(
device
)
def
reset
(
self
):
self
.
value
=
AverageNumber
()
def
get
(
self
):
return
self
.
value
.
get
()
def
key_str
(
self
,
with_unit
=
True
):
if
self
.
unit
is
None
or
not
with_unit
:
return
self
.
name
else
:
return
f
'
{
self
.
name
}
(
{
self
.
unit
}
)'
def
__str__
(
self
):
return
f
'
{
self
.
key_str
()
}
:
{
self
.
value
.
get
():.
6
f
}
'
class
RMSError
(
ErrorMetric
):
"""
Vector squared error
"""
def
__init__
(
self
,
vdim
:
int
=
1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
vdim
=
vdim
self
.
_se
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
def
_square_error
(
self
,
y_ref
,
y_pred
,
vdim
:
int
):
return
self
.
_se
(
y_ref
.
view
(
-
1
,
vdim
),
y_pred
.
view
(
-
1
,
vdim
)).
sum
(
dim
=
1
)
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
se
=
self
.
_square_error
(
y_ref
,
y_pred
,
self
.
vdim
)
self
.
value
.
update
(
se
)
def
get
(
self
):
return
self
.
value
.
get
()
**
0.5
class
ComponentRMSError
(
ErrorMetric
):
"""
Ignore vector dim and just average over components
Results smaller error
"""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_se
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
def
_square_error
(
self
,
y_ref
,
y_pred
):
return
self
.
_se
(
y_ref
,
y_pred
)
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
y_ref
=
y_ref
.
view
(
-
1
)
y_pred
=
y_pred
.
view
(
-
1
)
se
=
self
.
_square_error
(
y_ref
,
y_pred
)
self
.
value
.
update
(
se
)
def
get
(
self
):
return
self
.
value
.
get
()
**
0.5
class
MAError
(
ErrorMetric
):
"""
Average over all component
"""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
def
_square_error
(
self
,
y_ref
,
y_pred
):
return
torch
.
abs
(
y_ref
-
y_pred
)
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
y_ref
=
y_ref
.
reshape
((
-
1
,))
y_pred
=
y_pred
.
reshape
((
-
1
,))
se
=
self
.
_square_error
(
y_ref
,
y_pred
)
self
.
value
.
update
(
se
)
class
CustomError
(
ErrorMetric
):
"""
Custom error metric
Args:
func: a function that takes y_ref and y_pred
and returns a list of errors
"""
def
__init__
(
self
,
func
:
Callable
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
func
=
func
def
update
(
self
,
output
:
AtomGraphData
):
y_ref
,
y_pred
=
self
.
_retrieve
(
output
)
se
=
self
.
func
(
y_ref
,
y_pred
)
if
len
(
y_ref
)
>
0
else
torch
.
tensor
([])
self
.
value
.
update
(
se
)
class
LossError
(
ErrorMetric
):
"""
Error metric that record loss
"""
def
__init__
(
self
,
name
:
str
,
loss_def
:
LossDefinition
,
**
kwargs
,
):
super
().
__init__
(
name
,
ignore_unlabeld
=
loss_def
.
ignore_unlabeled
,
**
kwargs
,
)
self
.
loss_def
=
loss_def
def
update
(
self
,
output
:
AtomGraphData
):
loss
=
self
.
loss_def
.
get_loss
(
output
)
# type: ignore
self
.
value
.
update
(
loss
)
# type: ignore
class
CombinedError
(
ErrorMetric
):
"""
Combine multiple error metrics with weights
corresponds to a weighted sum of errors (normally used in loss)
"""
def
__init__
(
self
,
metrics
:
List
[
Tuple
[
ErrorMetric
,
float
]],
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
metrics
=
metrics
assert
kwargs
[
'unit'
]
is
None
def
update
(
self
,
output
:
AtomGraphData
):
for
metric
,
_
in
self
.
metrics
:
metric
.
update
(
output
)
def
reset
(
self
):
for
metric
,
_
in
self
.
metrics
:
metric
.
reset
()
def
ddp_reduce
(
self
,
device
):
# override
for
metric
,
_
in
self
.
metrics
:
metric
.
value
.
_ddp_reduce
(
device
)
def
get
(
self
):
val
=
0.0
for
metric
,
weight
in
self
.
metrics
:
val
+=
metric
.
get
()
*
weight
return
val
class
ErrorRecorder
:
"""
record errors of a model
"""
METRIC_DICT
=
{
'RMSE'
:
RMSError
,
'ComponentRMSE'
:
ComponentRMSError
,
'MAE'
:
MAError
,
'Loss'
:
LossError
,
}
def
__init__
(
self
,
metrics
:
List
[
ErrorMetric
]):
self
.
history
=
[]
self
.
metrics
=
metrics
def
_update
(
self
,
output
:
AtomGraphData
):
for
metric
in
self
.
metrics
:
metric
.
update
(
output
)
def
update
(
self
,
output
:
AtomGraphData
,
no_grad
=
True
):
if
no_grad
:
with
torch
.
no_grad
():
self
.
_update
(
output
)
else
:
self
.
_update
(
output
)
def
get_metric_dict
(
self
,
with_unit
=
True
):
return
{
metric
.
key_str
(
with_unit
):
metric
.
get
()
for
metric
in
self
.
metrics
}
def
get_current
(
self
):
dct
=
{}
for
metric
in
self
.
metrics
:
dct
[
metric
.
name
]
=
{
'value'
:
metric
.
get
(),
'unit'
:
metric
.
unit
,
'ref_key'
:
metric
.
ref_key
,
'pred_key'
:
metric
.
pred_key
,
}
return
dct
def
get_dct
(
self
,
prefix
=
''
):
dct
=
{}
if
prefix
.
endswith
(
'_'
)
is
False
and
prefix
!=
''
:
prefix
=
prefix
+
'_'
for
metric
in
self
.
metrics
:
dct
[
f
'
{
prefix
}{
metric
.
name
}
'
]
=
f
'
{
metric
.
get
():
6
f
}
'
return
dct
def
get_key_str
(
self
,
name
:
str
):
for
metric
in
self
.
metrics
:
if
name
==
metric
.
name
:
return
metric
.
key_str
()
return
None
def
epoch_forward
(
self
):
self
.
history
.
append
(
self
.
get_current
())
pretty
=
self
.
get_metric_dict
(
with_unit
=
True
)
for
metric
in
self
.
metrics
:
metric
.
reset
()
return
pretty
# for print
@
staticmethod
def
init_total_loss_metric
(
config
,
criteria
:
Optional
[
Callable
]
=
None
,
loss_functions
:
Optional
[
List
[
Tuple
[
LossDefinition
,
float
]]]
=
None
,
):
if
criteria
is
None
and
loss_functions
is
None
:
raise
ValueError
(
'both criteria and loss functions not given'
)
is_stress
=
config
[
KEY
.
IS_TRAIN_STRESS
]
metrics
=
[]
if
criteria
is
not
None
:
energy_metric
=
CustomError
(
criteria
,
**
get_err_type
(
'Energy'
))
metrics
.
append
((
energy_metric
,
1
))
force_metric
=
CustomError
(
criteria
,
**
get_err_type
(
'Force'
))
metrics
.
append
((
force_metric
,
config
[
KEY
.
FORCE_WEIGHT
]))
if
is_stress
:
stress_metric
=
CustomError
(
criteria
,
**
get_err_type
(
'Stress'
))
metrics
.
append
((
stress_metric
,
config
[
KEY
.
STRESS_WEIGHT
]))
else
:
# TODO: this is hard-coded
for
efs
in
[
'Energy'
,
'Force'
,
'Stress'
]:
if
efs
==
'Stress'
and
not
is_stress
:
continue
lf
,
w
=
_get_loss_function_from_name
(
loss_functions
,
efs
)
if
lf
is
None
:
raise
ValueError
(
f
'
{
efs
}
not found from loss_functions'
)
metric
=
LossError
(
loss_def
=
lf
,
**
get_err_type
(
efs
))
metrics
.
append
((
metric
,
w
))
total_loss_metric
=
CombinedError
(
metrics
,
name
=
'TotalLoss'
,
unit
=
None
,
ref_key
=
None
,
pred_key
=
None
)
return
total_loss_metric
@
staticmethod
def
from_config
(
config
:
dict
,
loss_functions
=
None
):
loss_cls
=
loss_dict
[
config
.
get
(
KEY
.
LOSS
,
'mse'
).
lower
()]
loss_param
=
config
.
get
(
KEY
.
LOSS_PARAM
,
{})
criteria
=
loss_cls
(
**
loss_param
)
if
loss_functions
is
None
else
None
err_config
=
config
.
get
(
KEY
.
ERROR_RECORD
,
False
)
if
not
err_config
:
raise
ValueError
(
'No error_record config found. Consider util.get_error_recorder'
)
err_config_n
=
[]
if
not
config
.
get
(
KEY
.
IS_TRAIN_STRESS
,
True
):
for
err_type
,
metric_name
in
err_config
:
if
'Stress'
in
err_type
:
continue
err_config_n
.
append
((
err_type
,
metric_name
))
err_config
=
err_config_n
err_metrics
=
[]
for
err_type
,
metric_name
in
err_config
:
metric_kwargs
=
get_err_type
(
err_type
)
if
err_type
==
'TotalLoss'
:
# special case
err_metrics
.
append
(
ErrorRecorder
.
init_total_loss_metric
(
config
,
criteria
,
loss_functions
)
)
continue
metric_cls
=
ErrorRecorder
.
METRIC_DICT
[
metric_name
]
assert
isinstance
(
metric_kwargs
[
'name'
],
str
)
if
metric_name
==
'Loss'
:
if
loss_functions
is
not
None
:
metric_cls
=
LossError
metric_kwargs
[
'loss_def'
],
_
=
_get_loss_function_from_name
(
loss_functions
,
metric_kwargs
[
'name'
]
)
else
:
metric_cls
=
CustomError
metric_kwargs
[
'func'
]
=
criteria
metric_kwargs
.
pop
(
'unit'
,
None
)
metric_kwargs
[
'name'
]
+=
f
'_
{
metric_name
}
'
err_metrics
.
append
(
metric_cls
(
**
metric_kwargs
))
return
ErrorRecorder
(
err_metrics
)
mace-bench/3rdparty/SevenNet/sevenn/logger.py
0 → 100644
View file @
73866b01
import
os
import
time
import
traceback
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Optional
from
ase.data
import
atomic_numbers
import
sevenn._keys
as
KEY
from
sevenn
import
__version__
CHEM_SYMBOLS
=
{
v
:
k
for
k
,
v
in
atomic_numbers
.
items
()}
class
Singleton
(
type
):
_instances
=
{}
def
__call__
(
cls
,
*
args
,
**
kwargs
):
if
cls
not
in
cls
.
_instances
:
cls
.
_instances
[
cls
]
=
super
(
Singleton
,
cls
).
__call__
(
*
args
,
**
kwargs
)
return
cls
.
_instances
[
cls
]
class
Logger
(
metaclass
=
Singleton
):
SCREEN_WIDTH
=
120
# half size of my screen / changed due to stress output
def
__init__
(
self
,
filename
:
Optional
[
str
]
=
None
,
screen
:
bool
=
False
,
rank
:
int
=
0
):
self
.
rank
=
rank
self
.
_filename
=
filename
if
rank
==
0
:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self
.
logfile
=
None
self
.
files
=
{}
self
.
screen
=
screen
else
:
self
.
logfile
=
None
self
.
screen
=
False
self
.
timer_dct
=
{}
self
.
active
=
True
def
__enter__
(
self
):
if
self
.
rank
!=
0
:
return
self
if
self
.
logfile
is
None
and
self
.
_filename
is
not
None
:
try
:
self
.
logfile
=
open
(
self
.
_filename
,
'a'
,
buffering
=
1
,
encoding
=
'utf-8'
)
except
IOError
as
e
:
print
(
f
'Failed to re-open log file
{
self
.
_filename
}
:
{
e
}
'
)
self
.
logfile
=
None
self
.
files
=
{}
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
self
.
rank
!=
0
:
return
self
try
:
if
self
.
logfile
is
not
None
:
self
.
logfile
.
close
()
self
.
logfile
=
None
for
f
in
self
.
files
.
values
():
f
.
close
()
except
IOError
as
e
:
print
(
f
'Failed to close log files:
{
e
}
'
)
finally
:
self
.
logfile
=
None
self
.
files
=
{}
def
switch_file
(
self
,
new_filename
:
str
):
if
self
.
rank
!=
0
:
return
self
if
self
.
logfile
is
not
None
:
raise
ValueError
(
'Current logfile is not yet closed'
)
self
.
_filename
=
new_filename
return
self
def
write
(
self
,
content
:
str
):
if
self
.
rank
!=
0
:
return
# no newline!
if
self
.
logfile
is
not
None
and
self
.
active
:
self
.
logfile
.
write
(
content
)
if
self
.
screen
and
self
.
active
:
print
(
content
,
end
=
''
)
def
writeline
(
self
,
content
:
str
):
content
=
content
+
'
\n
'
self
.
write
(
content
)
def
init_csv
(
self
,
filename
:
str
,
header
:
list
):
"""
Deprecated
"""
if
self
.
rank
==
0
:
self
.
files
[
filename
]
=
open
(
filename
,
'w'
,
buffering
=
1
,
encoding
=
'utf-8'
)
self
.
files
[
filename
].
write
(
','
.
join
(
header
)
+
'
\n
'
)
else
:
pass
def
append_csv
(
self
,
filename
:
str
,
content
:
list
,
decimal
:
int
=
6
):
"""
Deprecated
"""
if
self
.
rank
==
0
:
if
filename
not
in
self
.
files
:
self
.
files
[
filename
]
=
open
(
filename
,
'a'
,
buffering
=
1
)
str_content
=
[]
for
c
in
content
:
if
isinstance
(
c
,
float
):
str_content
.
append
(
f
'
{
c
:.
{
decimal
}
f
}
'
)
else
:
str_content
.
append
(
str
(
c
))
self
.
files
[
filename
].
write
(
','
.
join
(
str_content
)
+
'
\n
'
)
else
:
pass
def
natoms_write
(
self
,
natoms
:
Dict
[
str
,
Dict
]):
content
=
''
total_natom
=
{}
for
label
,
natom
in
natoms
.
items
():
content
+=
self
.
format_k_v
(
label
,
natom
)
for
specie
,
num
in
natom
.
items
():
try
:
total_natom
[
specie
]
+=
num
except
KeyError
:
total_natom
[
specie
]
=
num
content
+=
self
.
format_k_v
(
'Total, label wise'
,
total_natom
)
content
+=
self
.
format_k_v
(
'Total'
,
sum
(
total_natom
.
values
()))
self
.
write
(
content
)
def
statistic_write
(
self
,
statistic
:
Dict
[
str
,
Dict
]):
content
=
''
for
label
,
dct
in
statistic
.
items
():
if
label
.
startswith
(
'_'
):
continue
if
not
isinstance
(
dct
,
dict
):
continue
dct_new
=
{}
for
k
,
v
in
dct
.
items
():
if
k
.
startswith
(
'_'
):
continue
if
isinstance
(
v
,
int
):
dct_new
[
k
]
=
v
else
:
dct_new
[
k
]
=
f
'
{
v
:.
3
f
}
'
content
+=
self
.
format_k_v
(
label
,
dct_new
)
self
.
write
(
content
)
# TODO : refactoring!!!, this is not loss, rmse
def
epoch_write_specie_wise_loss
(
self
,
train_loss
,
valid_loss
):
lb_pad
=
21
fs
=
6
pad
=
21
-
fs
ln
=
'-'
*
fs
total_atom_type
=
train_loss
.
keys
()
content
=
''
for
at
in
total_atom_type
:
t_F
=
train_loss
[
at
]
v_F
=
valid_loss
[
at
]
at_sym
=
CHEM_SYMBOLS
[
at
]
content
+=
'{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'
.
format
(
label
=
at_sym
,
t_E
=
ln
,
v_E
=
ln
,
lb_pad
=
lb_pad
,
pad
=
pad
,
fs
=
fs
)
+
'{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'
.
format
(
t_F
=
t_F
,
v_F
=
v_F
,
pad
=
pad
,
fs
=
fs
)
content
+=
'{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'
.
format
(
t_S
=
ln
,
v_S
=
ln
,
pad
=
pad
,
fs
=
fs
)
content
+=
'
\n
'
self
.
write
(
content
)
def
write_full_table
(
self
,
dict_list
:
List
[
Dict
],
row_labels
:
List
[
str
],
decimal_places
:
int
=
6
,
pad
:
int
=
2
,
):
"""
Assume data_list is list of dict with same keys
"""
assert
len
(
dict_list
)
==
len
(
row_labels
)
label_len
=
max
(
map
(
len
,
row_labels
))
# Extract the column names and create a 2D array of values
col_names
=
list
(
dict_list
[
0
].
keys
())
values
=
[
list
(
d
.
values
())
for
d
in
dict_list
]
# Format the numbers with the given decimal places
formatted_values
=
[
[
f
'
{
value
:.
{
decimal_places
}
f
}
'
for
value
in
row
]
for
row
in
values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths
=
[
max
(
len
(
str
(
value
))
for
value
in
col
)
+
pad
for
col
in
zip
(
col_names
,
*
formatted_values
)
]
# Create header row and separator
header
=
' '
*
(
label_len
+
pad
)
+
' '
.
join
(
col_name
.
ljust
(
pad
)
for
col_name
,
pad
in
zip
(
col_names
,
max_col_lengths
)
)
separator
=
'-'
.
join
(
'-'
*
pad
for
pad
in
max_col_lengths
)
+
'-'
*
(
label_len
+
pad
)
# Print header and separator
self
.
writeline
(
header
)
self
.
writeline
(
separator
)
# Print the data rows with row labels
for
row_label
,
row
in
zip
(
row_labels
,
formatted_values
):
data_row
=
' '
.
join
(
value
.
rjust
(
pad
)
for
value
,
pad
in
zip
(
row
,
max_col_lengths
)
)
self
.
writeline
(
f
'
{
row_label
.
ljust
(
label_len
)
}{
data_row
}
'
)
def
format_k_v
(
self
,
key
:
Any
,
val
:
Any
,
write
:
bool
=
False
):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE
=
20
SEPARATOR
=
', '
EMPTY_PADDING
=
' '
*
(
MAX_KEY_SIZE
+
3
)
NEW_LINE_LEN
=
Logger
.
SCREEN_WIDTH
-
5
key
=
str
(
key
)
val
=
str
(
val
)
content
=
f
'
{
key
:
<
{
MAX_KEY_SIZE
}}
:
{
val
}
'
if
len
(
content
)
>
NEW_LINE_LEN
:
content
=
f
'
{
key
:
<
{
MAX_KEY_SIZE
}}
: '
# septate val by separator
val_list
=
val
.
split
(
SEPARATOR
)
current_len
=
len
(
content
)
for
val_compo
in
val_list
:
current_len
+=
len
(
val_compo
)
if
current_len
>
NEW_LINE_LEN
:
newline_content
=
f
'
{
EMPTY_PADDING
}{
val_compo
}{
SEPARATOR
}
'
content
+=
f
'
\\\n
{
newline_content
}
'
current_len
=
len
(
newline_content
)
else
:
content
+=
f
'
{
val_compo
}{
SEPARATOR
}
'
if
content
.
endswith
(
f
'
{
SEPARATOR
}
'
):
content
=
content
[:
-
len
(
SEPARATOR
)]
content
+=
'
\n
'
if
write
is
False
:
return
content
else
:
self
.
write
(
content
)
return
''
def
greeting
(
self
):
LOGO_ASCII_FILE
=
f
'
{
os
.
path
.
dirname
(
__file__
)
}
/logo_ascii'
with
open
(
LOGO_ASCII_FILE
,
'r'
)
as
logo_f
:
logo_ascii
=
logo_f
.
read
()
content
=
'SevenNet: Scalable EquiVariance-Enabled Neural Network
\n
'
content
+=
f
'version
{
__version__
}
,
{
time
.
ctime
()
}
\n
'
self
.
write
(
content
)
self
.
write
(
logo_ascii
)
def
bar
(
self
):
content
=
'-'
*
Logger
.
SCREEN_WIDTH
+
'
\n
'
self
.
write
(
content
)
def
print_config
(
self
,
model_config
:
Dict
[
str
,
Any
],
data_config
:
Dict
[
str
,
Any
],
train_config
:
Dict
[
str
,
Any
],
):
"""
print some important information from config
"""
content
=
'successfully read yaml config!
\n\n
'
+
'from model configuration
\n
'
for
k
,
v
in
model_config
.
items
():
content
+=
self
.
format_k_v
(
k
,
str
(
v
))
content
+=
'
\n
from train configuration
\n
'
for
k
,
v
in
train_config
.
items
():
content
+=
self
.
format_k_v
(
k
,
str
(
v
))
content
+=
'
\n
from data configuration
\n
'
for
k
,
v
in
data_config
.
items
():
content
+=
self
.
format_k_v
(
k
,
str
(
v
))
self
.
write
(
content
)
# TODO: This is not good make own exception
def
error
(
self
,
e
:
Exception
):
content
=
''
if
type
(
e
)
is
ValueError
:
content
+=
'Error occurred!
\n
'
content
+=
str
(
e
)
+
'
\n
'
else
:
content
+=
'Unknown error occurred!
\n
'
content
+=
traceback
.
format_exc
()
self
.
write
(
content
)
def
timer_start
(
self
,
name
:
str
):
self
.
timer_dct
[
name
]
=
datetime
.
now
()
def
timer_end
(
self
,
name
:
str
,
message
:
str
,
remove
:
bool
=
True
):
"""
print f"{message}: {elapsed}"
"""
elapsed
=
str
(
datetime
.
now
()
-
self
.
timer_dct
[
name
])
# elapsed = elapsed.strftime('%H-%M-%S')
if
remove
:
del
self
.
timer_dct
[
name
]
self
.
write
(
f
'
{
message
}
:
{
elapsed
[:
-
4
]
}
\n
'
)
# TODO: print it without config
# TODO: refactoring, readout part name :(
def
print_model_info
(
self
,
model
,
config
):
from
functools
import
partial
kv_write
=
partial
(
self
.
format_k_v
,
write
=
True
)
self
.
writeline
(
'Irreps of features'
)
kv_write
(
'edge_feature'
,
model
.
get_irreps_in
(
'edge_embedding'
,
'irreps_out'
))
for
i
in
range
(
config
[
KEY
.
NUM_CONVOLUTION
]):
kv_write
(
f
'
{
i
}
th node'
,
model
.
get_irreps_in
(
f
'
{
i
}
_self_interaction_1'
),
)
i
=
config
[
KEY
.
NUM_CONVOLUTION
]
-
1
kv_write
(
'readout irreps'
,
model
.
get_irreps_in
(
f
'
{
i
}
_equivariant_gate'
,
'irreps_out'
),
)
num_weights
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
self
.
writeline
(
f
'# learnable parameters:
{
num_weights
}
\n
'
)
mace-bench/3rdparty/SevenNet/sevenn/logo_ascii
0 → 100644
View file @
73866b01
****
******** .
*//////, .. . ,*.
,,***. .. , ********. ./,
. . .. /////. ., . *///////// /////////.
.&@&/ . .(((((((.. / *//////*. ... *((((((((((.
@@@@@@@@@@* @@@@@@@@@@ @@@@@ *((@@@@@ ( %@@@@@@@@@@ .@@@@@@ ..@@@@. @@@@@@* .(@@@@@(((*
@@@@@. @@@@ @@@@@ . @@@@@ # %@@@@ @@@@@@@@ @@@@(, @@@@@@@@. @@@@@(*.
%@@@@@@@& @@@@@@@@@@ @@@@@ @@@@@ # ., .%@@@@@@@@@ @@@@@@@@@@ @@@@, @@@@@@@@@@ @@@@@
,(%@@@@@@@@@ @@@@@@@@@@ @@@@@ @@@@& (//////%@@@@@@@@@ @@@@ @@@@@@ @@@@ . @@@@@ @@@@@.@@@@@
. @@@@@ @@@@ . . @@@@@@@@% . . ( .////,%@@@@ @@@@ @@@@@@@@@ @@@@@ @@@@@@@@@
(@@@@@@@@@@@ @@@@@@@@@@**. @@@@@@* *. .%@@@@@@@@@@ @@@@ . @@@@@@@ @@@@@ .@@@@@@@
@@@@@@@@@. @@@@@@@@@@///, @@@@. . / %@@@@@@@@@@ @@@@***, @@@@@ @@@@@ @@@@@
. //////////*. / . .*******... . ,.
.&&&&&... ,//////*. ...////. / ,*/. . ,////, .,/////
&@@@@@@ ,(/((, * ,((((((. .***.
,/@(, .. * ,((((*
,
.
mace-bench/3rdparty/SevenNet/sevenn/model_build.py
0 → 100644
View file @
73866b01
import
copy
import
warnings
from
collections
import
OrderedDict
from
typing
import
List
,
Literal
,
Union
,
overload
from
e3nn.o3
import
Irreps
import
sevenn._const
as
_const
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
from
.nn.convolution
import
IrrepsConvolution
from
.nn.edge_embedding
import
(
BesselBasis
,
EdgeEmbedding
,
PolynomialCutoff
,
SphericalEncoding
,
XPLORCutoff
,
)
from
.nn.force_output
import
ForceStressOutputFromEdge
from
.nn.interaction_blocks
import
NequIP_interaction_block
from
.nn.linear
import
AtomReduce
,
FCN_e3nn
,
IrrepsLinear
from
.nn.node_embedding
import
OnehotEmbedding
from
.nn.scale
import
ModalWiseRescale
,
Rescale
,
SpeciesWiseRescale
from
.nn.self_connection
import
(
SelfConnectionIntro
,
SelfConnectionLinearIntro
,
SelfConnectionOutro
,
)
from
.nn.sequential
import
AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings
.
filterwarnings
(
'ignore'
,
message
=
(
"The TorchScript type system doesn't "
'support instance-level annotations'
),
)
def
_insert_after
(
module_name_after
,
key_module_pair
,
layers
):
idx
=
-
1
for
i
,
(
key
,
_
)
in
enumerate
(
layers
):
if
key
==
module_name_after
:
idx
=
i
break
if
idx
==
-
1
:
return
layers
# do nothing if not found
layers
.
insert
(
idx
+
1
,
key_module_pair
)
return
layers
def
init_self_connection
(
config
):
self_connection_type_list
=
config
[
KEY
.
SELF_CONNECTION_TYPE
]
num_conv
=
config
[
KEY
.
NUM_CONVOLUTION
]
if
isinstance
(
self_connection_type_list
,
str
):
self_connection_type_list
=
[
self_connection_type_list
]
*
num_conv
io_pair_list
=
[]
for
sc_type
in
self_connection_type_list
:
if
sc_type
==
'none'
:
io_pair
=
None
elif
sc_type
==
'nequip'
:
io_pair
=
SelfConnectionIntro
,
SelfConnectionOutro
elif
sc_type
==
'linear'
:
io_pair
=
SelfConnectionLinearIntro
,
SelfConnectionOutro
else
:
raise
ValueError
(
f
'Unknown self_connection_type found:
{
sc_type
}
'
)
io_pair_list
.
append
(
io_pair
)
return
io_pair_list
def
init_edge_embedding
(
config
):
_cutoff_param
=
{
'cutoff_length'
:
config
[
KEY
.
CUTOFF
]}
rbf
,
env
,
sph
=
None
,
None
,
None
rbf_dct
=
copy
.
deepcopy
(
config
[
KEY
.
RADIAL_BASIS
])
rbf_dct
.
update
(
_cutoff_param
)
rbf_name
=
rbf_dct
.
pop
(
KEY
.
RADIAL_BASIS_NAME
)
if
rbf_name
==
'bessel'
:
rbf
=
BesselBasis
(
**
rbf_dct
)
envelop_dct
=
copy
.
deepcopy
(
config
[
KEY
.
CUTOFF_FUNCTION
])
envelop_dct
.
update
(
_cutoff_param
)
envelop_name
=
envelop_dct
.
pop
(
KEY
.
CUTOFF_FUNCTION_NAME
)
if
envelop_name
==
'poly_cut'
:
env
=
PolynomialCutoff
(
**
envelop_dct
)
elif
envelop_name
==
'XPLOR'
:
env
=
XPLORCutoff
(
**
envelop_dct
)
lmax_edge
=
config
[
KEY
.
LMAX
]
if
config
[
KEY
.
LMAX_EDGE
]
>
0
:
lmax_edge
=
config
[
KEY
.
LMAX_EDGE
]
parity
=
-
1
if
config
[
KEY
.
IS_PARITY
]
else
1
_normalize_sph
=
config
[
KEY
.
_NORMALIZE_SPH
]
sph
=
SphericalEncoding
(
lmax_edge
,
parity
,
normalize
=
_normalize_sph
)
return
EdgeEmbedding
(
basis_module
=
rbf
,
cutoff_module
=
env
,
spherical_module
=
sph
)
def
init_feature_reduce
(
config
,
irreps_x
):
# features per node to scalar per node
layers
=
OrderedDict
()
if
config
[
KEY
.
READOUT_AS_FCN
]
is
False
:
hidden_irreps
=
Irreps
([(
irreps_x
.
dim
//
2
,
(
0
,
1
))])
layers
.
update
(
{
'reduce_input_to_hidden'
:
IrrepsLinear
(
irreps_x
,
hidden_irreps
,
data_key_in
=
KEY
.
NODE_FEATURE
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
'reduce_hidden_to_energy'
:
IrrepsLinear
(
hidden_irreps
,
Irreps
([(
1
,
(
0
,
1
))]),
data_key_in
=
KEY
.
NODE_FEATURE
,
data_key_out
=
KEY
.
SCALED_ATOMIC_ENERGY
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
}
)
else
:
act
=
_const
.
ACTIVATION
[
config
[
KEY
.
READOUT_FCN_ACTIVATION
]]
hidden_neurons
=
config
[
KEY
.
READOUT_FCN_HIDDEN_NEURONS
]
layers
.
update
(
{
'readout_FCN'
:
FCN_e3nn
(
dim_out
=
1
,
hidden_neurons
=
hidden_neurons
,
activation
=
act
,
data_key_in
=
KEY
.
NODE_FEATURE
,
data_key_out
=
KEY
.
SCALED_ATOMIC_ENERGY
,
irreps_in
=
irreps_x
,
)
}
)
return
layers
def
init_shift_scale
(
config
):
# for mm, ex, shift: modal_idx -> shifts
shift_scale
=
[]
train_shift_scale
=
config
[
KEY
.
TRAIN_SHIFT_SCALE
]
type_map
=
config
[
KEY
.
TYPE_MAP
]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for
s
in
(
config
[
KEY
.
SHIFT
],
config
[
KEY
.
SCALE
]):
if
hasattr
(
s
,
'tolist'
):
# numpy or torch
s
=
s
.
tolist
()
if
isinstance
(
s
,
dict
):
s
=
{
k
:
v
.
tolist
()
if
hasattr
(
v
,
'tolist'
)
else
v
for
k
,
v
in
s
.
items
()}
if
isinstance
(
s
,
list
)
and
len
(
s
)
==
1
:
s
=
s
[
0
]
shift_scale
.
append
(
s
)
shift
,
scale
=
shift_scale
rescale_module
=
None
if
config
.
get
(
KEY
.
USE_MODALITY
,
False
):
rescale_module
=
ModalWiseRescale
.
from_mappers
(
# type: ignore
shift
,
scale
,
config
[
KEY
.
USE_MODAL_WISE_SHIFT
],
config
[
KEY
.
USE_MODAL_WISE_SCALE
],
type_map
=
type_map
,
modal_map
=
config
[
KEY
.
MODAL_MAP
],
train_shift_scale
=
train_shift_scale
,
)
elif
all
([
isinstance
(
s
,
float
)
for
s
in
shift_scale
]):
rescale_module
=
Rescale
(
shift
,
scale
,
train_shift_scale
=
train_shift_scale
)
elif
any
([
isinstance
(
s
,
list
)
for
s
in
shift_scale
]):
rescale_module
=
SpeciesWiseRescale
.
from_mappers
(
# type: ignore
shift
,
scale
,
type_map
=
type_map
,
train_shift_scale
=
train_shift_scale
)
else
:
raise
ValueError
(
'shift, scale should be list of float or float'
)
return
rescale_module
def
patch_modality
(
layers
:
OrderedDict
,
config
):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg
=
config
if
not
cfg
.
get
(
KEY
.
USE_MODALITY
,
False
):
return
layers
_layers
=
list
(
layers
.
items
())
_layers
=
_insert_after
(
'onehot_idx_to_onehot'
,
(
'one_hot_modality'
,
OnehotEmbedding
(
num_classes
=
config
[
KEY
.
NUM_MODALITIES
],
data_key_x
=
KEY
.
MODAL_TYPE
,
data_key_out
=
KEY
.
MODAL_ATTR
,
data_key_save
=
None
,
data_key_additional
=
None
,
),
),
_layers
,
)
layers
=
OrderedDict
(
_layers
)
num_modal
=
config
[
KEY
.
NUM_MODALITIES
]
for
k
,
module
in
layers
.
items
():
if
not
isinstance
(
module
,
IrrepsLinear
):
continue
if
(
(
cfg
[
KEY
.
USE_MODAL_NODE_EMBEDDING
]
and
k
.
endswith
(
'onehot_to_feature_x'
))
or
(
cfg
[
KEY
.
USE_MODAL_SELF_INTER_INTRO
]
and
k
.
endswith
(
'self_interaction_1'
)
)
or
(
cfg
[
KEY
.
USE_MODAL_SELF_INTER_OUTRO
]
and
k
.
endswith
(
'self_interaction_2'
)
)
or
(
cfg
[
KEY
.
USE_MODAL_OUTPUT_BLOCK
]
and
k
==
'reduce_input_to_hidden'
)
):
module
.
set_num_modalities
(
num_modal
)
return
layers
def
patch_cue
(
layers
:
OrderedDict
,
config
):
import
sevenn.nn.cue_helper
as
cue_helper
cue_cfg
=
copy
.
deepcopy
(
config
.
get
(
KEY
.
CUEQUIVARIANCE_CONFIG
,
{}))
if
not
cue_cfg
.
pop
(
'use'
,
False
):
return
layers
if
not
cue_helper
.
is_cue_available
():
warnings
.
warn
(
(
'cuEquivariance is requested, but the package is not installed. '
+
'Fallback to original code.'
)
)
return
layers
if
not
cue_helper
.
is_cue_cuda_available_model
(
config
):
return
layers
group
=
'O3'
if
config
[
KEY
.
IS_PARITY
]
else
'SO3'
cueq_module_params
=
dict
(
layout
=
'mul_ir'
)
cueq_module_params
.
update
(
cue_cfg
)
updates
=
{}
for
k
,
module
in
layers
.
items
():
if
isinstance
(
module
,
(
IrrepsLinear
,
SelfConnectionLinearIntro
)):
if
k
==
'reduce_hidden_to_energy'
:
# TODO: has bug with 0 shape
continue
module_patched
=
cue_helper
.
patch_linear
(
module
,
group
,
**
cueq_module_params
)
updates
[
k
]
=
module_patched
elif
isinstance
(
module
,
SelfConnectionIntro
):
module_patched
=
cue_helper
.
patch_fully_connected
(
module
,
group
,
**
cueq_module_params
)
updates
[
k
]
=
module_patched
elif
isinstance
(
module
,
IrrepsConvolution
):
module_patched
=
cue_helper
.
patch_convolution
(
module
,
group
,
**
cueq_module_params
)
updates
[
k
]
=
module_patched
layers
.
update
(
updates
)
return
layers
def
patch_modules
(
layers
:
OrderedDict
,
config
):
layers
=
patch_modality
(
layers
,
config
)
layers
=
patch_cue
(
layers
,
config
)
return
layers
def
_to_parallel_model
(
layers
:
OrderedDict
,
config
):
num_classes
=
layers
[
'onehot_idx_to_onehot'
].
num_classes
one_hot_irreps
=
Irreps
(
f
'
{
num_classes
}
x0e'
)
irreps_node_zero
=
layers
[
'onehot_to_feature_x'
].
irreps_out
_layers
=
list
(
layers
.
items
())
layers_list
=
[]
num_convolution_layer
=
config
[
KEY
.
NUM_CONVOLUTION
]
def
slice_until_this
(
module_name
,
layers
):
idx
=
-
1
for
i
,
(
key
,
_
)
in
enumerate
(
layers
):
if
key
==
module_name
:
idx
=
i
break
first_to
=
layers
[:
idx
+
1
]
remain
=
layers
[
idx
+
1
:]
return
first_to
,
remain
_layers
=
_insert_after
(
'onehot_to_feature_x'
,
(
'one_hot_ghost'
,
OnehotEmbedding
(
data_key_x
=
KEY
.
NODE_FEATURE_GHOST
,
num_classes
=
num_classes
,
data_key_save
=
None
,
data_key_additional
=
None
,
),
),
_layers
,
)
_layers
=
_insert_after
(
'one_hot_ghost'
,
(
'ghost_onehot_to_feature_x'
,
IrrepsLinear
(
irreps_in
=
one_hot_irreps
,
irreps_out
=
irreps_node_zero
,
data_key_in
=
KEY
.
NODE_FEATURE_GHOST
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
),
_layers
,
)
_layers
=
_insert_after
(
'0_self_interaction_1'
,
(
'ghost_0_self_interaction_1'
,
IrrepsLinear
(
irreps_node_zero
,
irreps_node_zero
,
data_key_in
=
KEY
.
NODE_FEATURE_GHOST
,
biases
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
],
),
),
_layers
,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for
i
in
range
(
1
,
num_convolution_layer
):
sliced
,
_layers
=
slice_until_this
(
f
'
{
i
}
_self_interaction_1'
,
_layers
)
layers_list
.
append
(
OrderedDict
(
sliced
))
_layers
.
insert
(
0
,
(
'edge_embedding'
,
init_edge_embedding
(
config
)))
layers_list
.
append
(
OrderedDict
(
_layers
))
del
layers_list
[
-
1
][
'force_output'
]
# done in LAMMPS
return
layers_list
@
overload
def
build_E3_equivariant_model
(
config
:
dict
,
parallel
:
Literal
[
False
]
=
False
)
->
AtomGraphSequential
:
# noqa
...
@
overload
def
build_E3_equivariant_model
(
config
:
dict
,
parallel
:
Literal
[
True
]
)
->
List
[
AtomGraphSequential
]:
# noqa
...
def
build_E3_equivariant_model
(
config
:
dict
,
parallel
:
bool
=
False
)
->
Union
[
AtomGraphSequential
,
List
[
AtomGraphSequential
]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers
=
OrderedDict
()
cutoff
=
config
[
KEY
.
CUTOFF
]
num_species
=
config
[
KEY
.
NUM_SPECIES
]
feature_multiplicity
=
config
[
KEY
.
NODE_FEATURE_MULTIPLICITY
]
num_convolution_layer
=
config
[
KEY
.
NUM_CONVOLUTION
]
interaction_type
=
config
[
KEY
.
INTERACTION_TYPE
]
use_bias_in_linear
=
config
[
KEY
.
USE_BIAS_IN_LINEAR
]
lmax_node
=
config
[
KEY
.
LMAX
]
# ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if
config
[
KEY
.
LMAX_NODE
]
>
0
:
lmax_node
=
config
[
KEY
.
LMAX_NODE
]
act_radial
=
_const
.
ACTIVATION
[
config
[
KEY
.
ACTIVATION_RADIAL
]]
self_connection_pair_list
=
init_self_connection
(
config
)
irreps_manual
=
None
if
config
[
KEY
.
IRREPS_MANUAL
]
is
not
False
:
irreps_manual
=
config
[
KEY
.
IRREPS_MANUAL
]
try
:
irreps_manual
=
[
Irreps
(
irr
)
for
irr
in
irreps_manual
]
assert
len
(
irreps_manual
)
==
num_convolution_layer
+
1
except
Exception
:
raise
RuntimeError
(
'invalid irreps_manual input given'
)
conv_denominator
=
config
[
KEY
.
CONV_DENOMINATOR
]
if
not
isinstance
(
conv_denominator
,
list
):
conv_denominator
=
[
conv_denominator
]
*
num_convolution_layer
train_conv_denominator
=
config
[
KEY
.
TRAIN_DENOMINTAOR
]
edge_embedding
=
init_edge_embedding
(
config
)
irreps_filter
=
edge_embedding
.
spherical
.
irreps_out
radial_basis_num
=
edge_embedding
.
basis_function
.
num_basis
layers
.
update
({
'edge_embedding'
:
edge_embedding
})
one_hot_irreps
=
Irreps
(
f
'
{
num_species
}
x0e'
)
irreps_x
=
(
Irreps
(
f
'
{
feature_multiplicity
}
x0e'
)
if
irreps_manual
is
None
else
irreps_manual
[
0
]
)
layers
.
update
(
{
'onehot_idx_to_onehot'
:
OnehotEmbedding
(
num_classes
=
num_species
,
data_key_x
=
KEY
.
NODE_FEATURE
,
data_key_out
=
KEY
.
NODE_FEATURE
,
data_key_save
=
KEY
.
ATOM_TYPE
,
# atomic numbers
data_key_additional
=
KEY
.
NODE_ATTR
,
# one-hot embeddings
),
'onehot_to_feature_x'
:
IrrepsLinear
(
irreps_in
=
one_hot_irreps
,
irreps_out
=
irreps_x
,
data_key_in
=
KEY
.
NODE_FEATURE
,
biases
=
use_bias_in_linear
,
),
}
)
weight_nn_hidden
=
config
[
KEY
.
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
]
weight_nn_layers
=
[
radial_basis_num
]
+
weight_nn_hidden
param_interaction_block
=
{
'irreps_filter'
:
irreps_filter
,
'weight_nn_layers'
:
weight_nn_layers
,
'train_conv_denominator'
:
train_conv_denominator
,
'act_radial'
:
act_radial
,
'bias_in_linear'
:
use_bias_in_linear
,
'num_species'
:
num_species
,
'parallel'
:
parallel
,
}
interaction_builder
=
None
if
interaction_type
in
[
'nequip'
]:
act_scalar
=
{}
act_gate
=
{}
for
k
,
v
in
config
[
KEY
.
ACTIVATION_SCARLAR
].
items
():
act_scalar
[
k
]
=
_const
.
ACTIVATION_DICT
[
k
][
v
]
for
k
,
v
in
config
[
KEY
.
ACTIVATION_GATE
].
items
():
act_gate
[
k
]
=
_const
.
ACTIVATION_DICT
[
k
][
v
]
param_interaction_block
.
update
(
{
'act_scalar'
:
act_scalar
,
'act_gate'
:
act_gate
,
}
)
if
interaction_type
==
'nequip'
:
interaction_builder
=
NequIP_interaction_block
else
:
raise
ValueError
(
f
'Unknown interaction type:
{
interaction_type
}
'
)
for
t
in
range
(
num_convolution_layer
):
param_interaction_block
.
update
(
{
'irreps_x'
:
irreps_x
,
't'
:
t
,
'conv_denominator'
:
conv_denominator
[
t
],
'self_connection_pair'
:
self_connection_pair_list
[
t
],
}
)
if
interaction_type
==
'nequip'
:
parity_mode
=
'full'
fix_multiplicity
=
False
if
t
==
num_convolution_layer
-
1
:
lmax_node
=
0
parity_mode
=
'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out
=
(
util
.
infer_irreps_out
(
irreps_x
,
# type: ignore
irreps_filter
,
lmax_node
,
# type: ignore
parity_mode
,
fix_multiplicity
=
feature_multiplicity
,
)
if
irreps_manual
is
None
else
irreps_manual
[
t
+
1
]
)
irreps_out_tp
=
util
.
infer_irreps_out
(
irreps_x
,
# type: ignore
irreps_filter
,
irreps_out
.
lmax
,
# type: ignore
parity_mode
,
fix_multiplicity
,
)
else
:
raise
ValueError
(
f
'Unknown interaction type:
{
interaction_type
}
'
)
param_interaction_block
.
update
(
{
'irreps_out_tp'
:
irreps_out_tp
,
'irreps_out'
:
irreps_out
,
}
)
layers
.
update
(
interaction_builder
(
**
param_interaction_block
))
irreps_x
=
irreps_out
layers
.
update
(
init_feature_reduce
(
config
,
irreps_x
))
layers
.
update
(
{
'rescale_atomic_energy'
:
init_shift_scale
(
config
),
'reduce_total_enegy'
:
AtomReduce
(
data_key_in
=
KEY
.
ATOMIC_ENERGY
,
data_key_out
=
KEY
.
PRED_TOTAL_ENERGY
,
),
}
)
gradient_module
=
ForceStressOutputFromEdge
()
grad_key
=
gradient_module
.
get_grad_key
()
layers
.
update
({
'force_output'
:
gradient_module
})
common_args
=
{
'cutoff'
:
cutoff
,
'type_map'
:
config
[
KEY
.
TYPE_MAP
],
'modal_map'
:
config
.
get
(
KEY
.
MODAL_MAP
,
None
),
'eval_type_map'
:
False
if
parallel
else
True
,
'eval_modal_map'
:
False
if
not
config
.
get
(
KEY
.
USE_MODALITY
,
False
)
or
parallel
else
True
,
'data_key_grad'
:
grad_key
,
}
if
parallel
:
layers_list
=
_to_parallel_model
(
layers
,
config
)
return
[
AtomGraphSequential
(
patch_modules
(
layers
,
config
),
**
common_args
)
for
layers
in
layers_list
]
else
:
return
AtomGraphSequential
(
patch_modules
(
layers
,
config
),
**
common_args
)
mace-bench/3rdparty/SevenNet/sevenn/pair_d3.so
0 → 100644
View file @
73866b01
File added
mace-bench/3rdparty/SevenNet/sevenn/pair_d3_for_ase.cuda.o
0 → 100644
View file @
73866b01
File added
mace-bench/3rdparty/SevenNet/sevenn/parse_input.py
0 → 100644
View file @
73866b01
import
glob
import
os
import
warnings
from
typing
import
Any
,
Callable
,
Dict
import
torch
import
yaml
import
sevenn._const
as
_const
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
def
config_initialize
(
key
:
str
,
config
:
Dict
,
default
:
Any
,
conditions
:
Dict
,
):
# default value exist & no user input -> return default
if
key
not
in
config
.
keys
():
return
default
# No validation method exist => accept user input
user_input
=
config
[
key
]
if
key
in
conditions
:
condition
=
conditions
[
key
]
else
:
return
user_input
if
type
(
default
)
is
dict
and
isinstance
(
condition
,
dict
):
for
i_key
,
val
in
default
.
items
():
user_input
[
i_key
]
=
config_initialize
(
i_key
,
user_input
,
val
,
condition
)
return
user_input
elif
isinstance
(
condition
,
type
):
if
isinstance
(
user_input
,
condition
):
return
user_input
else
:
try
:
return
condition
(
user_input
)
# try type casting
except
ValueError
:
raise
ValueError
(
f
"Expect '
{
user_input
}
' for '
{
key
}
' is
{
condition
}
"
)
elif
isinstance
(
condition
,
Callable
)
and
condition
(
user_input
):
return
user_input
else
:
raise
ValueError
(
f
"Given input '
{
user_input
}
' for '
{
key
}
' is not valid"
)
def
init_model_config
(
config
:
Dict
):
# defaults = _const.model_defaults(config)
model_meta
=
{}
# init complicated ones
if
KEY
.
CHEMICAL_SPECIES
not
in
config
.
keys
():
raise
ValueError
(
'required key chemical_species not exist'
)
input_chem
=
config
[
KEY
.
CHEMICAL_SPECIES
]
if
isinstance
(
input_chem
,
str
)
and
input_chem
.
lower
()
==
'auto'
:
model_meta
[
KEY
.
CHEMICAL_SPECIES
]
=
'auto'
model_meta
[
KEY
.
NUM_SPECIES
]
=
'auto'
model_meta
[
KEY
.
TYPE_MAP
]
=
'auto'
elif
isinstance
(
input_chem
,
str
)
and
'univ'
in
input_chem
.
lower
():
model_meta
.
update
(
util
.
chemical_species_preprocess
([],
universal
=
True
))
else
:
if
isinstance
(
input_chem
,
list
)
and
all
(
isinstance
(
x
,
str
)
for
x
in
input_chem
):
pass
elif
isinstance
(
input_chem
,
str
):
input_chem
=
(
input_chem
.
replace
(
'-'
,
','
).
replace
(
' '
,
','
).
split
(
','
)
)
input_chem
=
[
chem
for
chem
in
input_chem
if
len
(
chem
)
!=
0
]
else
:
raise
ValueError
(
f
'given
{
KEY
.
CHEMICAL_SPECIES
}
input is strange'
)
model_meta
.
update
(
util
.
chemical_species_preprocess
(
input_chem
))
# deprecation warnings
if
KEY
.
AVG_NUM_NEIGH
in
config
:
warnings
.
warn
(
"key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'."
' We use the default, the average number of neighbors in the'
' dataset, if not provided.'
,
UserWarning
,
)
config
.
pop
(
KEY
.
AVG_NUM_NEIGH
)
if
KEY
.
TRAIN_AVG_NUM_NEIGH
in
config
:
warnings
.
warn
(
"key 'train_avg_num_neigh' is deprecated. Please use"
" 'train_denominator'. We overwrite train_denominator as given"
' train_avg_num_neigh'
,
UserWarning
,
)
config
[
KEY
.
TRAIN_DENOMINTAOR
]
=
config
[
KEY
.
TRAIN_AVG_NUM_NEIGH
]
config
.
pop
(
KEY
.
TRAIN_AVG_NUM_NEIGH
)
if
KEY
.
OPTIMIZE_BY_REDUCE
in
config
:
warnings
.
warn
(
"key 'optimize_by_reduce' is deprecated. Always true"
,
UserWarning
,
)
config
.
pop
(
KEY
.
OPTIMIZE_BY_REDUCE
)
# init simpler ones
for
key
,
default
in
_const
.
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
.
items
():
model_meta
[
key
]
=
config_initialize
(
key
,
config
,
default
,
_const
.
MODEL_CONFIG_CONDITION
)
unknown_keys
=
[
key
for
key
in
config
.
keys
()
if
key
not
in
model_meta
.
keys
()
]
if
len
(
unknown_keys
)
!=
0
:
warnings
.
warn
(
f
'Unexpected model keys:
{
unknown_keys
}
will be ignored'
,
UserWarning
,
)
return
model_meta
def
init_train_config
(
config
:
Dict
):
train_meta
=
{}
# defaults = _const.train_defaults(config)
try
:
device_input
=
config
[
KEY
.
DEVICE
]
train_meta
[
KEY
.
DEVICE
]
=
torch
.
device
(
device_input
)
except
KeyError
:
train_meta
[
KEY
.
DEVICE
]
=
(
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
)
train_meta
[
KEY
.
DEVICE
]
=
str
(
train_meta
[
KEY
.
DEVICE
])
# init simpler ones
for
key
,
default
in
_const
.
DEFAULT_TRAINING_CONFIG
.
items
():
train_meta
[
key
]
=
config_initialize
(
key
,
config
,
default
,
_const
.
TRAINING_CONFIG_CONDITION
)
if
KEY
.
CONTINUE
in
config
.
keys
():
cnt_dct
=
config
[
KEY
.
CONTINUE
]
if
KEY
.
CHECKPOINT
not
in
cnt_dct
.
keys
():
raise
ValueError
(
'no checkpoint is given in continue'
)
checkpoint
=
cnt_dct
[
KEY
.
CHECKPOINT
]
if
os
.
path
.
isfile
(
checkpoint
):
checkpoint_file
=
checkpoint
else
:
checkpoint_file
=
util
.
pretrained_name_to_path
(
checkpoint
)
train_meta
[
KEY
.
CONTINUE
].
update
({
KEY
.
CHECKPOINT
:
checkpoint_file
})
unknown_keys
=
[
key
for
key
in
config
.
keys
()
if
key
not
in
train_meta
.
keys
()
]
if
len
(
unknown_keys
)
!=
0
:
warnings
.
warn
(
f
'Unexpected train keys:
{
unknown_keys
}
will be ignored'
,
UserWarning
,
)
return
train_meta
def
init_data_config
(
config
:
Dict
):
data_meta
=
{}
# defaults = _const.data_defaults(config)
load_data_keys
=
[]
for
k
in
config
:
if
k
.
startswith
(
'load_'
)
and
k
.
endswith
(
'_path'
):
load_data_keys
.
append
(
k
)
for
load_data_key
in
load_data_keys
:
if
load_data_key
in
config
.
keys
():
inp
=
config
[
load_data_key
]
extended
=
[]
if
type
(
inp
)
not
in
[
str
,
list
]:
raise
ValueError
(
f
'unexpected input
{
inp
}
for sturcture_list'
)
if
type
(
inp
)
is
str
:
extended
=
glob
.
glob
(
inp
)
elif
type
(
inp
)
is
list
:
for
i
in
inp
:
if
isinstance
(
i
,
str
):
extended
.
extend
(
glob
.
glob
(
i
))
elif
isinstance
(
i
,
dict
):
extended
.
append
(
i
)
if
len
(
extended
)
==
0
:
raise
ValueError
(
f
'Cannot find
{
inp
}
for
{
load_data_key
}
'
+
' or path is not given'
)
data_meta
[
load_data_key
]
=
extended
else
:
data_meta
[
load_data_key
]
=
False
for
key
,
default
in
_const
.
DEFAULT_DATA_CONFIG
.
items
():
data_meta
[
key
]
=
config_initialize
(
key
,
config
,
default
,
_const
.
DATA_CONFIG_CONDITION
)
unknown_keys
=
[
key
for
key
in
config
.
keys
()
if
key
not
in
data_meta
.
keys
()
]
if
len
(
unknown_keys
)
!=
0
:
warnings
.
warn
(
f
'Unexpected data keys:
{
unknown_keys
}
will be ignored'
,
UserWarning
,
)
return
data_meta
def
read_config_yaml
(
filename
:
str
,
return_separately
:
bool
=
False
):
with
open
(
filename
,
'r'
)
as
fstream
:
inputs
=
yaml
.
safe_load
(
fstream
)
model_meta
,
train_meta
,
data_meta
=
{},
{},
{}
for
key
,
config
in
inputs
.
items
():
if
key
==
'model'
:
model_meta
=
init_model_config
(
config
)
elif
key
==
'train'
:
train_meta
=
init_train_config
(
config
)
elif
key
==
'data'
:
data_meta
=
init_data_config
(
config
)
else
:
raise
ValueError
(
f
'Unexpected input
{
key
}
given'
)
if
return_separately
:
return
model_meta
,
train_meta
,
data_meta
else
:
model_meta
.
update
(
train_meta
)
model_meta
.
update
(
data_meta
)
return
model_meta
def
main
():
filename
=
'./input.yaml'
read_config_yaml
(
filename
)
if
__name__
==
'__main__'
:
main
()
mace-bench/3rdparty/SevenNet/sevenn/py.typed
0 → 100644
View file @
73866b01
mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py
0 → 100644
View file @
73866b01
import
warnings
from
.logger
import
*
# noqa: F403
warnings
.
warn
(
'Please use sevenn.logger instead of sevenn.sevenn_logger'
,
DeprecationWarning
,
stacklevel
=
2
)
mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py
0 → 100644
View file @
73866b01
import
warnings
from
.calculator
import
*
# noqa: F403
warnings
.
warn
(
'Please use sevenn.calculator instead of sevenn.sevennet_calculator'
,
DeprecationWarning
,
stacklevel
=
2
)
mace-bench/3rdparty/SevenNet/sevenn/util.py
0 → 100644
View file @
73866b01
import
os
import
os.path
as
osp
import
pathlib
import
shutil
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
requests
import
torch
import
torch.nn
from
e3nn.o3
import
FullTensorProduct
,
Irreps
from
tqdm
import
tqdm
import
sevenn._const
as
_const
import
sevenn._keys
as
KEY
def
to_atom_graph_list
(
atom_graph_batch
):
"""
torch_geometric batched data to separate list
original to_data_list() by PyG is not enough since
it doesn't handle inferred tensors
"""
is_stress
=
KEY
.
PRED_STRESS
in
atom_graph_batch
data_list
=
atom_graph_batch
.
to_data_list
()
indices
=
atom_graph_batch
[
KEY
.
NUM_ATOMS
].
tolist
()
atomic_energy_list
=
torch
.
split
(
atom_graph_batch
[
KEY
.
ATOMIC_ENERGY
],
indices
)
inferred_total_energy_list
=
torch
.
unbind
(
atom_graph_batch
[
KEY
.
PRED_TOTAL_ENERGY
]
)
inferred_force_list
=
torch
.
split
(
atom_graph_batch
[
KEY
.
PRED_FORCE
],
indices
)
inferred_stress_list
=
None
if
is_stress
:
inferred_stress_list
=
torch
.
unbind
(
atom_graph_batch
[
KEY
.
PRED_STRESS
])
for
i
,
data
in
enumerate
(
data_list
):
data
[
KEY
.
ATOMIC_ENERGY
]
=
atomic_energy_list
[
i
]
data
[
KEY
.
PRED_TOTAL_ENERGY
]
=
inferred_total_energy_list
[
i
]
data
[
KEY
.
PRED_FORCE
]
=
inferred_force_list
[
i
]
# To fit with KEY.STRESS (ref) format
if
is_stress
and
inferred_stress_list
is
not
None
:
data
[
KEY
.
PRED_STRESS
]
=
torch
.
unsqueeze
(
inferred_stress_list
[
i
],
0
)
return
data_list
def
error_recorder_from_loss_functions
(
loss_functions
):
from
.error_recorder
import
ErrorRecorder
,
MAError
,
RMSError
,
get_err_type
from
.train.loss
import
ForceLoss
,
PerAtomEnergyLoss
,
StressLoss
metrics
=
[]
for
loss_function
,
_
in
loss_functions
:
ref_key
=
loss_function
.
ref_key
pred_key
=
loss_function
.
pred_key
# unit = loss_function.unit
criterion
=
loss_function
.
criterion
name
=
loss_function
.
name
base
=
None
if
type
(
loss_function
)
is
PerAtomEnergyLoss
:
base
=
get_err_type
(
'Energy'
)
elif
type
(
loss_function
)
is
ForceLoss
:
base
=
get_err_type
(
'Force'
)
elif
type
(
loss_function
)
is
StressLoss
:
base
=
get_err_type
(
'Stress'
)
else
:
base
=
{}
base
[
'name'
]
=
name
base
[
'ref_key'
]
=
ref_key
base
[
'pred_key'
]
=
pred_key
if
type
(
criterion
)
is
torch
.
nn
.
MSELoss
:
base
[
'name'
]
=
base
[
'name'
]
+
'_RMSE'
metrics
.
append
(
RMSError
(
**
base
))
elif
type
(
criterion
)
is
torch
.
nn
.
L1Loss
:
metrics
.
append
(
MAError
(
**
base
))
return
ErrorRecorder
(
metrics
)
def
onehot_to_chem
(
one_hot_indices
:
List
[
int
],
type_map
:
Dict
[
int
,
int
]):
from
ase.data
import
chemical_symbols
type_map_rev
=
{
v
:
k
for
k
,
v
in
type_map
.
items
()}
return
[
chemical_symbols
[
type_map_rev
[
x
]]
for
x
in
one_hot_indices
]
def
model_from_checkpoint
(
checkpoint
:
str
,
)
->
Tuple
[
torch
.
nn
.
Module
,
Dict
]:
cp
=
load_checkpoint
(
checkpoint
)
model
=
cp
.
build_model
()
return
model
,
cp
.
config
def
model_from_checkpoint_with_backend
(
checkpoint
:
str
,
backend
:
str
=
'e3nn'
,
)
->
Tuple
[
torch
.
nn
.
Module
,
Dict
]:
cp
=
load_checkpoint
(
checkpoint
)
model
=
cp
.
build_model
(
backend
)
return
model
,
cp
.
config
def
unlabeled_atoms_to_input
(
atoms
,
cutoff
:
float
,
grad_key
:
str
=
KEY
.
EDGE_VEC
):
from
.atom_graph_data
import
AtomGraphData
from
.train.dataload
import
unlabeled_atoms_to_graph
atom_graph
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms
,
cutoff
)
)
atom_graph
[
grad_key
].
requires_grad_
(
True
)
atom_graph
[
KEY
.
BATCH
]
=
torch
.
zeros
([
0
])
return
atom_graph
def
chemical_species_preprocess
(
input_chem
:
List
[
str
],
universal
:
bool
=
False
):
from
ase.data
import
atomic_numbers
,
chemical_symbols
from
.nn.node_embedding
import
get_type_mapper_from_specie
config
=
{}
if
not
universal
:
input_chem
=
list
(
set
(
input_chem
))
chemical_specie
=
sorted
([
x
.
strip
()
for
x
in
input_chem
])
config
[
KEY
.
CHEMICAL_SPECIES
]
=
chemical_specie
config
[
KEY
.
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
]
=
[
atomic_numbers
[
x
]
for
x
in
chemical_specie
]
config
[
KEY
.
NUM_SPECIES
]
=
len
(
chemical_specie
)
config
[
KEY
.
TYPE_MAP
]
=
get_type_mapper_from_specie
(
chemical_specie
)
else
:
config
[
KEY
.
CHEMICAL_SPECIES
]
=
chemical_symbols
len_univ
=
len
(
chemical_symbols
)
config
[
KEY
.
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
]
=
list
(
range
(
len_univ
))
config
[
KEY
.
NUM_SPECIES
]
=
len_univ
config
[
KEY
.
TYPE_MAP
]
=
{
z
:
z
for
z
in
range
(
len_univ
)}
return
config
def
dtype_correct
(
v
:
Union
[
np
.
ndarray
,
torch
.
Tensor
,
int
,
float
],
float_dtype
:
torch
.
dtype
=
torch
.
float32
,
int_dtype
:
torch
.
dtype
=
torch
.
int64
,
):
if
isinstance
(
v
,
np
.
ndarray
):
if
np
.
issubdtype
(
v
.
dtype
,
np
.
floating
):
return
torch
.
from_numpy
(
v
).
to
(
float_dtype
)
elif
np
.
issubdtype
(
v
.
dtype
,
np
.
integer
):
return
torch
.
from_numpy
(
v
).
to
(
int_dtype
)
elif
isinstance
(
v
,
torch
.
Tensor
):
if
v
.
dtype
.
is_floating_point
:
return
v
.
to
(
float_dtype
)
# convert to specified float dtype
else
:
# assuming non-floating point tensors are integers
return
v
.
to
(
int_dtype
)
# convert to specified int dtype
else
:
# scalar values
if
isinstance
(
v
,
int
):
return
torch
.
tensor
(
v
,
dtype
=
int_dtype
)
elif
isinstance
(
v
,
float
):
return
torch
.
tensor
(
v
,
dtype
=
float_dtype
)
else
:
# Not numeric
return
v
def
infer_irreps_out
(
irreps_x
:
Irreps
,
irreps_operand
:
Irreps
,
drop_l
:
Union
[
bool
,
int
]
=
False
,
parity_mode
:
str
=
'full'
,
fix_multiplicity
:
Union
[
bool
,
int
]
=
False
,
):
assert
parity_mode
in
[
'full'
,
'even'
,
'sph'
]
# (mul, (ir, p))
irreps_out
=
FullTensorProduct
(
irreps_x
,
irreps_operand
).
irreps_out
.
simplify
()
new_irreps_elem
=
[]
for
mul
,
(
l
,
p
)
in
irreps_out
:
# noqa
elem
=
(
mul
,
(
l
,
p
))
if
drop_l
is
not
False
and
l
>
drop_l
:
continue
if
parity_mode
==
'even'
and
p
==
-
1
:
continue
elif
parity_mode
==
'sph'
and
p
!=
(
-
1
)
**
l
:
continue
if
fix_multiplicity
:
elem
=
(
fix_multiplicity
,
(
l
,
p
))
new_irreps_elem
.
append
(
elem
)
return
Irreps
(
new_irreps_elem
)
def
download_checkpoint
(
path
:
str
,
url
:
str
):
fname
=
osp
.
basename
(
path
)
temp_path
=
path
+
'.partial'
try
:
# raises permission error if fails
os
.
makedirs
(
osp
.
dirname
(
path
),
exist_ok
=
True
)
response
=
requests
.
get
(
url
,
stream
=
True
,
timeout
=
30
)
response
.
raise_for_status
()
# Raise exception for bad status codes
total_size
=
int
(
response
.
headers
.
get
(
'content-length'
,
0
))
block_size
=
1024
# 1 KB chunks
progress_bar
=
tqdm
(
total
=
total_size
,
unit
=
'B'
,
unit_scale
=
True
,
desc
=
f
'Downloading
{
fname
}
'
,
)
with
open
(
temp_path
,
'wb'
)
as
file
:
for
data
in
response
.
iter_content
(
block_size
):
progress_bar
.
update
(
len
(
data
))
file
.
write
(
data
)
progress_bar
.
close
()
shutil
.
move
(
temp_path
,
path
)
print
(
f
'Checkpoint downloaded:
{
path
}
'
)
return
path
except
PermissionError
:
raise
except
Exception
as
e
:
# Clean up partial downloads on failure
# May not work as errors handled internally by tqdm etc.
print
(
f
'Download failed:
{
str
(
e
)
}
'
)
if
os
.
path
.
exists
(
temp_path
):
print
(
f
'Cleaning up partial download:
{
temp_path
}
'
)
os
.
remove
(
temp_path
)
raise
def
pretrained_name_to_path
(
name
:
str
)
->
str
:
name
=
name
.
lower
()
heads
=
[
'sevennet'
,
'7net'
]
checkpoint_path
=
None
url
=
None
if
(
# TODO: regex
name
in
[
f
'
{
n
}
-0_11july2024'
for
n
in
heads
]
or
name
in
[
f
'
{
n
}
-0_11jul2024'
for
n
in
heads
]
or
name
in
[
'sevennet-0'
,
'7net-0'
]
):
checkpoint_path
=
_const
.
SEVENNET_0_11Jul2024
elif
name
in
[
f
'
{
n
}
-0_22may2024'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_0_22May2024
elif
name
in
[
f
'
{
n
}
-l3i5'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_l3i5
elif
name
in
[
f
'
{
n
}
-mf-0'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_MF_0
elif
name
in
[
f
'
{
n
}
-mf-ompa'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_MF_ompa
elif
name
in
[
f
'
{
n
}
-omat'
for
n
in
heads
]:
checkpoint_path
=
_const
.
SEVENNET_omat
else
:
raise
ValueError
(
'Not a valid pretrained model name'
)
url
=
_const
.
CHECKPOINT_DOWNLOAD_LINKS
.
get
(
checkpoint_path
)
paths
=
[
checkpoint_path
,
checkpoint_path
.
replace
(
_const
.
_prefix
,
osp
.
expanduser
(
'~/.cache/sevennet'
)),
]
for
path
in
paths
:
if
osp
.
exists
(
path
):
return
path
# File not found check url and try download
if
url
is
None
:
raise
FileNotFoundError
(
checkpoint_path
)
try
:
return
download_checkpoint
(
paths
[
0
],
url
)
# 7net package path
except
PermissionError
:
return
download_checkpoint
(
paths
[
1
],
url
)
# ~/.cache
def
load_checkpoint
(
checkpoint
:
Union
[
pathlib
.
Path
,
str
]):
from
sevenn.checkpoint
import
SevenNetCheckpoint
suggests
=
[
'7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat'
]
if
osp
.
isfile
(
checkpoint
):
checkpoint_path
=
checkpoint
else
:
try
:
checkpoint_path
=
pretrained_name_to_path
(
str
(
checkpoint
))
except
ValueError
:
raise
ValueError
(
f
'Given
{
checkpoint
}
is not exists and not a pre-trained name.
\n
'
f
'Valid pretrained model names:
{
suggests
}
'
)
return
SevenNetCheckpoint
(
checkpoint_path
)
def
unique_filepath
(
filepath
:
str
)
->
str
:
if
not
os
.
path
.
isfile
(
filepath
):
return
filepath
else
:
dirname
=
os
.
path
.
dirname
(
filepath
)
fname
=
os
.
path
.
basename
(
filepath
)
name
,
ext
=
os
.
path
.
splitext
(
fname
)
cnt
=
0
new_name
=
f
'
{
name
}{
cnt
}{
ext
}
'
new_path
=
os
.
path
.
join
(
dirname
,
new_name
)
while
os
.
path
.
exists
(
new_path
):
cnt
+=
1
new_name
=
f
'
{
name
}{
cnt
}{
ext
}
'
new_path
=
os
.
path
.
join
(
dirname
,
new_name
)
return
new_path
def
get_error_recorder
(
recorder_tuples
:
List
[
Tuple
[
str
,
str
]]
=
[
(
'Energy'
,
'RMSE'
),
(
'Force'
,
'RMSE'
),
(
'Stress'
,
'RMSE'
),
(
'Energy'
,
'MAE'
),
(
'Force'
,
'MAE'
),
(
'Stress'
,
'MAE'
),
],
):
# TODO add criterion argument and loss recorder selections
import
sevenn.error_recorder
as
error_recorder
config
=
recorder_tuples
err_metrics
=
[]
for
err_type
,
metric_name
in
config
:
metric_kwargs
=
error_recorder
.
get_err_type
(
err_type
).
copy
()
metric_kwargs
[
'name'
]
+=
f
'_
{
metric_name
}
'
metric_cls
=
error_recorder
.
ErrorRecorder
.
METRIC_DICT
[
metric_name
]
err_metrics
.
append
(
metric_cls
(
**
metric_kwargs
))
return
error_recorder
.
ErrorRecorder
(
err_metrics
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment