Commit cc6e6b7d authored by Wang, Leping's avatar Wang, Leping
Browse files

- Add config.sh with all pipeline parameters organized by category

  (molecular, crystal structure, compute, run mode, path)
- Refactor search_gen_proc.sh to source config.sh instead of
  hardcoding parameters, with optional config path argument
- Refactor structure_generate.py to load config.sh via exec(),
  replacing hardcoded values with config-driven parameters
- Remove mace-bench (the relaxation part, it will be replaced by updated seperate mace-bench project )
parent 61ec3ad9
import ctypes
import os
import pathlib
import warnings
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.jit
import torch.jit._script
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.mixing import SumCalculator
from ase.data import chemical_symbols
import sevenn._keys as KEY
import sevenn.util as util
from sevenn.atom_graph_data import AtomGraphData
from sevenn.nn.sequential import AtomGraphSequential
from sevenn.train.dataload import unlabeled_atoms_to_graph
import logging
torch_script_type = torch.jit._script.RecursiveScriptModule
class SevenNetCalculator(Calculator):
"""Supporting properties:
'free_energy', 'energy', 'forces', 'stress', 'energies'
free_energy equals energy. 'energies' stores atomic energy.
Multi-GPU acceleration is not supported with ASE calculator.
You should use LAMMPS for the acceleration.
"""
def __init__(
self,
model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0',
file_type: str = 'checkpoint',
device: Union[torch.device, str] = 'auto',
modal: Optional[str] = None,
enable_cueq: bool = False,
sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info
**kwargs,
):
"""Initialize SevenNetCalculator.
Parameters
----------
model: str | Path | AtomGraphSequential, default='7net-0'
Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or
path to the checkpoint, deployed model or the model itself
file_type: str, default='checkpoint'
one of 'checkpoint' | 'torchscript' | 'model_instance'
device: str | torch.device, default='auto'
if not given, use CUDA if available
modal: str | None, default=None
modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa,
it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24)
case insensitive
enable_cueq: bool, default=False
if True, use cuEquivariant to accelerate inference.
sevennet_config: dict | None, default=None
Not used, but can be used to carry meta information of this calculator
"""
print("&&& Initializing SevenNetCalculator")
super().__init__(**kwargs)
self.sevennet_config = None
if isinstance(model, pathlib.PurePath):
model = str(model)
allowed_file_types = ['checkpoint', 'torchscript', 'model_instance']
file_type = file_type.lower()
if file_type not in allowed_file_types:
raise ValueError(f'file_type not in {allowed_file_types}')
if enable_cueq and file_type in ['model_instance', 'torchscript']:
warnings.warn(
'file_type should be checkpoint to enable cueq. cueq set to False'
)
enable_cueq = False
if isinstance(device, str): # TODO: do we really need this?
if device == 'auto':
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu'
)
else:
self.device = torch.device(device)
else:
self.device = device
if file_type == 'checkpoint' and isinstance(model, str):
cp = util.load_checkpoint(model)
backend = 'e3nn' if not enable_cueq else 'cueq'
model_loaded = cp.build_model(backend)
model_loaded.set_is_batch_data(False)
self.type_map = cp.config[KEY.TYPE_MAP]
self.cutoff = cp.config[KEY.CUTOFF]
self.sevennet_config = cp.config
elif file_type == 'torchscript' and isinstance(model, str):
if modal:
raise NotImplementedError()
extra_dict = {
'chemical_symbols_to_index': b'',
'cutoff': b'',
'num_species': b'',
'model_type': b'',
'version': b'',
'dtype': b'',
'time': b'',
}
model_loaded = torch.jit.load(
model, _extra_files=extra_dict, map_location=self.device
)
chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8')
sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)}
self.type_map = {
sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split())
}
self.cutoff = float(extra_dict['cutoff'].decode('utf-8'))
elif isinstance(model, AtomGraphSequential):
if model.type_map is None:
raise ValueError(
'Model must have the type_map to be used with calculator'
)
if model.cutoff == 0.0:
raise ValueError('Model cutoff seems not initialized')
model.eval_type_map = torch.tensor(True) # ?
model.set_is_batch_data(False)
model_loaded = model
self.type_map = model.type_map
self.cutoff = model.cutoff
else:
raise ValueError('Unexpected input combinations')
if self.sevennet_config is None and sevennet_config is not None:
self.sevennet_config = sevennet_config
self.model = model_loaded
self.modal = None
if isinstance(self.model, AtomGraphSequential):
modal_map = self.model.modal_map
if modal_map:
modal_ava = list(modal_map.keys())
if not modal:
raise ValueError(f'modal argument missing (avail: {modal_ava})')
elif modal not in modal_ava:
raise ValueError(f'unknown modal {modal} (not in {modal_ava})')
self.modal = modal
elif not self.model.modal_map and modal:
warnings.warn(f'modal={modal} is ignored as model has no modal_map')
self.model.to(self.device)
self.model.eval()
self.implemented_properties = [
'free_energy',
'energy',
'forces',
'stress',
'energies',
]
def set_atoms(self, atoms):
# called by ase, when atoms.calc = calc
zs = tuple(set(atoms.get_atomic_numbers()))
for z in zs:
if z not in self.type_map:
sp = list(self.type_map.keys())
raise ValueError(
f'Model do not know atomic number: {z}, (knows: {sp})'
)
def output_to_results(self, output):
energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item()
num_atoms = output['num_atoms'].item()
atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten()
forces = output[KEY.PRED_FORCE].detach().cpu().numpy()[:num_atoms, :]
stress = np.array(
(-output[KEY.PRED_STRESS])
.detach()
.cpu()
.numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation
)
# Store results
return {
'free_energy': energy,
'energy': energy,
'energies': atomic_energies,
'forces': forces,
'stress': stress,
'num_edges': output[KEY.EDGE_IDX].shape[1],
}
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
# call parent class to set necessary atom attributes
Calculator.calculate(self, atoms, properties, system_changes)
if atoms is None:
raise ValueError('No atoms to evaluate')
data = AtomGraphData.from_numpy_dict(
unlabeled_atoms_to_graph(atoms, self.cutoff)
)
if self.modal:
data[KEY.DATA_MODALITY] = self.modal
data.to(self.device) # type: ignore
if isinstance(self.model, torch_script_type):
data[KEY.NODE_FEATURE] = torch.tensor(
[self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]],
dtype=torch.int64,
device=self.device,
)
data[KEY.POS].requires_grad_(True) # backward compatibility
data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility
data = data.to_dict()
del data['data_info']
import logging
logging.debug(f"data: {data}")
# logging.debug(f"data[pos]: {data['pos']}")
# logging.debug(f"data[x]: {data['x']}")
logging.debug(f"data[cell_lattice_vectors]: {data['cell_lattice_vectors']}")
logging.debug(f"data[cell_volume]: {data['cell_volume']}")
output = self.model(data)
# logging.info(f"input: {data}")
# logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}")
# logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}")
# logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}")
self.results = self.output_to_results(output)
# logging.debug(f"results['energy'] = {self.results['energy']}")
# logging.debug(f"results['forces'] = {self.results['forces']}")
# logging.debug(f"results['stress'] = {self.results['stress']}")
def predict_one(self, atoms):
if atoms is None:
raise ValueError('No atoms to evaluate')
data = AtomGraphData.from_numpy_dict(
unlabeled_atoms_to_graph(atoms, self.cutoff)
)
if self.modal:
data[KEY.DATA_MODALITY] = self.modal
data.to(self.device) # type: ignore
if isinstance(self.model, torch_script_type):
data[KEY.NODE_FEATURE] = torch.tensor(
[self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]],
dtype=torch.int64,
device=self.device,
)
data[KEY.POS].requires_grad_(True) # backward compatibility
data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility
data = data.to_dict()
del data['data_info']
return self.model(data)
def predict(self, atoms_list, properties=None):
# if len(atoms_list) == 1:
# output = self.predict_one(atoms_list[0])
# predictions = {}
# predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0)
# predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0)
# voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0)
# stress_list = []
# for i in range(voigt.shape[0]):
# stress_list.append(self._stress2tensor(voigt[i,:]))
# predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3)
# return predictions
if not atoms_list:
raise ValueError("Empty atoms_list provided")
if not isinstance(atoms_list, list):
atoms_list = [atoms_list]
# Convert atoms to graph data
graph_list = []
for atoms in atoms_list:
data = AtomGraphData.from_numpy_dict(
unlabeled_atoms_to_graph(atoms, self.cutoff)
)
if self.modal:
data[KEY.DATA_MODALITY] = self.modal
if isinstance(self.model, torch_script_type):
data[KEY.NODE_FEATURE] = torch.tensor(
[self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]],
dtype=torch.int64,
device=self.device,
)
data[KEY.POS].requires_grad_(True) # backward compatibility
data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility
graph_list.append(data)
# Process graphs based on model type
# was_batch_mode = True
if isinstance(self.model, AtomGraphSequential):
# was_batch_mode = self.model.is_batch_data
self.model.set_is_batch_data(True)
self.model.eval()
# Batch the data if there are multiple atoms
from torch_geometric.loader.dataloader import Collater
batched_data = Collater(graph_list)(graph_list)
batched_data = batched_data.to(self.device)
import logging
logging.debug(f"batched_data: {batched_data}")
# logging.debug(f"batched_data[pos]: {batched_data['pos']}")
# logging.debug(f"batched_data[x]: {batched_data['x']}")
logging.debug(f"batched_data[cell_lattice_vectors]: {batched_data['cell_lattice_vectors']}")
logging.debug(f"batched_data[cell_volume]: {batched_data['cell_volume']}")
# Run model on batched data
if isinstance(self.model, torch_script_type):
batched_dict = batched_data.to_dict()
if 'data_info' in batched_dict:
del batched_dict['data_info']
output = self.model(batched_dict)
else:
output = self.model(batched_data)
# Convert to list of individual outputs using util.to_atom_graph_list
# logging.info(f"input: {batched_data}")
# logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}")
# logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}")
# logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}")
predictions = {}
predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).detach()
predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).detach()
voigt = (-output[KEY.PRED_STRESS])[:, [0, 1, 2, 4, 5, 3]].to(torch.float64).detach()
stress_list = []
for i in range(voigt.shape[0]):
stress_list.append(self._stress2tensor(voigt[i,:]))
predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3).detach()
# logging.debug(f"predictions['energy'] = {predictions['energy']}")
# logging.debug(f"predictions['forces'] = {predictions['forces']}")
# logging.debug(f"predictions['stress'] = {predictions['stress']}")
return predictions
def _stress2tensor(self, stress):
tensor = torch.tensor(
[
[stress[0], stress[5], stress[4]],
[stress[5], stress[1], stress[3]],
[stress[4], stress[3], stress[2]],
],
device=self.device
)
return tensor
class SevenNetD3Calculator(SumCalculator):
def __init__(
self,
model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0',
file_type: str = 'checkpoint',
device: Union[torch.device, str] = 'auto',
sevennet_config: Optional[Any] = None, # hold meta information
damping_type: str = 'damp_bj',
functional_name: str = 'pbe',
vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au
cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au
batch_size=10,
**kwargs,
):
"""Initialize SevenNetD3Calculator. CUDA required.
Parameters
----------
model: str | Path | AtomGraphSequential
Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or
path to the checkpoint, deployed model or the model itself
file_type: str, default='checkpoint'
one of 'checkpoint' | 'torchscript' | 'model_instance'
device: str | torch.device, default='auto'
if not given, use CUDA if available
modal: str | None, default=None
modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa,
it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24)
enable_cueq: bool, default=False
if True, use cuEquivariant to accelerate inference.
damping_type: str, default='damp_bj'
Damping type of D3, one of 'damp_bj' | 'damp_zero'
functional_name: str, default='pbe'
Target functional name of D3 parameters.
vdw_cutoff: float, default=9000
vdw cutoff of D3 calculator in au
cn_cutoff: float, default=1600
cn cutoff of D3 calculator in au
"""
self.d3_calc = D3Calculator(
damping_type=damping_type,
functional_name=functional_name,
vdw_cutoff=vdw_cutoff,
cn_cutoff=cn_cutoff,
**kwargs,
)
self.sevennet_calc = SevenNetCalculator(
model=model,
file_type=file_type,
device=device,
sevennet_config=sevennet_config,
**kwargs,
)
super().__init__([self.sevennet_calc, self.d3_calc])
self.device = device
self.d3_calcs = []
for _ in range(batch_size):
self.d3_calcs.append(
D3Calculator(
damping_type=damping_type,
functional_name=functional_name,
vdw_cutoff=vdw_cutoff,
cn_cutoff=cn_cutoff,
**kwargs,
)
)
def predict(self, atoms_list):
"""Predict the energy and forces for a list of atoms.
"""
# Call the predict method of the first calculator (SevenNetCalculator)
predictions = self.sevennet_calc.predict(atoms_list)
energy_list = []
forces_list = []
stress_list = []
predictions3d = {}
for i, atoms in enumerate(atoms_list):
prediction = self.d3_calcs[i].predict_one(atoms)
energy_list.append(torch.tensor(prediction['energy']))
forces_list.append(torch.from_numpy(prediction['forces']).to(self.device))
stress_list.append(self._stress2tensor(torch.from_numpy(prediction['stress'])))
# Convert lists to tensors
predictions3d['energy'] = torch.stack(energy_list, dim=0).to(self.device)
predictions3d['forces'] = torch.cat(forces_list, dim=0).view(-1, 3)
predictions3d['stress'] = torch.stack(stress_list, dim=0).view(-1, 3, 3)
predictions['energy'] += predictions3d['energy'].detach()
predictions['forces'] += predictions3d['forces'].detach()
predictions['stress'] += predictions3d['stress'].detach()
return predictions
def _stress2tensor(self, stress):
tensor = torch.tensor(
[
# [stress[0], stress[3], stress[4]],
# [stress[3], stress[1], stress[5]],
# [stress[4], stress[5], stress[2]],
[stress[0], stress[5], stress[4]],
[stress[5], stress[1], stress[3]],
[stress[4], stress[3], stress[2]],
],
device=self.device
)
return tensor
def _load(name: str) -> ctypes.CDLL:
from torch.utils.cpp_extension import LIB_EXT, _get_build_directory, load
# Load the library from the candidate locations
package_dir = os.path.dirname(os.path.abspath(__file__))
try:
return ctypes.CDLL(os.path.join(package_dir, f'{name}{LIB_EXT}'))
except OSError:
pass
cache_dir = _get_build_directory(name, verbose=False)
try:
return ctypes.CDLL(os.path.join(cache_dir, f'{name}{LIB_EXT}'))
except OSError:
pass
# Compile the library if it is not found
if os.access(package_dir, os.W_OK):
compile_dir = package_dir
else:
print('Warning: package directory is not writable. Using cache directory.')
compile_dir = cache_dir
if 'TORCH_CUDA_ARCH_LIST' not in os.environ:
print('Warning: TORCH_CUDA_ARCH_LIST is not set.')
print('Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90')
os.environ['TORCH_CUDA_ARCH_LIST'] = '6.1;7.0;7.5;8.0;8.6;8.9;9.0'
load(
name=name,
sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')],
extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'],
build_directory=compile_dir,
verbose=True,
is_python_module=False,
)
return ctypes.CDLL(os.path.join(compile_dir, f'{name}{LIB_EXT}'))
class PairD3(ctypes.Structure):
pass # Opaque structure; only used as a pointer
class D3Calculator(Calculator):
"""ASE calculator for accelerated D3 van der Waals (vdW) correction.
Example:
from ase.calculators.mixing import SumCalculator
calc_1 = SevenNetCalculator()
calc_2 = D3Calculator()
return SumCalculator([calc_1, calc_2])
This calculator interfaces with the `libpaird3.so` library,
which is compiled by nvcc during the package installation.
If you encounter any errors, please verify
the installation process and the compilation options in `setup.py`.
Note: Multi-GPU parallel MD is not supported in this mode.
Note: Cffi could be used, but it was avoided to reduce dependencies.
"""
# Here, free_energy = energy
implemented_properties = ['free_energy', 'energy', 'forces', 'stress']
def __init__(
self,
damping_type: str = 'damp_bj', # damp_bj, damp_zero
functional_name: str = 'pbe', # check the source code
vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au
cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au
**kwargs,
):
super().__init__(**kwargs)
if not torch.cuda.is_available():
raise NotImplementedError('CPU + D3 is not implemented yet')
self.rthr = vdw_cutoff
self.cnthr = cn_cutoff
self.damp_name = damping_type.lower()
self.func_name = functional_name.lower()
if self.damp_name not in ['damp_bj', 'damp_zero']:
raise ValueError('Error: Invalid damping type.')
self._lib = _load('pair_d3')
self._lib.pair_init.restype = ctypes.POINTER(PairD3)
self.pair = self._lib.pair_init()
self._lib.pair_set_atom.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair
ctypes.c_int, # int natoms
ctypes.c_int, # int ntypes
ctypes.POINTER(ctypes.c_int), # int* types
ctypes.POINTER(ctypes.c_double), # double* x
]
self._lib.pair_set_atom.restype = None
self._lib.pair_set_domain.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair
ctypes.c_int, # int xperiodic
ctypes.c_int, # int yperiodic
ctypes.c_int, # int zperiodic
ctypes.POINTER(ctypes.c_double), # double* boxlo
ctypes.POINTER(ctypes.c_double), # double* boxhi
ctypes.c_double, # double xy
ctypes.c_double, # double xz
ctypes.c_double, # double yz
]
self._lib.pair_set_domain.restype = None
self._lib.pair_run_settings.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair
ctypes.c_double, # double rthr
ctypes.c_double, # double cnthr
ctypes.c_char_p, # const char* damp_name
ctypes.c_char_p, # const char* func_name
]
self._lib.pair_run_settings.restype = None
self._lib.pair_run_coeff.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair
ctypes.POINTER(ctypes.c_int), # int* atomic_numbers
]
self._lib.pair_run_coeff.restype = None
self._lib.pair_run_compute.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_run_compute.restype = None
self._lib.pair_get_energy.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_get_energy.restype = ctypes.c_double
self._lib.pair_get_force.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_get_force.restype = ctypes.POINTER(ctypes.c_double)
self._lib.pair_get_stress.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_get_stress.restype = ctypes.POINTER(ctypes.c_double * 6)
self._lib.pair_fin.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_fin.restype = None
def _idx_to_numbers(self, Z_of_atoms):
unique_numbers = list(dict.fromkeys(Z_of_atoms))
return unique_numbers
def _idx_to_types(self, Z_of_atoms):
unique_numbers = list(dict.fromkeys(Z_of_atoms))
mapping = {num: idx + 1 for idx, num in enumerate(unique_numbers)}
atom_types = [mapping[num] for num in Z_of_atoms]
return atom_types
def _convert_domain_ase2lammps(self, cell):
qtrans, ltrans = np.linalg.qr(cell.T, mode='complete')
lammps_cell = ltrans.T
signs = np.sign(np.diag(lammps_cell))
lammps_cell = lammps_cell * signs
qtrans = qtrans * signs
lammps_cell = lammps_cell[(0, 1, 2, 1, 2, 2), (0, 1, 2, 0, 0, 1)]
rotator = qtrans.T
return lammps_cell, rotator
def _stress2tensor(self, stress):
tensor = np.array(
[
[stress[0], stress[3], stress[4]],
[stress[3], stress[1], stress[5]],
[stress[4], stress[5], stress[2]],
]
)
return tensor
def _tensor2stress(self, tensor):
stress = -np.array(
[
tensor[0, 0],
tensor[1, 1],
tensor[2, 2],
tensor[1, 2],
tensor[0, 2],
tensor[0, 1],
]
)
return stress
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
Calculator.calculate(self, atoms, properties, system_changes)
if atoms is None:
raise ValueError('No atoms to evaluate')
if atoms.get_cell().sum() == 0:
print(
'Warning: D3Calculator requires a cell.\n'
'Warning: An orthogonal cell large enough is generated.'
)
positions = atoms.get_positions()
min_pos = positions.min(axis=0)
max_pos = positions.max(axis=0)
max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726
cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin
cell = np.eye(3) * cell_lengths
atoms.set_cell(cell)
atoms.set_pbc([True, True, True]) # for minus positions
cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell())
Z_of_atoms = atoms.get_atomic_numbers()
natoms = len(atoms)
ntypes = len(set(Z_of_atoms))
types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms))
positions = atoms.get_positions() @ rotator.T
x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten())
atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms))
boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0)
boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2])
xy = cell[3]
xz = cell[4]
yz = cell[5]
xperiodic, yperiodic, zperiodic = atoms.get_pbc()
lib = self._lib
assert lib is not None
lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat)
xperiodic = xperiodic.astype(int)
yperiodic = yperiodic.astype(int)
zperiodic = zperiodic.astype(int)
lib.pair_set_domain(
self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz
)
lib.pair_run_settings(
self.pair,
self.rthr,
self.cnthr,
self.damp_name.encode('utf-8'),
self.func_name.encode('utf-8'),
)
lib.pair_run_coeff(self.pair, atomic_numbers)
lib.pair_run_compute(self.pair)
result_E = lib.pair_get_energy(self.pair)
result_F_ptr = lib.pair_get_force(self.pair)
result_F_size = natoms * 3
result_F = np.ctypeslib.as_array(
result_F_ptr, shape=(result_F_size,)
).reshape((natoms, 3))
result_F = np.array(result_F)
result_F = result_F @ rotator
result_S = lib.pair_get_stress(self.pair)
result_S = np.array(result_S.contents)
result_S = (
self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator)
/ atoms.get_volume()
)
self.results = {
'free_energy': result_E,
'energy': result_E,
'forces': result_F,
'stress': result_S,
}
def predict_one(self, atoms):
atoms = atoms.copy()
if atoms is None:
raise ValueError('No atoms to evaluate')
if atoms.get_cell().sum() == 0:
print(
'Warning: D3Calculator requires a cell.\n'
'Warning: An orthogonal cell large enough is generated.'
)
positions = atoms.get_positions()
min_pos = positions.min(axis=0)
max_pos = positions.max(axis=0)
max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726
cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin
cell = np.eye(3) * cell_lengths
atoms.set_cell(cell)
atoms.set_pbc([True, True, True]) # for minus positions
cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell())
Z_of_atoms = atoms.get_atomic_numbers()
natoms = len(atoms)
ntypes = len(set(Z_of_atoms))
types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms))
positions = atoms.get_positions() @ rotator.T
x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten())
atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms))
boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0)
boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2])
xy = cell[3]
xz = cell[4]
yz = cell[5]
xperiodic, yperiodic, zperiodic = atoms.get_pbc()
lib = self._lib
assert lib is not None
lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat)
xperiodic = xperiodic.astype(int)
yperiodic = yperiodic.astype(int)
zperiodic = zperiodic.astype(int)
lib.pair_set_domain(
self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz
)
lib.pair_run_settings(
self.pair,
self.rthr,
self.cnthr,
self.damp_name.encode('utf-8'),
self.func_name.encode('utf-8'),
)
lib.pair_run_coeff(self.pair, atomic_numbers)
lib.pair_run_compute(self.pair)
result_E = lib.pair_get_energy(self.pair)
result_F_ptr = lib.pair_get_force(self.pair)
result_F_size = natoms * 3
result_F = np.ctypeslib.as_array(
result_F_ptr, shape=(result_F_size,)
).reshape((natoms, 3))
result_F = np.array(result_F)
result_F = result_F @ rotator
result_S = lib.pair_get_stress(self.pair)
result_S = np.array(result_S.contents)
result_S = (
self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator)
/ atoms.get_volume()
)
prediction = {
'free_energy': float(result_E),
'energy': float(result_E),
'forces': result_F.copy(),
'stress': result_S.copy(),
}
return prediction
def __del__(self):
if self._lib is not None:
self._lib.pair_fin(self.pair)
self._lib = None
self.pair = None
import os
import pathlib
import uuid
import warnings
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, Optional, Union
import pandas as pd
from packaging.version import Version
from torch import Tensor
from torch import load as torch_load
import sevenn
import sevenn._const as consts
import sevenn._keys as KEY
import sevenn.scripts.backward_compatibility as compat
from sevenn import model_build
from sevenn.nn.scale import get_resolved_shift_scale
from sevenn.nn.sequential import AtomGraphSequential
def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6):
import numpy as np
def acl(a, b, rtol=rtol, atol=atol):
return np.allclose(a, b, rtol=rtol, atol=atol)
assert len(atoms1) == len(atoms2)
assert acl(atoms1.get_cell(), atoms2.get_cell())
assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy())
assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10)
assert acl(
atoms1.get_stress(voigt=False),
atoms2.get_stress(voigt=False),
rtol * 10,
atol * 10,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
def copy_state_dict(state_dict) -> dict:
if isinstance(state_dict, dict):
return {key: copy_state_dict(value) for key, value in state_dict.items()}
elif isinstance(state_dict, list):
return [copy_state_dict(item) for item in state_dict] # type: ignore
elif isinstance(state_dict, Tensor):
return state_dict.clone() # type: ignore
else:
# For non-tensor values (e.g., scalars, None), return as-is
return state_dict
def _config_cp_routine(config):
cp_ver = Version(config.get('version', None))
this_ver = Version(sevenn.__version__)
if cp_ver > this_ver:
warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source'
f'({this_ver}). This may cause unexpected behaviors')
defaults = {**consts.model_defaults(config)}
config = compat.patch_old_config(config) # type: ignore
scaler = model_build.init_shift_scale(config)
shift, scale = get_resolved_shift_scale(
scaler, config.get(KEY.TYPE_MAP), config.get(KEY.MODAL_MAP, None)
)
config['shift'] = shift
config['scale'] = scale
for k, v in defaults.items():
if k in config:
continue
if os.getenv('SEVENN_DEBUG', False):
warnings.warn(f'{k} not in config, use default value {v}', UserWarning)
config[k] = v
for k, v in config.items():
if isinstance(v, Tensor):
config[k] = v.cpu()
return config
def _convert_e3nn_and_cueq(stct_src, stct_dst, src_config, from_cueq):
"""
manually check keys and assert if something unexpected happens
"""
n_layer = src_config['num_convolution_layer']
linear_module_names = [
'onehot_to_feature_x',
'reduce_input_to_hidden',
'reduce_hidden_to_energy',
]
convolution_module_names = []
fc_tensor_product_module_names = []
for i in range(n_layer):
linear_module_names.append(f'{i}_self_interaction_1')
linear_module_names.append(f'{i}_self_interaction_2')
if src_config.get(KEY.SELF_CONNECTION_TYPE) == 'linear':
linear_module_names.append(f'{i}_self_connection_intro')
elif src_config.get(KEY.SELF_CONNECTION_TYPE) == 'nequip':
fc_tensor_product_module_names.append(f'{i}_self_connection_intro')
convolution_module_names.append(f'{i}_convolution')
# Rule: those keys can be safely ignored before state dict load,
# except for linear.bias. This should be aborted in advance to
# this function. Others are not parameters but constants.
cue_only_linear_followers = ['linear.f.tp.f_fx.module.c']
e3nn_only_linear_followers = ['linear.bias', 'linear.output_mask']
ignores_in_linear = cue_only_linear_followers + e3nn_only_linear_followers
cue_only_conv_followers = [
'convolution.f.tp.f_fx.module.c',
'convolution.f.tp.module.module.f.module.module._f.data',
]
e3nn_only_conv_followers = [
'convolution._compiled_main_left_right._w3j',
'convolution.weight',
'convolution.output_mask',
]
ignores_in_conv = cue_only_conv_followers + e3nn_only_conv_followers
cue_only_fc_followers = ['fc_tensor_product.f.tp.f_fx.module.c']
e3nn_only_fc_followers = [
'fc_tensor_product.output_mask',
]
ignores_in_fc = cue_only_fc_followers + e3nn_only_fc_followers
updated_keys = []
for k, v in stct_src.items():
module_name = k.split('.')[0]
flag = False
if module_name in linear_module_names:
for ignore in ignores_in_linear:
if '.'.join([module_name, ignore]) in k:
flag = True
break
if not flag and k == '.'.join([module_name, 'linear.weight']):
updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
flag = True
assert flag, f'Unexpected key from linear: {k}'
elif module_name in convolution_module_names:
for ignore in ignores_in_conv:
if '.'.join([module_name, ignore]) in k:
flag = True
break
if not flag and (
k.startswith(f'{module_name}.weight_nn')
or k == '.'.join([module_name, 'denominator'])
):
updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
flag = True
assert flag, f'Unexpected key from linear: {k}'
elif module_name in fc_tensor_product_module_names:
for ignore in ignores_in_fc:
if '.'.join([module_name, ignore]) in k:
flag = True
break
if not flag and k == '.'.join([module_name, 'fc_tensor_product.weight']):
updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
flag = True
assert flag, f'Unexpected key from fc tensor product: {k}'
else:
# assert k in stct_dst
updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
return stct_dst
class SevenNetCheckpoint:
"""
Tool box for checkpoint processed from SevenNet.
"""
def __init__(self, checkpoint_path: Union[pathlib.Path, str]):
self._checkpoint_path = os.path.abspath(checkpoint_path)
self._config = None
self._epoch = None
self._model_state_dict = None
self._optimizer_state_dict = None
self._scheduler_state_dict = None
self._hash = None
self._time = None
self._loaded = False
def __repr__(self) -> str:
cfg = self.config # just alias
if len(cfg) == 0:
return ''
dct = {
'Sevennet version': cfg.get('version', 'Not found'),
'When': self.time,
'Hash': self.hash,
'Cutoff': cfg.get('cutoff'),
'Channel': cfg.get('channel'),
'Lmax': cfg.get('lmax'),
'Group (parity)': 'O3' if cfg.get('is_parity') else 'SO3',
'Interaction layers': cfg.get('num_convolution_layer'),
'Self connection type': cfg.get('self_connection_type', 'nequip'),
'Last epoch': self.epoch,
'Elements': len(cfg.get('chemical_species', [])),
}
if cfg.get('use_modality', False):
dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys()))
df = pd.DataFrame.from_dict([dct]).T # type: ignore
df.columns = ['']
return df.to_string()
@property
def checkpoint_path(self) -> str:
return str(self._checkpoint_path)
@property
def config(self) -> Dict[str, Any]:
if not self._loaded:
self._load()
assert isinstance(self._config, dict)
return deepcopy(self._config)
@property
def model_state_dict(self) -> Dict[str, Any]:
if not self._loaded:
self._load()
assert isinstance(self._model_state_dict, dict)
return copy_state_dict(self._model_state_dict)
@property
def optimizer_state_dict(self) -> Dict[str, Any]:
if not self._loaded:
self._load()
assert isinstance(self._optimizer_state_dict, dict)
return copy_state_dict(self._optimizer_state_dict)
@property
def scheduler_state_dict(self) -> Dict[str, Any]:
if not self._loaded:
self._load()
assert isinstance(self._scheduler_state_dict, dict)
return copy_state_dict(self._scheduler_state_dict)
@property
def epoch(self) -> Optional[int]:
if not self._loaded:
self._load()
return self._epoch
@property
def time(self) -> str:
if not self._loaded:
self._load()
assert isinstance(self._time, str)
return self._time
@property
def hash(self) -> str:
if not self._loaded:
self._load()
assert isinstance(self._hash, str)
return self._hash
def _load(self) -> None:
assert not self._loaded
cp_path = self.checkpoint_path # just alias
cp = torch_load(cp_path, weights_only=False, map_location='cpu')
self._config_original = cp.get('config', {})
self._model_state_dict = cp.get('model_state_dict', {})
self._optimizer_state_dict = cp.get('optimizer_state_dict', {})
self._scheduler_state_dict = cp.get('scheduler_state_dict', {})
self._epoch = cp.get('epoch', None)
self._time = cp.get('time', 'Not found')
self._hash = cp.get('hash', 'Not found')
if len(self._config_original) == 0:
warnings.warn(f'config is not found from {cp_path}')
self._config = {}
else:
self._config = _config_cp_routine(self._config_original)
if len(self._model_state_dict) == 0:
warnings.warn(f'model_state_dict is not found from {cp_path}')
self._loaded = True
def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential:
from .model_build import build_E3_equivariant_model
use_cue = not backend or backend.lower() in ['cue', 'cueq']
try:
cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use']
except KeyError:
cp_using_cue = False
if (not backend) or (use_cue == cp_using_cue):
# backend not given, or checkpoint backend is same as requested
model = build_E3_equivariant_model(self.config)
state_dict = compat.patch_state_dict_if_old(
self.model_state_dict, self.config, model
)
else:
cfg_new = self.config
cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue}
model = build_E3_equivariant_model(cfg_new)
stct_src = compat.patch_state_dict_if_old(
self.model_state_dict, self.config, model
)
state_dict = _convert_e3nn_and_cueq(
stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue
)
missing, not_used = model.load_state_dict(state_dict, strict=False)
if len(not_used) > 0:
warnings.warn(f'Some keys are not used: {not_used}', UserWarning)
assert len(missing) == 0, f'Missing keys: {missing}'
return model
def yaml_dict(self, mode: str) -> dict:
"""
Return dict for input.yaml from checkpoint config
Dataset paths and statistic values are removed intentionally
"""
if mode not in ['reproduce', 'continue', 'continue_modal']:
raise ValueError(f'Unknown mode: {mode}')
ignore = [
'when',
KEY.DDP_BACKEND,
KEY.LOCAL_RANK,
KEY.IS_DDP,
KEY.DEVICE,
KEY.MODEL_TYPE,
KEY.SHIFT,
KEY.SCALE,
KEY.CONV_DENOMINATOR,
KEY.SAVE_DATASET,
KEY.SAVE_BY_LABEL,
KEY.SAVE_BY_TRAIN_VALID,
KEY.CONTINUE,
KEY.LOAD_DATASET, # old
]
cfg = self.config
len_atoms = len(cfg[KEY.TYPE_MAP])
world_size = cfg.pop(KEY.WORLD_SIZE, 1)
cfg[KEY.BATCH_SIZE] = cfg[KEY.BATCH_SIZE] * world_size
cfg[KEY.LOAD_TRAINSET] = '**path_to_training_set**'
major, minor, _ = cfg.pop('version', '0.0.0').split('.')[:3]
if int(major) == 0 and int(minor) <= 9:
warnings.warn('checkpoint version too old, yaml may wrong')
ret = {'model': {}, 'train': {}, 'data': {}}
for k, v in cfg.items():
if k.startswith('_') or k in ignore or k.endswith('set_path'):
continue
if k in consts.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG:
ret['model'][k] = v
elif k in consts.DEFAULT_TRAINING_CONFIG:
ret['train'][k] = v
elif k in consts.DEFAULT_DATA_CONFIG:
ret['data'][k] = v
ret['model'][KEY.CHEMICAL_SPECIES] = (
'univ' if len_atoms == consts.NUM_UNIV_ELEMENT else 'auto'
)
ret['data'][KEY.LOAD_TRAINSET] = '**path_to_trainset**'
ret['data'][KEY.LOAD_VALIDSET] = '**path_to_validset**'
# TODO
ret['data'][KEY.SHIFT] = '**failed to infer shift, should be set**'
ret['data'][KEY.SCALE] = '**failed to infer scale, should be set**'
if mode.startswith('continue'):
ret['train'].update(
{KEY.CONTINUE: {KEY.CHECKPOINT: self.checkpoint_path}}
)
modal_names = None
if mode == 'continue_modal' and not cfg.get(KEY.USE_MODALITY, False):
ret['train'][KEY.USE_MODALITY] = True
# suggest defaults
ret['model'][KEY.USE_MODAL_NODE_EMBEDDING] = False
ret['model'][KEY.USE_MODAL_SELF_INTER_INTRO] = True
ret['model'][KEY.USE_MODAL_SELF_INTER_OUTRO] = True
ret['model'][KEY.USE_MODAL_OUTPUT_BLOCK] = True
ret['data'][KEY.USE_MODAL_WISE_SHIFT] = True
ret['data'][KEY.USE_MODAL_WISE_SCALE] = False
modal_names = ['my_modal1', 'my_modal2']
elif cfg.get(KEY.USE_MODALITY, False):
modal_names = list(cfg[KEY.MODAL_MAP].keys())
if modal_names:
ret['data'][KEY.LOAD_TRAINSET] = [
{'data_modality': mm, 'file_list': [{'file': f'**path_to_{mm}**'}]}
for mm in modal_names
]
return ret
def append_modal(
self,
dst_config,
original_modal_name: str = 'origin',
working_dir: str = os.getcwd(),
):
""" """
import sevenn.train.modal_dataset as modal_dataset
from sevenn.model_build import init_shift_scale
from sevenn.scripts.convert_model_modality import _append_modal_weight
src_config = self.config
src_has_no_modal = not src_config.get(KEY.USE_MODALITY, False)
# inherit element things first
chem_keys = [
KEY.TYPE_MAP,
KEY.NUM_SPECIES,
KEY.CHEMICAL_SPECIES,
KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER,
]
dst_config.update({k: src_config[k] for k in chem_keys})
if dst_config[KEY.USE_MODAL_WISE_SHIFT] and (
KEY.SHIFT not in dst_config or not isinstance(dst_config[KEY.SHIFT], str)
):
raise ValueError('To use modal wise shift, keyword shift is required')
if dst_config[KEY.USE_MODAL_WISE_SCALE] and (
KEY.SCALE not in dst_config or not isinstance(dst_config[KEY.SCALE], str)
):
raise ValueError('To use modal wise scale, keyword scale is required')
if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SHIFT]:
dst_config[KEY.SHIFT] = src_config[KEY.SHIFT]
if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SCALE]:
dst_config[KEY.SCALE] = src_config[KEY.SCALE]
# get statistics of given datasets of yaml
# dst_config updated
_ = modal_dataset.from_config(dst_config, working_dir=working_dir)
dst_modal_map = dst_config[KEY.MODAL_MAP]
found_modal_names = list(dst_modal_map.keys())
if len(found_modal_names) == 0:
raise ValueError('No modality is found from config')
# Check difference btw given modals and new modal map
orig_modal_map = src_config.get(KEY.MODAL_MAP, {original_modal_name: 0})
assert isinstance(orig_modal_map, dict)
new_modal_map = orig_modal_map.copy()
for modal_name in found_modal_names:
if modal_name in orig_modal_map: # duplicate, skipping
continue
new_modal_map[modal_name] = len(new_modal_map) # assign new
print(f'New modals: {list(new_modal_map.keys())}')
if src_has_no_modal:
append_num = len(new_modal_map)
else:
append_num = len(new_modal_map) - len(orig_modal_map)
if append_num == 0:
raise ValueError('Nothing to append from checkpoint')
dst_config[KEY.NUM_MODALITIES] = len(new_modal_map)
dst_config[KEY.MODAL_MAP] = new_modal_map
# update dst_config's shift scales based on src_config
for ss_key, use_mw in (
(KEY.SHIFT, dst_config[KEY.USE_MODAL_WISE_SHIFT]),
(KEY.SCALE, dst_config[KEY.USE_MODAL_WISE_SCALE]),
):
if not use_mw: # not using mw ss, just assign
assert not isinstance(dst_config[ss_key], dict)
dst_config[ss_key] = src_config[ss_key]
elif src_has_no_modal:
assert isinstance(dst_config[ss_key], dict)
# mw ss, update by dict but use original_modal_name
dst_config[ss_key].update({original_modal_name: src_config[ss_key]})
else:
assert isinstance(dst_config[ss_key], dict)
# mw ss, update by dict
dst_config[ss_key].update(src_config[ss_key])
scaler = init_shift_scale(dst_config)
# finally, prepare updated continuable state dict using above
orig_model = self.build_model()
orig_state_dict = orig_model.state_dict()
new_state_dict = copy_state_dict(orig_state_dict)
for stct_key in orig_state_dict:
sp = stct_key.split('.')
k, follower = sp[0], '.'.join(sp[1:])
if k == 'rescale_atomic_energy' and follower == 'shift':
new_state_dict[stct_key] = scaler.shift.clone()
elif k == 'rescale_atomic_energy' and follower == 'scale':
new_state_dict[stct_key] = scaler.scale.clone()
elif follower == 'linear.weight' and ( # append linear layer
(
dst_config[KEY.USE_MODAL_NODE_EMBEDDING]
and k.endswith('onehot_to_feature_x')
)
or (
dst_config[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1')
)
or (
dst_config[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2')
)
or (
dst_config[KEY.USE_MODAL_OUTPUT_BLOCK]
and k == 'reduce_input_to_hidden'
)
):
orig_linear = getattr(orig_model._modules[k], 'linear')
# assert normalization element
new_state_dict[stct_key] = _append_modal_weight(
orig_state_dict,
k,
orig_linear.irreps_in,
orig_linear.irreps_out,
append_num,
)
dst_config['version'] = sevenn.__version__
return new_state_dict
def get_checkpoint_dict(self) -> dict:
"""
Return duplicate of this checkpoint with new hash and time.
Convenient for creating variant of the checkpoint
"""
return {
'config': self.config,
'epoch': self.epoch,
'model_state_dict': self.model_state_dict,
'optimizer_state_dict': self.optimizer_state_dict,
'scheduler_state_dict': self.scheduler_state_dict,
'time': datetime.now().strftime('%Y-%m-%d %H:%M'),
'hash': uuid.uuid4().hex,
}
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.train.loss import LossDefinition
from .atom_graph_data import AtomGraphData
from .train.optim import loss_dict
_ERROR_TYPES = {
'TotalEnergy': {
'name': 'Energy',
'ref_key': KEY.ENERGY,
'pred_key': KEY.PRED_TOTAL_ENERGY,
'unit': 'eV',
'vdim': 1,
},
'Energy': { # by default per-atom for energy
'name': 'Energy',
'ref_key': KEY.ENERGY,
'pred_key': KEY.PRED_TOTAL_ENERGY,
'unit': 'eV/atom',
'per_atom': True,
'vdim': 1,
},
'Force': {
'name': 'Force',
'ref_key': KEY.FORCE,
'pred_key': KEY.PRED_FORCE,
'unit': 'eV/Å',
'vdim': 3,
},
'Stress': {
'name': 'Stress',
'ref_key': KEY.STRESS,
'pred_key': KEY.PRED_STRESS,
'unit': 'kbar',
'coeff': 1602.1766208,
'vdim': 6,
},
'Stress_GPa': {
'name': 'Stress',
'ref_key': KEY.STRESS,
'pred_key': KEY.PRED_STRESS,
'unit': 'GPa',
'coeff': 160.21766208,
'vdim': 6,
},
'TotalLoss': {
'name': 'TotalLoss',
'unit': None,
},
}
def get_err_type(name: str) -> Dict[str, Any]:
return deepcopy(_ERROR_TYPES[name])
def _get_loss_function_from_name(loss_functions, name):
for loss_def, w in loss_functions:
if loss_def.name.lower() == name.lower():
return loss_def, w
return None, None
class AverageNumber:
def __init__(self):
self._sum = 0.0
self._count = 0
def update(self, values: torch.Tensor):
self._sum += values.sum().item()
self._count += values.numel()
def _ddp_reduce(self, device):
_sum = torch.tensor(self._sum, device=device)
_count = torch.tensor(self._count, device=device)
dist.all_reduce(_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(_count, op=dist.ReduceOp.SUM)
self._sum = _sum.item()
self._count = _count.item()
def get(self):
if self._count == 0:
return torch.nan
return self._sum / self._count
class ErrorMetric:
"""
Base class for error metrics We always average error by # of structures,
and designed to collect errors in the middle of iteration (by AverageNumber)
"""
def __init__(
self,
name: str,
ref_key: str,
pred_key: str,
coeff: float = 1.0,
unit: Optional[str] = None,
per_atom: bool = False,
ignore_unlabeled: bool = True,
**kwargs,
):
self.name = name
self.unit = unit
self.coeff = coeff
self.ref_key = ref_key
self.pred_key = pred_key
self.per_atom = per_atom
self.ignore_unlabeled = ignore_unlabeled
self.value = AverageNumber()
def update(self, output: AtomGraphData):
raise NotImplementedError
def _retrieve(self, output: AtomGraphData):
y_ref = output[self.ref_key] * self.coeff
y_pred = output[self.pred_key] * self.coeff
if self.per_atom:
assert y_ref.dim() == 1 and y_pred.dim() == 1
natoms = output[KEY.NUM_ATOMS]
y_ref = y_ref / natoms
y_pred = y_pred / natoms
if self.ignore_unlabeled:
unlabelled_idx = torch.isnan(y_ref)
y_ref = y_ref[~unlabelled_idx]
y_pred = y_pred[~unlabelled_idx]
return y_ref, y_pred
def ddp_reduce(self, device):
self.value._ddp_reduce(device)
def reset(self):
self.value = AverageNumber()
def get(self):
return self.value.get()
def key_str(self, with_unit=True):
if self.unit is None or not with_unit:
return self.name
else:
return f'{self.name} ({self.unit})'
def __str__(self):
return f'{self.key_str()}: {self.value.get():.6f}'
class RMSError(ErrorMetric):
"""
Vector squared error
"""
def __init__(self, vdim: int = 1, **kwargs):
super().__init__(**kwargs)
self.vdim = vdim
self._se = torch.nn.MSELoss(reduction='none')
def _square_error(self, y_ref, y_pred, vdim: int):
return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1)
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
se = self._square_error(y_ref, y_pred, self.vdim)
self.value.update(se)
def get(self):
return self.value.get() ** 0.5
class ComponentRMSError(ErrorMetric):
"""
Ignore vector dim and just average over components
Results smaller error
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._se = torch.nn.MSELoss(reduction='none')
def _square_error(self, y_ref, y_pred):
return self._se(y_ref, y_pred)
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
y_ref = y_ref.view(-1)
y_pred = y_pred.view(-1)
se = self._square_error(y_ref, y_pred)
self.value.update(se)
def get(self):
return self.value.get() ** 0.5
class MAError(ErrorMetric):
"""
Average over all component
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _square_error(self, y_ref, y_pred):
return torch.abs(y_ref - y_pred)
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
y_ref = y_ref.reshape((-1,))
y_pred = y_pred.reshape((-1,))
se = self._square_error(y_ref, y_pred)
self.value.update(se)
class CustomError(ErrorMetric):
"""
Custom error metric
Args:
func: a function that takes y_ref and y_pred
and returns a list of errors
"""
def __init__(self, func: Callable, **kwargs):
super().__init__(**kwargs)
self.func = func
def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output)
se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([])
self.value.update(se)
class LossError(ErrorMetric):
"""
Error metric that record loss
"""
def __init__(
self,
name: str,
loss_def: LossDefinition,
**kwargs,
):
super().__init__(
name,
ignore_unlabeld=loss_def.ignore_unlabeled,
**kwargs,
)
self.loss_def = loss_def
def update(self, output: AtomGraphData):
loss = self.loss_def.get_loss(output) # type: ignore
self.value.update(loss) # type: ignore
class CombinedError(ErrorMetric):
"""
Combine multiple error metrics with weights
corresponds to a weighted sum of errors (normally used in loss)
"""
def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs):
super().__init__(**kwargs)
self.metrics = metrics
assert kwargs['unit'] is None
def update(self, output: AtomGraphData):
for metric, _ in self.metrics:
metric.update(output)
def reset(self):
for metric, _ in self.metrics:
metric.reset()
def ddp_reduce(self, device): # override
for metric, _ in self.metrics:
metric.value._ddp_reduce(device)
def get(self):
val = 0.0
for metric, weight in self.metrics:
val += metric.get() * weight
return val
class ErrorRecorder:
"""
record errors of a model
"""
METRIC_DICT = {
'RMSE': RMSError,
'ComponentRMSE': ComponentRMSError,
'MAE': MAError,
'Loss': LossError,
}
def __init__(self, metrics: List[ErrorMetric]):
self.history = []
self.metrics = metrics
def _update(self, output: AtomGraphData):
for metric in self.metrics:
metric.update(output)
def update(self, output: AtomGraphData, no_grad=True):
if no_grad:
with torch.no_grad():
self._update(output)
else:
self._update(output)
def get_metric_dict(self, with_unit=True):
return {metric.key_str(with_unit): metric.get() for metric in self.metrics}
def get_current(self):
dct = {}
for metric in self.metrics:
dct[metric.name] = {
'value': metric.get(),
'unit': metric.unit,
'ref_key': metric.ref_key,
'pred_key': metric.pred_key,
}
return dct
def get_dct(self, prefix=''):
dct = {}
if prefix.endswith('_') is False and prefix != '':
prefix = prefix + '_'
for metric in self.metrics:
dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}'
return dct
def get_key_str(self, name: str):
for metric in self.metrics:
if name == metric.name:
return metric.key_str()
return None
def epoch_forward(self):
self.history.append(self.get_current())
pretty = self.get_metric_dict(with_unit=True)
for metric in self.metrics:
metric.reset()
return pretty # for print
@staticmethod
def init_total_loss_metric(
config,
criteria: Optional[Callable] = None,
loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None,
):
if criteria is None and loss_functions is None:
raise ValueError('both criteria and loss functions not given')
is_stress = config[KEY.IS_TRAIN_STRESS]
metrics = []
if criteria is not None:
energy_metric = CustomError(criteria, **get_err_type('Energy'))
metrics.append((energy_metric, 1))
force_metric = CustomError(criteria, **get_err_type('Force'))
metrics.append((force_metric, config[KEY.FORCE_WEIGHT]))
if is_stress:
stress_metric = CustomError(criteria, **get_err_type('Stress'))
metrics.append((stress_metric, config[KEY.STRESS_WEIGHT]))
else: # TODO: this is hard-coded
for efs in ['Energy', 'Force', 'Stress']:
if efs == 'Stress' and not is_stress:
continue
lf, w = _get_loss_function_from_name(loss_functions, efs)
if lf is None:
raise ValueError(f'{efs} not found from loss_functions')
metric = LossError(loss_def=lf, **get_err_type(efs))
metrics.append((metric, w))
total_loss_metric = CombinedError(
metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None
)
return total_loss_metric
@staticmethod
def from_config(config: dict, loss_functions=None):
loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()]
loss_param = config.get(KEY.LOSS_PARAM, {})
criteria = loss_cls(**loss_param) if loss_functions is None else None
err_config = config.get(KEY.ERROR_RECORD, False)
if not err_config:
raise ValueError(
'No error_record config found. Consider util.get_error_recorder'
)
err_config_n = []
if not config.get(KEY.IS_TRAIN_STRESS, True):
for err_type, metric_name in err_config:
if 'Stress' in err_type:
continue
err_config_n.append((err_type, metric_name))
err_config = err_config_n
err_metrics = []
for err_type, metric_name in err_config:
metric_kwargs = get_err_type(err_type)
if err_type == 'TotalLoss': # special case
err_metrics.append(
ErrorRecorder.init_total_loss_metric(
config, criteria, loss_functions
)
)
continue
metric_cls = ErrorRecorder.METRIC_DICT[metric_name]
assert isinstance(metric_kwargs['name'], str)
if metric_name == 'Loss':
if loss_functions is not None:
metric_cls = LossError
metric_kwargs['loss_def'], _ = _get_loss_function_from_name(
loss_functions, metric_kwargs['name']
)
else:
metric_cls = CustomError
metric_kwargs['func'] = criteria
metric_kwargs.pop('unit', None)
metric_kwargs['name'] += f'_{metric_name}'
err_metrics.append(metric_cls(**metric_kwargs))
return ErrorRecorder(err_metrics)
import os
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
from ase.data import atomic_numbers
import sevenn._keys as KEY
from sevenn import __version__
CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output
def __init__(
self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
):
self.rank = rank
self._filename = filename
if rank == 0:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self.logfile = None
self.files = {}
self.screen = screen
else:
self.logfile = None
self.screen = False
self.timer_dct = {}
self.active = True
def __enter__(self):
if self.rank != 0:
return self
if self.logfile is None and self._filename is not None:
try:
self.logfile = open(
self._filename, 'a', buffering=1, encoding='utf-8'
)
except IOError as e:
print(f'Failed to re-open log file {self._filename}: {e}')
self.logfile = None
self.files = {}
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.rank != 0:
return self
try:
if self.logfile is not None:
self.logfile.close()
self.logfile = None
for f in self.files.values():
f.close()
except IOError as e:
print(f'Failed to close log files: {e}')
finally:
self.logfile = None
self.files = {}
def switch_file(self, new_filename: str):
if self.rank != 0:
return self
if self.logfile is not None:
raise ValueError('Current logfile is not yet closed')
self._filename = new_filename
return self
def write(self, content: str):
if self.rank != 0:
return
# no newline!
if self.logfile is not None and self.active:
self.logfile.write(content)
if self.screen and self.active:
print(content, end='')
def writeline(self, content: str):
content = content + '\n'
self.write(content)
def init_csv(self, filename: str, header: list):
"""
Deprecated
"""
if self.rank == 0:
self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
self.files[filename].write(','.join(header) + '\n')
else:
pass
def append_csv(self, filename: str, content: list, decimal: int = 6):
"""
Deprecated
"""
if self.rank == 0:
if filename not in self.files:
self.files[filename] = open(filename, 'a', buffering=1)
str_content = []
for c in content:
if isinstance(c, float):
str_content.append(f'{c:.{decimal}f}')
else:
str_content.append(str(c))
self.files[filename].write(','.join(str_content) + '\n')
else:
pass
def natoms_write(self, natoms: Dict[str, Dict]):
content = ''
total_natom = {}
for label, natom in natoms.items():
content += self.format_k_v(label, natom)
for specie, num in natom.items():
try:
total_natom[specie] += num
except KeyError:
total_natom[specie] = num
content += self.format_k_v('Total, label wise', total_natom)
content += self.format_k_v('Total', sum(total_natom.values()))
self.write(content)
def statistic_write(self, statistic: Dict[str, Dict]):
content = ''
for label, dct in statistic.items():
if label.startswith('_'):
continue
if not isinstance(dct, dict):
continue
dct_new = {}
for k, v in dct.items():
if k.startswith('_'):
continue
if isinstance(v, int):
dct_new[k] = v
else:
dct_new[k] = f'{v:.3f}'
content += self.format_k_v(label, dct_new)
self.write(content)
# TODO : refactoring!!!, this is not loss, rmse
def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
lb_pad = 21
fs = 6
pad = 21 - fs
ln = '-' * fs
total_atom_type = train_loss.keys()
content = ''
for at in total_atom_type:
t_F = train_loss[at]
v_F = valid_loss[at]
at_sym = CHEM_SYMBOLS[at]
content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
t_F=t_F, v_F=v_F, pad=pad, fs=fs
)
content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
t_S=ln, v_S=ln, pad=pad, fs=fs
)
content += '\n'
self.write(content)
def write_full_table(
self,
dict_list: List[Dict],
row_labels: List[str],
decimal_places: int = 6,
pad: int = 2,
):
"""
Assume data_list is list of dict with same keys
"""
assert len(dict_list) == len(row_labels)
label_len = max(map(len, row_labels))
# Extract the column names and create a 2D array of values
col_names = list(dict_list[0].keys())
values = [list(d.values()) for d in dict_list]
# Format the numbers with the given decimal places
formatted_values = [
[f'{value:.{decimal_places}f}' for value in row] for row in values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths = [
max(len(str(value)) for value in col) + pad
for col in zip(col_names, *formatted_values)
]
# Create header row and separator
header = ' ' * (label_len + pad) + ' '.join(
col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
)
separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
label_len + pad
)
# Print header and separator
self.writeline(header)
self.writeline(separator)
# Print the data rows with row labels
for row_label, row in zip(row_labels, formatted_values):
data_row = ' '.join(
value.rjust(pad) for value, pad in zip(row, max_col_lengths)
)
self.writeline(f'{row_label.ljust(label_len)}{data_row}')
def format_k_v(self, key: Any, val: Any, write: bool = False):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE = 20
SEPARATOR = ', '
EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
key = str(key)
val = str(val)
content = f'{key:<{MAX_KEY_SIZE}}: {val}'
if len(content) > NEW_LINE_LEN:
content = f'{key:<{MAX_KEY_SIZE}}: '
# septate val by separator
val_list = val.split(SEPARATOR)
current_len = len(content)
for val_compo in val_list:
current_len += len(val_compo)
if current_len > NEW_LINE_LEN:
newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
content += f'\\\n{newline_content}'
current_len = len(newline_content)
else:
content += f'{val_compo}{SEPARATOR}'
if content.endswith(f'{SEPARATOR}'):
content = content[: -len(SEPARATOR)]
content += '\n'
if write is False:
return content
else:
self.write(content)
return ''
def greeting(self):
LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
with open(LOGO_ASCII_FILE, 'r') as logo_f:
logo_ascii = logo_f.read()
content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
content += f'version {__version__}, {time.ctime()}\n'
self.write(content)
self.write(logo_ascii)
def bar(self):
content = '-' * Logger.SCREEN_WIDTH + '\n'
self.write(content)
def print_config(
self,
model_config: Dict[str, Any],
data_config: Dict[str, Any],
train_config: Dict[str, Any],
):
"""
print some important information from config
"""
content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
for k, v in model_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom train configuration\n'
for k, v in train_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom data configuration\n'
for k, v in data_config.items():
content += self.format_k_v(k, str(v))
self.write(content)
# TODO: This is not good make own exception
def error(self, e: Exception):
content = ''
if type(e) is ValueError:
content += 'Error occurred!\n'
content += str(e) + '\n'
else:
content += 'Unknown error occurred!\n'
content += traceback.format_exc()
self.write(content)
def timer_start(self, name: str):
self.timer_dct[name] = datetime.now()
def timer_end(self, name: str, message: str, remove: bool = True):
"""
print f"{message}: {elapsed}"
"""
elapsed = str(datetime.now() - self.timer_dct[name])
# elapsed = elapsed.strftime('%H-%M-%S')
if remove:
del self.timer_dct[name]
self.write(f'{message}: {elapsed[:-4]}\n')
# TODO: print it without config
# TODO: refactoring, readout part name :(
def print_model_info(self, model, config):
from functools import partial
kv_write = partial(self.format_k_v, write=True)
self.writeline('Irreps of features')
kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
for i in range(config[KEY.NUM_CONVOLUTION]):
kv_write(
f'{i}th node',
model.get_irreps_in(f'{i}_self_interaction_1'),
)
i = config[KEY.NUM_CONVOLUTION] - 1
kv_write(
'readout irreps',
model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
)
num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.writeline(f'# learnable parameters: {num_weights}\n')
****
******** .
*//////, .. . ,*.
,,***. .. , ********. ./,
. . .. /////. ., . *///////// /////////.
.&@&/ . .(((((((.. / *//////*. ... *((((((((((.
@@@@@@@@@@* @@@@@@@@@@ @@@@@ *((@@@@@ ( %@@@@@@@@@@ .@@@@@@ ..@@@@. @@@@@@* .(@@@@@(((*
@@@@@. @@@@ @@@@@ . @@@@@ # %@@@@ @@@@@@@@ @@@@(, @@@@@@@@. @@@@@(*.
%@@@@@@@& @@@@@@@@@@ @@@@@ @@@@@ # ., .%@@@@@@@@@ @@@@@@@@@@ @@@@, @@@@@@@@@@ @@@@@
,(%@@@@@@@@@ @@@@@@@@@@ @@@@@ @@@@& (//////%@@@@@@@@@ @@@@ @@@@@@ @@@@ . @@@@@ @@@@@.@@@@@
. @@@@@ @@@@ . . @@@@@@@@% . . ( .////,%@@@@ @@@@ @@@@@@@@@ @@@@@ @@@@@@@@@
(@@@@@@@@@@@ @@@@@@@@@@**. @@@@@@* *. .%@@@@@@@@@@ @@@@ . @@@@@@@ @@@@@ .@@@@@@@
@@@@@@@@@. @@@@@@@@@@///, @@@@. . / %@@@@@@@@@@ @@@@***, @@@@@ @@@@@ @@@@@
. //////////*. / . .*******... . ,.
.&&&&&... ,//////*. ...////. / ,*/. . ,////, .,/////
&@@@@@@ ,(/((, * ,((((((. .***.
,/@(, .. * ,((((*
,
.
import argparse
import os
import sys
import time
from sevenn import __version__
description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint
global_config = {
'version': __version__,
'when': time.ctime(),
'_model_type': 'E3_equivariant_model',
}
def run(args):
"""
main function of sevenn
"""
import random
import sys
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath
input_yaml = args.input_yaml
mode = args.mode
working_dir = args.working_dir
log = args.log
screen = args.screen
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
if use_cue:
import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')
if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True)
world_size = 1
if distributed:
if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank
)
else:
local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting()
if distributed:
logger.writeline(
f'Distributed training enabled, total world size is {world_size}'
)
try:
model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True
)
except Exception as e:
logger.writeline('Failed to parsing input.yaml')
logger.error(e)
sys.exit(1)
train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size
if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
global_config.update(train_config)
global_config.update(data_config)
# Not implemented
if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED]
random.seed(seed)
torch.manual_seed(seed)
# run train
if mode == 'train_v1':
train(global_config, working_dir)
elif mode == 'train_v2':
train_v2(global_config, working_dir)
def cmd_parser_train(parser):
ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument(
'-m',
'--mode',
choices=['train_v1', 'train_v2'],
default='train_v2',
help=mode_help,
type=str,
)
ag.add_argument(
'-cueq',
'--enable_cueq',
help='(Not stable!) use cuEquivariance for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
nargs='?',
const=os.getcwd(),
help=working_dir_help,
type=str,
)
ag.add_argument(
'-l',
'--log',
default='log.sevenn',
help='name of logfile, default is log.sevenn',
type=str,
)
ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true'
)
ag.add_argument(
'--distributed_backend',
help=distributed_backend_help,
type=str,
default='nccl',
choices=['nccl', 'mpi'],
)
def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args()
Hack copied from stack overflow
"""
subparser_found = False
for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser
break
else:
for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction):
continue
for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]:
subparser_found = True
if not subparser_found:
# insert default in last position before global positional
# arguments, this implies no global options are specified after
# first positional argument
if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name)
else:
args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main():
import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore
args = ag.parse_args()
if args.command is None: # backward compatibility
args.command = 'train'
if args.command == 'train':
run(args)
elif args.command == 'preset':
preset_cmd.run(args)
if __name__ == '__main__':
main()
import argparse
import os.path as osp
from sevenn import __version__
description = (
'tool box for sevennet checkpoints'
)
def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False)
group.add_argument(
'--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint',
type=str,
)
group.add_argument(
'--append_modal_yaml',
help='append modality with given yaml.',
type=str,
)
ag.add_argument(
'--original_modal_name',
help=(
'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"'
),
default='origin',
type=str,
)
def run(args):
import torch
import yaml
from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml:
mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name
)
to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = {
'continue': {
'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True,
'reset_optimizer': True,
'reset_scheduler': True,
}
}
print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
)
else:
print(checkpoint)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description_get_model = (
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
def add_parser(subparsers):
ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy']
)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str
)
ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true'
)
ag.add_argument(
'-m',
'--modal',
help='Modality of multi-modal model',
type=str,
)
def run(args):
import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint
output_prefix = args.output_prefix
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial:
deploy(checkpoint_path, output_prefix, modal)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way
def main():
ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
from datetime import datetime
from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = (
'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.'
)
legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument(
'-n',
'--num_cores',
help='number of cores to build graph in parallel',
default=1,
type=int,
)
ag.add_argument(
'-o',
'--out',
help='Existing path to write outputs.',
type=str,
default='./',
)
ag.add_argument(
'-f',
'--filename',
help=filename_help,
type=str,
default='graph.pt',
)
ag.add_argument(
'--legacy',
help=legacy_help,
action='store_true',
)
ag.add_argument(
'-s',
'--screen',
help='print log to the screen',
action='store_true',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key',
)
def run(args):
import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger
source = glob.glob(args.source)
cutoff = args.cutoff
num_cores = args.num_cores
filename = args.filename
out = args.out
legacy = args.legacy
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if len(source) == 0:
print('Source has zero len, nothing to read')
sys.exit(0)
if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}')
metadata = {
'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff,
}
with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description)
if not legacy:
graph_build.build_sevennet_graph_dataset(
source,
cutoff,
num_cores,
out,
filename,
metadata,
**fmt_kwargs,
)
else:
out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data
source,
cutoff,
num_cores,
out,
metadata,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'
def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument(
'-d',
'--device',
type=str,
default='auto',
help='cpu/cuda/cuda:x',
)
ag.add_argument(
'-nw',
'--nworkers',
type=int,
default=1,
help='Number of cores to build graph, defaults to 1',
)
ag.add_argument(
'-o',
'--output',
type=str,
default='./inference_results',
help='A directory name to write outputs',
)
ag.add_argument(
'-b',
'--batch',
type=int,
default='4',
help='batch size, useful for GPU'
)
ag.add_argument(
'-s',
'--save_graph',
action='store_true',
help='Additionally, save preprocessed graph as sevenn_data'
)
ag.add_argument(
'-au',
'--allow_unlabeled',
action='store_true',
help='Allow energy or force unlabeled data'
)
ag.add_argument(
'-m',
'--modal',
type=str,
default=None,
help='modality for multi-modal inference',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key',
)
def run(args):
import torch
from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path
out = args.output
if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists')
device = args.device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = []
for target in args.targets:
targets.extend(glob.glob(target))
if len(targets) == 0:
print('No targets (data to inference) are found')
sys.exit(0)
cp = args.checkpoint
if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference(
cp,
targets,
out,
args.nworkers,
device,
args.batch,
args.save_graph,
args.allow_unlabeled,
args.modal,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
import subprocess
from sevenn import __version__
# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically
def run(args):
lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17
if args.d3:
d3_support = '1'
print(' - D3 support enabled')
else:
d3_support = '0'
print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
if __name__ == '__main__':
main()
import argparse
import os
from sevenn import __version__
description = (
'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml'
)
preset_help = 'Name of preset'
def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument(
'preset', choices=[
'fine_tune',
'fine_tune_le',
'sevennet-0',
'sevennet-l3i5',
'base',
'multi_modal'
],
help=preset_help
)
def run(args):
preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read())
# When executed as sevenn_preset (legacy way)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import copy
import warnings
from collections import OrderedDict
from typing import List, Literal, Union, overload
from e3nn.o3 import Irreps
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
from .nn.convolution import IrrepsConvolution
from .nn.edge_embedding import (
BesselBasis,
EdgeEmbedding,
PolynomialCutoff,
SphericalEncoding,
XPLORCutoff,
)
from .nn.force_output import ForceStressOutputFromEdge
from .nn.interaction_blocks import NequIP_interaction_block
from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear
from .nn.node_embedding import OnehotEmbedding
from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale
from .nn.self_connection import (
SelfConnectionIntro,
SelfConnectionLinearIntro,
SelfConnectionOutro,
)
from .nn.sequential import AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings.filterwarnings(
'ignore',
message=(
"The TorchScript type system doesn't " 'support instance-level annotations'
),
)
def _insert_after(module_name_after, key_module_pair, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name_after:
idx = i
break
if idx == -1:
return layers # do nothing if not found
layers.insert(idx + 1, key_module_pair)
return layers
def init_self_connection(config):
self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE]
num_conv = config[KEY.NUM_CONVOLUTION]
if isinstance(self_connection_type_list, str):
self_connection_type_list = [self_connection_type_list] * num_conv
io_pair_list = []
for sc_type in self_connection_type_list:
if sc_type == 'none':
io_pair = None
elif sc_type == 'nequip':
io_pair = SelfConnectionIntro, SelfConnectionOutro
elif sc_type == 'linear':
io_pair = SelfConnectionLinearIntro, SelfConnectionOutro
else:
raise ValueError(f'Unknown self_connection_type found: {sc_type}')
io_pair_list.append(io_pair)
return io_pair_list
def init_edge_embedding(config):
_cutoff_param = {'cutoff_length': config[KEY.CUTOFF]}
rbf, env, sph = None, None, None
rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS])
rbf_dct.update(_cutoff_param)
rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME)
if rbf_name == 'bessel':
rbf = BesselBasis(**rbf_dct)
envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION])
envelop_dct.update(_cutoff_param)
envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME)
if envelop_name == 'poly_cut':
env = PolynomialCutoff(**envelop_dct)
elif envelop_name == 'XPLOR':
env = XPLORCutoff(**envelop_dct)
lmax_edge = config[KEY.LMAX]
if config[KEY.LMAX_EDGE] > 0:
lmax_edge = config[KEY.LMAX_EDGE]
parity = -1 if config[KEY.IS_PARITY] else 1
_normalize_sph = config[KEY._NORMALIZE_SPH]
sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph)
return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph)
def init_feature_reduce(config, irreps_x):
# features per node to scalar per node
layers = OrderedDict()
if config[KEY.READOUT_AS_FCN] is False:
hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))])
layers.update(
{
'reduce_input_to_hidden': IrrepsLinear(
irreps_x,
hidden_irreps,
data_key_in=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
'reduce_hidden_to_energy': IrrepsLinear(
hidden_irreps,
Irreps([(1, (0, 1))]),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
}
)
else:
act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]]
hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS]
layers.update(
{
'readout_FCN': FCN_e3nn(
dim_out=1,
hidden_neurons=hidden_neurons,
activation=act,
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
irreps_in=irreps_x,
)
}
)
return layers
def init_shift_scale(config):
# for mm, ex, shift: modal_idx -> shifts
shift_scale = []
train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE]
type_map = config[KEY.TYPE_MAP]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for s in (config[KEY.SHIFT], config[KEY.SCALE]):
if hasattr(s, 'tolist'): # numpy or torch
s = s.tolist()
if isinstance(s, dict):
s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()}
if isinstance(s, list) and len(s) == 1:
s = s[0]
shift_scale.append(s)
shift, scale = shift_scale
rescale_module = None
if config.get(KEY.USE_MODALITY, False):
rescale_module = ModalWiseRescale.from_mappers( # type: ignore
shift,
scale,
config[KEY.USE_MODAL_WISE_SHIFT],
config[KEY.USE_MODAL_WISE_SCALE],
type_map=type_map,
modal_map=config[KEY.MODAL_MAP],
train_shift_scale=train_shift_scale,
)
elif all([isinstance(s, float) for s in shift_scale]):
rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale)
elif any([isinstance(s, list) for s in shift_scale]):
rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore
shift, scale, type_map=type_map, train_shift_scale=train_shift_scale
)
else:
raise ValueError('shift, scale should be list of float or float')
return rescale_module
def patch_modality(layers: OrderedDict, config):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg = config
if not cfg.get(KEY.USE_MODALITY, False):
return layers
_layers = list(layers.items())
_layers = _insert_after(
'onehot_idx_to_onehot',
(
'one_hot_modality',
OnehotEmbedding(
num_classes=config[KEY.NUM_MODALITIES],
data_key_x=KEY.MODAL_TYPE,
data_key_out=KEY.MODAL_ATTR,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
layers = OrderedDict(_layers)
num_modal = config[KEY.NUM_MODALITIES]
for k, module in layers.items():
if not isinstance(module, IrrepsLinear):
continue
if (
(cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x'))
or (
cfg[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1')
)
or (
cfg[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2')
)
or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden')
):
module.set_num_modalities(num_modal)
return layers
def patch_cue(layers: OrderedDict, config):
import sevenn.nn.cue_helper as cue_helper
cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {}))
if not cue_cfg.pop('use', False):
return layers
if not cue_helper.is_cue_available():
warnings.warn(
(
'cuEquivariance is requested, but the package is not installed. '
+ 'Fallback to original code.'
)
)
return layers
if not cue_helper.is_cue_cuda_available_model(config):
return layers
group = 'O3' if config[KEY.IS_PARITY] else 'SO3'
cueq_module_params = dict(layout='mul_ir')
cueq_module_params.update(cue_cfg)
updates = {}
for k, module in layers.items():
if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)):
if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape
continue
module_patched = cue_helper.patch_linear(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, SelfConnectionIntro):
module_patched = cue_helper.patch_fully_connected(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, IrrepsConvolution):
module_patched = cue_helper.patch_convolution(
module, group, **cueq_module_params
)
updates[k] = module_patched
layers.update(updates)
return layers
def patch_modules(layers: OrderedDict, config):
layers = patch_modality(layers, config)
layers = patch_cue(layers, config)
return layers
def _to_parallel_model(layers: OrderedDict, config):
num_classes = layers['onehot_idx_to_onehot'].num_classes
one_hot_irreps = Irreps(f'{num_classes}x0e')
irreps_node_zero = layers['onehot_to_feature_x'].irreps_out
_layers = list(layers.items())
layers_list = []
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
def slice_until_this(module_name, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name:
idx = i
break
first_to = layers[: idx + 1]
remain = layers[idx + 1 :]
return first_to, remain
_layers = _insert_after(
'onehot_to_feature_x',
(
'one_hot_ghost',
OnehotEmbedding(
data_key_x=KEY.NODE_FEATURE_GHOST,
num_classes=num_classes,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
_layers = _insert_after(
'one_hot_ghost',
(
'ghost_onehot_to_feature_x',
IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
_layers = _insert_after(
'0_self_interaction_1',
(
'ghost_0_self_interaction_1',
IrrepsLinear(
irreps_node_zero,
irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for i in range(1, num_convolution_layer):
sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers)
layers_list.append(OrderedDict(sliced))
_layers.insert(0, ('edge_embedding', init_edge_embedding(config)))
layers_list.append(OrderedDict(_layers))
del layers_list[-1]['force_output'] # done in LAMMPS
return layers_list
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[False] = False
) -> AtomGraphSequential: # noqa
...
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[True]
) -> List[AtomGraphSequential]: # noqa
...
def build_E3_equivariant_model(
config: dict, parallel: bool = False
) -> Union[AtomGraphSequential, List[AtomGraphSequential]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers = OrderedDict()
cutoff = config[KEY.CUTOFF]
num_species = config[KEY.NUM_SPECIES]
feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY]
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
interaction_type = config[KEY.INTERACTION_TYPE]
use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR]
lmax_node = config[KEY.LMAX] # ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if config[KEY.LMAX_NODE] > 0:
lmax_node = config[KEY.LMAX_NODE]
act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]]
self_connection_pair_list = init_self_connection(config)
irreps_manual = None
if config[KEY.IRREPS_MANUAL] is not False:
irreps_manual = config[KEY.IRREPS_MANUAL]
try:
irreps_manual = [Irreps(irr) for irr in irreps_manual]
assert len(irreps_manual) == num_convolution_layer + 1
except Exception:
raise RuntimeError('invalid irreps_manual input given')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if not isinstance(conv_denominator, list):
conv_denominator = [conv_denominator] * num_convolution_layer
train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR]
edge_embedding = init_edge_embedding(config)
irreps_filter = edge_embedding.spherical.irreps_out
radial_basis_num = edge_embedding.basis_function.num_basis
layers.update({'edge_embedding': edge_embedding})
one_hot_irreps = Irreps(f'{num_species}x0e')
irreps_x = (
Irreps(f'{feature_multiplicity}x0e')
if irreps_manual is None
else irreps_manual[0]
)
layers.update(
{
'onehot_idx_to_onehot': OnehotEmbedding(
num_classes=num_species,
data_key_x=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE,
data_key_save=KEY.ATOM_TYPE, # atomic numbers
data_key_additional=KEY.NODE_ATTR, # one-hot embeddings
),
'onehot_to_feature_x': IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_x,
data_key_in=KEY.NODE_FEATURE,
biases=use_bias_in_linear,
),
}
)
weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS]
weight_nn_layers = [radial_basis_num] + weight_nn_hidden
param_interaction_block = {
'irreps_filter': irreps_filter,
'weight_nn_layers': weight_nn_layers,
'train_conv_denominator': train_conv_denominator,
'act_radial': act_radial,
'bias_in_linear': use_bias_in_linear,
'num_species': num_species,
'parallel': parallel,
}
interaction_builder = None
if interaction_type in ['nequip']:
act_scalar = {}
act_gate = {}
for k, v in config[KEY.ACTIVATION_SCARLAR].items():
act_scalar[k] = _const.ACTIVATION_DICT[k][v]
for k, v in config[KEY.ACTIVATION_GATE].items():
act_gate[k] = _const.ACTIVATION_DICT[k][v]
param_interaction_block.update(
{
'act_scalar': act_scalar,
'act_gate': act_gate,
}
)
if interaction_type == 'nequip':
interaction_builder = NequIP_interaction_block
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
for t in range(num_convolution_layer):
param_interaction_block.update(
{
'irreps_x': irreps_x,
't': t,
'conv_denominator': conv_denominator[t],
'self_connection_pair': self_connection_pair_list[t],
}
)
if interaction_type == 'nequip':
parity_mode = 'full'
fix_multiplicity = False
if t == num_convolution_layer - 1:
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = (
util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
lmax_node, # type: ignore
parity_mode,
fix_multiplicity=feature_multiplicity,
)
if irreps_manual is None
else irreps_manual[t + 1]
)
irreps_out_tp = util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
irreps_out.lmax, # type: ignore
parity_mode,
fix_multiplicity,
)
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
param_interaction_block.update(
{
'irreps_out_tp': irreps_out_tp,
'irreps_out': irreps_out,
}
)
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out
layers.update(init_feature_reduce(config, irreps_x))
layers.update(
{
'rescale_atomic_energy': init_shift_scale(config),
'reduce_total_enegy': AtomReduce(
data_key_in=KEY.ATOMIC_ENERGY,
data_key_out=KEY.PRED_TOTAL_ENERGY,
),
}
)
gradient_module = ForceStressOutputFromEdge()
grad_key = gradient_module.get_grad_key()
layers.update({'force_output': gradient_module})
common_args = {
'cutoff': cutoff,
'type_map': config[KEY.TYPE_MAP],
'modal_map': config.get(KEY.MODAL_MAP, None),
'eval_type_map': False if parallel else True,
'eval_modal_map': False
if not config.get(KEY.USE_MODALITY, False) or parallel
else True,
'data_key_grad': grad_key,
}
if parallel:
layers_list = _to_parallel_model(layers, config)
return [
AtomGraphSequential(patch_modules(layers, config), **common_args)
for layers in layers_list
]
else:
return AtomGraphSequential(patch_modules(layers, config), **common_args)
import math
import torch
@torch.jit.script
def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x) - math.log(2.0)
from typing import List
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, TensorProduct
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .activation import ShiftedSoftPlus
from .util import broadcast
def message_gather(
node_features: torch.Tensor,
edge_dst: torch.Tensor,
message: torch.Tensor
):
index = broadcast(edge_dst, message, 0)
out_shape = [len(node_features)] + list(message.shape[1:])
out = torch.zeros(
out_shape,
dtype=node_features.dtype,
device=node_features.device
)
out.scatter_reduce_(0, index, message, reduce='sum')
return out
@compile_mode('script')
class IrrepsConvolution(nn.Module):
"""
convolution of (fig 2.b), comm. in LAMMPS
"""
def __init__(
self,
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out: Irreps,
weight_layer_input_to_hidden: List[int],
weight_layer_act=ShiftedSoftPlus,
denominator: float = 1.0,
train_denominator: bool = False,
data_key_x: str = KEY.NODE_FEATURE,
data_key_filter: str = KEY.EDGE_ATTR,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
data_key_edge_idx: str = KEY.EDGE_IDX,
lazy_layer_instantiate: bool = True,
is_parallel: bool = False,
):
super().__init__()
self.denominator = nn.Parameter(
torch.FloatTensor([denominator]), requires_grad=train_denominator
)
self.key_x = data_key_x
self.key_filter = data_key_filter
self.key_weight_input = data_key_weight_input
self.key_edge_idx = data_key_edge_idx
self.is_parallel = is_parallel
instructions = []
irreps_mid = []
weight_numel = 0
for i, (mul_x, ir_x) in enumerate(irreps_x):
for j, (_, ir_filter) in enumerate(irreps_filter):
for ir_out in ir_x * ir_filter:
if ir_out in irreps_out: # here we drop l > lmax
k = len(irreps_mid)
weight_numel += mul_x * 1 # path shape
irreps_mid.append((mul_x, ir_out))
instructions.append((i, j, k, 'uvu', True))
irreps_mid = Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort() # type: ignore
instructions = [
(i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
# From v0.11.x, to compatible with cuEquivariance
self._instructions_before_sort = instructions
instructions = sorted(instructions, key=lambda x: x[2])
self.convolution_kwargs = dict(
irreps_in1=irreps_x,
irreps_in2=irreps_filter,
irreps_out=irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.weight_nn_kwargs = dict(
hs=weight_layer_input_to_hidden + [weight_numel],
act=weight_layer_act
)
self.convolution = None
self.weight_nn = None
self.layer_instantiated = False
self.convolution_cls = TensorProduct
self.weight_nn_cls = FullyConnectedNet
if not lazy_layer_instantiate:
self.instantiate()
self._comm_size = irreps_x.dim # used in parallel
def instantiate(self):
if self.convolution is not None:
raise ValueError('Convolution layer already exists')
if self.weight_nn is not None:
raise ValueError('Weight_nn layer already exists')
self.convolution = self.convolution_cls(**self.convolution_kwargs)
self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.convolution is not None, 'Convolution is not instantiated'
assert self.weight_nn is not None, 'Weight_nn is not instantiated'
weight = self.weight_nn(data[self.key_weight_input])
x = data[self.key_x]
if self.is_parallel:
x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]])
# note that 1 -> src 0 -> dst
edge_src = data[self.key_edge_idx][1]
edge_dst = data[self.key_edge_idx][0]
message = self.convolution(x[edge_src], data[self.key_filter], weight)
x = message_gather(x, edge_dst, message)
x = x.div(self.denominator)
if self.is_parallel:
x = torch.tensor_split(x, data[KEY.NLOCAL])[0]
data[self.key_x] = x
return data
import itertools
import warnings
from typing import Iterator, Literal, Union
import e3nn.o3 as o3
import numpy as np
from .convolution import IrrepsConvolution
from .linear import IrrepsLinear
from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
_CUE_AVAILABLE = True
# Obatained from MACE
class O3_e3nn(cue.O3):
def __mul__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> Iterator['O3_e3nn']:
return [ # type: ignore
O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)
]
@classmethod
def clebsch_gordan( # type: ignore
cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn'
) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim
)
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> bool:
rep2 = rep1._from(rep2) # type: ignore
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod
def iterator(cls) -> Iterator['O3_e3nn']:
for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
except ImportError:
_CUE_AVAILABLE = False
def is_cue_available():
return _CUE_AVAILABLE
def cue_needed(func):
def wrapper(*args, **kwargs):
if is_cue_available():
return func(*args, **kwargs)
else:
raise ImportError('cue is not available')
return wrapper
def _check_may_not_compatible(orig_kwargs, defaults):
for k, v in defaults.items():
v_given = orig_kwargs.pop(k, v)
if v_given != v:
warnings.warn(f'{k}: {v} is ignored to use cuEquivariance')
def is_cue_cuda_available_model(config):
if config.get('use_bias_in_linear', False):
warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn')
return False
else:
return True
@cue_needed
def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']):
"""Convert e3nn irreps to given group's cue irreps"""
if group == 'SO3':
assert all(irrep.ir.p == 1 for irrep in irreps)
return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore
elif group == 'O3':
return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore
else:
raise ValueError(f'Unknown group: {group}')
@cue_needed
def patch_linear(
module: Union[IrrepsLinear, SelfConnectionLinearIntro],
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
orig_kwargs = module.linear_kwargs
may_not_compatible_default = dict(
f_in=None,
f_out=None,
instructions=None,
biases=False,
path_normalization='element',
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(orig_kwargs, may_not_compatible_default)
module.linear_cls = cuet.Linear # type: ignore
orig_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_convolution(
module: IrrepsConvolution,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
# conv_kwargs will be patched in place
conv_kwargs = module.convolution_kwargs
conv_kwargs.update(
dict(
irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group),
irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group),
filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group),
)
)
inst_orig = conv_kwargs.pop('instructions')
inst_sorted = sorted(inst_orig, key=lambda x: x[2])
assert all([a == b for a, b in zip(inst_orig, inst_sorted)])
may_not_compatible_default = dict(
in1_var=None,
in2_var=None,
out_var=None,
irrep_normalization=False,
path_normalization='element',
compile_left_right=True,
compile_right=False,
_specialized_code=None,
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(conv_kwargs, may_not_compatible_default)
module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore
conv_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_fully_connected(
module: SelfConnectionIntro,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore
module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
may_not_compatible_default = dict(
irrep_normalization=None,
path_normalization=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(
module.fc_tensor_product_kwargs, may_not_compatible_default
)
module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore
module.fc_tensor_product_kwargs.update(**cue_kwargs)
return module
import math
import torch
import torch.nn as nn
from e3nn.o3 import Irreps, SphericalHarmonics
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EdgePreprocess(nn.Module):
"""
preprocessing pos to edge vectors and edge lengths
currently used in sevenn/scripts/deploy for lammps serial model
"""
def __init__(self, is_stress: bool):
super().__init__()
# controlled by 'AtomGraphSequential'
self.is_stress = is_stress
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
cell = data[KEY.CELL].view(-1, 3, 3)
else:
cell = data[KEY.CELL].view(3, 3)
cell_shift = data[KEY.CELL_SHIFT]
pos = data[KEY.POS]
batch = data[KEY.BATCH] # for deploy, must be defined first
if self.is_stress:
if self._is_batch_data:
num_batch = int(batch.max().cpu().item()) + 1
strain = torch.zeros(
(num_batch, 3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.bmm(
pos.unsqueeze(-2), sym_strain[batch]
).squeeze(-2)
cell = cell + torch.bmm(cell, sym_strain)
else:
strain = torch.zeros(
(3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.mm(pos, sym_strain)
cell = cell + torch.mm(cell, sym_strain)
idx_src = data[KEY.EDGE_IDX][0]
idx_dst = data[KEY.EDGE_IDX][1]
edge_vec = pos[idx_dst] - pos[idx_src]
if self._is_batch_data:
edge_vec = edge_vec + torch.einsum(
'ni,nij->nj', cell_shift, cell[batch[idx_src]]
)
else:
edge_vec = edge_vec + torch.einsum(
'ni,ij->nj', cell_shift, cell.squeeze(0)
)
data[KEY.EDGE_VEC] = edge_vec
data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1)
return data
class BesselBasis(nn.Module):
"""
f : (*, 1) -> (*, bessel_basis_num)
"""
def __init__(
self,
cutoff_length: float,
bessel_basis_num: int = 8,
trainable_coeff: bool = True,
):
super().__init__()
self.num_basis = bessel_basis_num
self.prefactor = 2.0 / cutoff_length
self.coeffs = torch.FloatTensor([
n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)
])
if trainable_coeff:
self.coeffs = nn.Parameter(self.coeffs)
def forward(self, r: torch.Tensor) -> torch.Tensor:
ur = r.unsqueeze(-1) # to fit dimension
return self.prefactor * torch.sin(self.coeffs * ur) / ur
class PolynomialCutoff(nn.Module):
"""
f : (*, 1) -> (*, 1)
https://arxiv.org/pdf/2003.03123.pdf
"""
def __init__(
self,
cutoff_length: float,
poly_cut_p_value: int = 6,
):
super().__init__()
p = poly_cut_p_value
self.cutoff_length = cutoff_length
self.p = p
self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0
self.coeff_p1 = p * (p + 2.0)
self.coeff_p2 = p * (p + 1.0) / 2.0
def forward(self, r: torch.Tensor) -> torch.Tensor:
r = r / self.cutoff_length
return (
1
- self.coeff_p0 * torch.pow(r, self.p)
+ self.coeff_p1 * torch.pow(r, self.p + 1.0)
- self.coeff_p2 * torch.pow(r, self.p + 2.0)
)
class XPLORCutoff(nn.Module):
"""
https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html
"""
def __init__(
self,
cutoff_length: float,
cutoff_on: float,
):
super().__init__()
self.r_on = cutoff_on
self.r_cut = cutoff_length
assert self.r_on < self.r_cut
def forward(self, r: torch.Tensor) -> torch.Tensor:
r_sq = r * r
r_on_sq = self.r_on * self.r_on
r_cut_sq = self.r_cut * self.r_cut
return torch.where(
r < self.r_on,
1.0,
(r_cut_sq - r_sq) ** 2
* (r_cut_sq + 2 * r_sq - 3 * r_on_sq)
/ (r_cut_sq - r_on_sq) ** 3,
)
@compile_mode('script')
class SphericalEncoding(nn.Module):
def __init__(
self,
lmax: int,
parity: int = -1,
normalization: str = 'component',
normalize: bool = True,
):
super().__init__()
self.lmax = lmax
self.normalization = normalization
self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e')
self.irreps_out = Irreps.spherical_harmonics(lmax, parity)
self.sph = SphericalHarmonics(
self.irreps_out,
normalize=normalize,
normalization=normalization,
irreps_in=self.irreps_in,
)
def forward(self, r: torch.Tensor) -> torch.Tensor:
return self.sph(r)
@compile_mode('script')
class EdgeEmbedding(nn.Module):
"""
embedding layer of |r| by
RadialBasis(|r|)*CutOff(|r|)
f : (N_edge) -> (N_edge, basis_num)
"""
def __init__(
self,
basis_module: nn.Module,
cutoff_module: nn.Module,
spherical_module: nn.Module,
):
super().__init__()
self.basis_function = basis_module
self.cutoff_function = cutoff_module
self.spherical = spherical_module
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
rvec = data[KEY.EDGE_VEC]
r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1)
data[KEY.EDGE_LENGTH] = r
data[KEY.EDGE_EMBEDDING] = self.basis_function(
r
) * self.cutoff_function(r).unsqueeze(-1)
data[KEY.EDGE_ATTR] = self.spherical(rvec)
return data
from typing import Callable, Dict
import torch.nn as nn
from e3nn.nn import Gate
from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EquivariantGate(nn.Module):
def __init__(
self,
irreps_x: Irreps,
act_scalar_dict: Dict[int, Callable],
act_gate_dict: Dict[int, Callable],
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
parity_mapper = {'e': 1, 'o': -1}
act_scalar_dict = {
parity_mapper[k]: v for k, v in act_scalar_dict.items()
}
act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()}
irreps_gated_elem = []
irreps_scalars_elem = []
# non scalar irreps > gated / scalar irreps > scalars
for mul, irreps in irreps_x:
if irreps.l > 0:
irreps_gated_elem.append((mul, irreps))
else:
irreps_scalars_elem.append((mul, irreps))
irreps_scalars = Irreps(irreps_scalars_elem)
irreps_gated = Irreps(irreps_gated_elem)
irreps_gates_parity = 1 if '0e' in irreps_scalars else -1
irreps_gates = Irreps(
[(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated]
)
act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars]
act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates]
self.gate = Gate(
irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated
)
def get_gate_irreps_in(self):
"""
user must call this function to get proper irreps in for forward
"""
return self.gate.irreps_in
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = self.gate(data[self.key_x])
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment