Commit 2409a22f authored by fanding2000's avatar fanding2000
Browse files

Format fix. More options in readme

parent ce29afea
import math import math
import torch import torch
@torch.jit.script @torch.jit.script
def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor: def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x) - math.log(2.0) return torch.nn.functional.softplus(x) - math.log(2.0)
from typing import List from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from e3nn.nn import FullyConnectedNet from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, TensorProduct from e3nn.o3 import Irreps, TensorProduct
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
from .activation import ShiftedSoftPlus from .activation import ShiftedSoftPlus
from .util import broadcast from .util import broadcast
def message_gather( def message_gather(
node_features: torch.Tensor, node_features: torch.Tensor,
edge_dst: torch.Tensor, edge_dst: torch.Tensor,
message: torch.Tensor message: torch.Tensor
): ):
index = broadcast(edge_dst, message, 0) index = broadcast(edge_dst, message, 0)
out_shape = [len(node_features)] + list(message.shape[1:]) out_shape = [len(node_features)] + list(message.shape[1:])
out = torch.zeros( out = torch.zeros(
out_shape, out_shape,
dtype=node_features.dtype, dtype=node_features.dtype,
device=node_features.device device=node_features.device
) )
out.scatter_reduce_(0, index, message, reduce='sum') out.scatter_reduce_(0, index, message, reduce='sum')
return out return out
@compile_mode('script') @compile_mode('script')
class IrrepsConvolution(nn.Module): class IrrepsConvolution(nn.Module):
""" """
convolution of (fig 2.b), comm. in LAMMPS convolution of (fig 2.b), comm. in LAMMPS
""" """
def __init__( def __init__(
self, self,
irreps_x: Irreps, irreps_x: Irreps,
irreps_filter: Irreps, irreps_filter: Irreps,
irreps_out: Irreps, irreps_out: Irreps,
weight_layer_input_to_hidden: List[int], weight_layer_input_to_hidden: List[int],
weight_layer_act=ShiftedSoftPlus, weight_layer_act=ShiftedSoftPlus,
denominator: float = 1.0, denominator: float = 1.0,
train_denominator: bool = False, train_denominator: bool = False,
data_key_x: str = KEY.NODE_FEATURE, data_key_x: str = KEY.NODE_FEATURE,
data_key_filter: str = KEY.EDGE_ATTR, data_key_filter: str = KEY.EDGE_ATTR,
data_key_weight_input: str = KEY.EDGE_EMBEDDING, data_key_weight_input: str = KEY.EDGE_EMBEDDING,
data_key_edge_idx: str = KEY.EDGE_IDX, data_key_edge_idx: str = KEY.EDGE_IDX,
lazy_layer_instantiate: bool = True, lazy_layer_instantiate: bool = True,
is_parallel: bool = False, is_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.denominator = nn.Parameter( self.denominator = nn.Parameter(
torch.FloatTensor([denominator]), requires_grad=train_denominator torch.FloatTensor([denominator]), requires_grad=train_denominator
) )
self.key_x = data_key_x self.key_x = data_key_x
self.key_filter = data_key_filter self.key_filter = data_key_filter
self.key_weight_input = data_key_weight_input self.key_weight_input = data_key_weight_input
self.key_edge_idx = data_key_edge_idx self.key_edge_idx = data_key_edge_idx
self.is_parallel = is_parallel self.is_parallel = is_parallel
instructions = [] instructions = []
irreps_mid = [] irreps_mid = []
weight_numel = 0 weight_numel = 0
for i, (mul_x, ir_x) in enumerate(irreps_x): for i, (mul_x, ir_x) in enumerate(irreps_x):
for j, (_, ir_filter) in enumerate(irreps_filter): for j, (_, ir_filter) in enumerate(irreps_filter):
for ir_out in ir_x * ir_filter: for ir_out in ir_x * ir_filter:
if ir_out in irreps_out: # here we drop l > lmax if ir_out in irreps_out: # here we drop l > lmax
k = len(irreps_mid) k = len(irreps_mid)
weight_numel += mul_x * 1 # path shape weight_numel += mul_x * 1 # path shape
irreps_mid.append((mul_x, ir_out)) irreps_mid.append((mul_x, ir_out))
instructions.append((i, j, k, 'uvu', True)) instructions.append((i, j, k, 'uvu', True))
irreps_mid = Irreps(irreps_mid) irreps_mid = Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort() # type: ignore irreps_mid, p, _ = irreps_mid.sort() # type: ignore
instructions = [ instructions = [
(i_in1, i_in2, p[i_out], mode, train) (i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions for i_in1, i_in2, i_out, mode, train in instructions
] ]
# From v0.11.x, to compatible with cuEquivariance # From v0.11.x, to compatible with cuEquivariance
self._instructions_before_sort = instructions self._instructions_before_sort = instructions
instructions = sorted(instructions, key=lambda x: x[2]) instructions = sorted(instructions, key=lambda x: x[2])
self.convolution_kwargs = dict( self.convolution_kwargs = dict(
irreps_in1=irreps_x, irreps_in1=irreps_x,
irreps_in2=irreps_filter, irreps_in2=irreps_filter,
irreps_out=irreps_mid, irreps_out=irreps_mid,
instructions=instructions, instructions=instructions,
shared_weights=False, shared_weights=False,
internal_weights=False, internal_weights=False,
) )
self.weight_nn_kwargs = dict( self.weight_nn_kwargs = dict(
hs=weight_layer_input_to_hidden + [weight_numel], hs=weight_layer_input_to_hidden + [weight_numel],
act=weight_layer_act act=weight_layer_act
) )
self.convolution = None self.convolution = None
self.weight_nn = None self.weight_nn = None
self.layer_instantiated = False self.layer_instantiated = False
self.convolution_cls = TensorProduct self.convolution_cls = TensorProduct
self.weight_nn_cls = FullyConnectedNet self.weight_nn_cls = FullyConnectedNet
if not lazy_layer_instantiate: if not lazy_layer_instantiate:
self.instantiate() self.instantiate()
self._comm_size = irreps_x.dim # used in parallel self._comm_size = irreps_x.dim # used in parallel
def instantiate(self): def instantiate(self):
if self.convolution is not None: if self.convolution is not None:
raise ValueError('Convolution layer already exists') raise ValueError('Convolution layer already exists')
if self.weight_nn is not None: if self.weight_nn is not None:
raise ValueError('Weight_nn layer already exists') raise ValueError('Weight_nn layer already exists')
self.convolution = self.convolution_cls(**self.convolution_kwargs) self.convolution = self.convolution_cls(**self.convolution_kwargs)
self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs) self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs)
self.layer_instantiated = True self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.convolution is not None, 'Convolution is not instantiated' assert self.convolution is not None, 'Convolution is not instantiated'
assert self.weight_nn is not None, 'Weight_nn is not instantiated' assert self.weight_nn is not None, 'Weight_nn is not instantiated'
weight = self.weight_nn(data[self.key_weight_input]) weight = self.weight_nn(data[self.key_weight_input])
x = data[self.key_x] x = data[self.key_x]
if self.is_parallel: if self.is_parallel:
x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]]) x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]])
# note that 1 -> src 0 -> dst # note that 1 -> src 0 -> dst
edge_src = data[self.key_edge_idx][1] edge_src = data[self.key_edge_idx][1]
edge_dst = data[self.key_edge_idx][0] edge_dst = data[self.key_edge_idx][0]
message = self.convolution(x[edge_src], data[self.key_filter], weight) message = self.convolution(x[edge_src], data[self.key_filter], weight)
x = message_gather(x, edge_dst, message) x = message_gather(x, edge_dst, message)
x = x.div(self.denominator) x = x.div(self.denominator)
if self.is_parallel: if self.is_parallel:
x = torch.tensor_split(x, data[KEY.NLOCAL])[0] x = torch.tensor_split(x, data[KEY.NLOCAL])[0]
data[self.key_x] = x data[self.key_x] = x
return data return data
import itertools import itertools
import warnings import warnings
from typing import Iterator, Literal, Union from typing import Iterator, Literal, Union
import e3nn.o3 as o3 import e3nn.o3 as o3
import numpy as np import numpy as np
from .convolution import IrrepsConvolution from .convolution import IrrepsConvolution
from .linear import IrrepsLinear from .linear import IrrepsLinear
from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro
try: try:
import cuequivariance as cue import cuequivariance as cue
import cuequivariance_torch as cuet import cuequivariance_torch as cuet
_CUE_AVAILABLE = True _CUE_AVAILABLE = True
# Obatained from MACE # Obatained from MACE
class O3_e3nn(cue.O3): class O3_e3nn(cue.O3):
def __mul__( # type: ignore def __mul__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn' rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> Iterator['O3_e3nn']: ) -> Iterator['O3_e3nn']:
return [ # type: ignore return [ # type: ignore
O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2) O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)
] ]
@classmethod @classmethod
def clebsch_gordan( # type: ignore def clebsch_gordan( # type: ignore
cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn' cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn'
) -> np.ndarray: ) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p: if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim rep3.dim
) )
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # type: ignore def __lt__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn' rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> bool: ) -> bool:
rep2 = rep1._from(rep2) # type: ignore rep2 = rep1._from(rep2) # type: ignore
return (rep1.l, rep1.p) < (rep2.l, rep2.p) return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod @classmethod
def iterator(cls) -> Iterator['O3_e3nn']: def iterator(cls) -> Iterator['O3_e3nn']:
for l in itertools.count(0): for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l) yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l) yield O3_e3nn(l=l, p=-1 * (-1) ** l)
except ImportError: except ImportError:
_CUE_AVAILABLE = False _CUE_AVAILABLE = False
def is_cue_available(): def is_cue_available():
return _CUE_AVAILABLE return _CUE_AVAILABLE
def cue_needed(func): def cue_needed(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if is_cue_available(): if is_cue_available():
return func(*args, **kwargs) return func(*args, **kwargs)
else: else:
raise ImportError('cue is not available') raise ImportError('cue is not available')
return wrapper return wrapper
def _check_may_not_compatible(orig_kwargs, defaults): def _check_may_not_compatible(orig_kwargs, defaults):
for k, v in defaults.items(): for k, v in defaults.items():
v_given = orig_kwargs.pop(k, v) v_given = orig_kwargs.pop(k, v)
if v_given != v: if v_given != v:
warnings.warn(f'{k}: {v} is ignored to use cuEquivariance') warnings.warn(f'{k}: {v} is ignored to use cuEquivariance')
def is_cue_cuda_available_model(config): def is_cue_cuda_available_model(config):
if config.get('use_bias_in_linear', False): if config.get('use_bias_in_linear', False):
warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn') warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn')
return False return False
else: else:
return True return True
@cue_needed @cue_needed
def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']): def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']):
"""Convert e3nn irreps to given group's cue irreps""" """Convert e3nn irreps to given group's cue irreps"""
if group == 'SO3': if group == 'SO3':
assert all(irrep.ir.p == 1 for irrep in irreps) assert all(irrep.ir.p == 1 for irrep in irreps)
return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore
elif group == 'O3': elif group == 'O3':
return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore
else: else:
raise ValueError(f'Unknown group: {group}') raise ValueError(f'Unknown group: {group}')
@cue_needed @cue_needed
def patch_linear( def patch_linear(
module: Union[IrrepsLinear, SelfConnectionLinearIntro], module: Union[IrrepsLinear, SelfConnectionLinearIntro],
group: Literal['SO3', 'O3'], group: Literal['SO3', 'O3'],
**cue_kwargs, **cue_kwargs,
): ):
assert not module.layer_instantiated assert not module.layer_instantiated
module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
orig_kwargs = module.linear_kwargs orig_kwargs = module.linear_kwargs
may_not_compatible_default = dict( may_not_compatible_default = dict(
f_in=None, f_in=None,
f_out=None, f_out=None,
instructions=None, instructions=None,
biases=False, biases=False,
path_normalization='element', path_normalization='element',
_optimize_einsums=None, _optimize_einsums=None,
) )
# pop may_not_compatible_defaults # pop may_not_compatible_defaults
_check_may_not_compatible(orig_kwargs, may_not_compatible_default) _check_may_not_compatible(orig_kwargs, may_not_compatible_default)
module.linear_cls = cuet.Linear # type: ignore module.linear_cls = cuet.Linear # type: ignore
orig_kwargs.update(**cue_kwargs) orig_kwargs.update(**cue_kwargs)
return module return module
@cue_needed @cue_needed
def patch_convolution( def patch_convolution(
module: IrrepsConvolution, module: IrrepsConvolution,
group: Literal['SO3', 'O3'], group: Literal['SO3', 'O3'],
**cue_kwargs, **cue_kwargs,
): ):
assert not module.layer_instantiated assert not module.layer_instantiated
# conv_kwargs will be patched in place # conv_kwargs will be patched in place
conv_kwargs = module.convolution_kwargs conv_kwargs = module.convolution_kwargs
conv_kwargs.update( conv_kwargs.update(
dict( dict(
irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group), irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group),
irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group), irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group),
filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group), filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group),
) )
) )
inst_orig = conv_kwargs.pop('instructions') inst_orig = conv_kwargs.pop('instructions')
inst_sorted = sorted(inst_orig, key=lambda x: x[2]) inst_sorted = sorted(inst_orig, key=lambda x: x[2])
assert all([a == b for a, b in zip(inst_orig, inst_sorted)]) assert all([a == b for a, b in zip(inst_orig, inst_sorted)])
may_not_compatible_default = dict( may_not_compatible_default = dict(
in1_var=None, in1_var=None,
in2_var=None, in2_var=None,
out_var=None, out_var=None,
irrep_normalization=False, irrep_normalization=False,
path_normalization='element', path_normalization='element',
compile_left_right=True, compile_left_right=True,
compile_right=False, compile_right=False,
_specialized_code=None, _specialized_code=None,
_optimize_einsums=None, _optimize_einsums=None,
) )
# pop may_not_compatible_defaults # pop may_not_compatible_defaults
_check_may_not_compatible(conv_kwargs, may_not_compatible_default) _check_may_not_compatible(conv_kwargs, may_not_compatible_default)
module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore
conv_kwargs.update(**cue_kwargs) conv_kwargs.update(**cue_kwargs)
return module return module
@cue_needed @cue_needed
def patch_fully_connected( def patch_fully_connected(
module: SelfConnectionIntro, module: SelfConnectionIntro,
group: Literal['SO3', 'O3'], group: Literal['SO3', 'O3'],
**cue_kwargs, **cue_kwargs,
): ):
assert not module.layer_instantiated assert not module.layer_instantiated
module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore
module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
may_not_compatible_default = dict( may_not_compatible_default = dict(
irrep_normalization=None, irrep_normalization=None,
path_normalization=None, path_normalization=None,
) )
# pop may_not_compatible_defaults # pop may_not_compatible_defaults
_check_may_not_compatible( _check_may_not_compatible(
module.fc_tensor_product_kwargs, may_not_compatible_default module.fc_tensor_product_kwargs, may_not_compatible_default
) )
module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore
module.fc_tensor_product_kwargs.update(**cue_kwargs) module.fc_tensor_product_kwargs.update(**cue_kwargs)
return module return module
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from e3nn.o3 import Irreps, SphericalHarmonics from e3nn.o3 import Irreps, SphericalHarmonics
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
@compile_mode('script') @compile_mode('script')
class EdgePreprocess(nn.Module): class EdgePreprocess(nn.Module):
""" """
preprocessing pos to edge vectors and edge lengths preprocessing pos to edge vectors and edge lengths
currently used in sevenn/scripts/deploy for lammps serial model currently used in sevenn/scripts/deploy for lammps serial model
""" """
def __init__(self, is_stress: bool): def __init__(self, is_stress: bool):
super().__init__() super().__init__()
# controlled by 'AtomGraphSequential' # controlled by 'AtomGraphSequential'
self.is_stress = is_stress self.is_stress = is_stress
self._is_batch_data = True self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data: if self._is_batch_data:
cell = data[KEY.CELL].view(-1, 3, 3) cell = data[KEY.CELL].view(-1, 3, 3)
else: else:
cell = data[KEY.CELL].view(3, 3) cell = data[KEY.CELL].view(3, 3)
cell_shift = data[KEY.CELL_SHIFT] cell_shift = data[KEY.CELL_SHIFT]
pos = data[KEY.POS] pos = data[KEY.POS]
batch = data[KEY.BATCH] # for deploy, must be defined first batch = data[KEY.BATCH] # for deploy, must be defined first
if self.is_stress: if self.is_stress:
if self._is_batch_data: if self._is_batch_data:
num_batch = int(batch.max().cpu().item()) + 1 num_batch = int(batch.max().cpu().item()) + 1
strain = torch.zeros( strain = torch.zeros(
(num_batch, 3, 3), (num_batch, 3, 3),
dtype=pos.dtype, dtype=pos.dtype,
device=pos.device, device=pos.device,
) )
strain.requires_grad_(True) strain.requires_grad_(True)
data['_strain'] = strain data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.bmm( pos = pos + torch.bmm(
pos.unsqueeze(-2), sym_strain[batch] pos.unsqueeze(-2), sym_strain[batch]
).squeeze(-2) ).squeeze(-2)
cell = cell + torch.bmm(cell, sym_strain) cell = cell + torch.bmm(cell, sym_strain)
else: else:
strain = torch.zeros( strain = torch.zeros(
(3, 3), (3, 3),
dtype=pos.dtype, dtype=pos.dtype,
device=pos.device, device=pos.device,
) )
strain.requires_grad_(True) strain.requires_grad_(True)
data['_strain'] = strain data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.mm(pos, sym_strain) pos = pos + torch.mm(pos, sym_strain)
cell = cell + torch.mm(cell, sym_strain) cell = cell + torch.mm(cell, sym_strain)
idx_src = data[KEY.EDGE_IDX][0] idx_src = data[KEY.EDGE_IDX][0]
idx_dst = data[KEY.EDGE_IDX][1] idx_dst = data[KEY.EDGE_IDX][1]
edge_vec = pos[idx_dst] - pos[idx_src] edge_vec = pos[idx_dst] - pos[idx_src]
if self._is_batch_data: if self._is_batch_data:
edge_vec = edge_vec + torch.einsum( edge_vec = edge_vec + torch.einsum(
'ni,nij->nj', cell_shift, cell[batch[idx_src]] 'ni,nij->nj', cell_shift, cell[batch[idx_src]]
) )
else: else:
edge_vec = edge_vec + torch.einsum( edge_vec = edge_vec + torch.einsum(
'ni,ij->nj', cell_shift, cell.squeeze(0) 'ni,ij->nj', cell_shift, cell.squeeze(0)
) )
data[KEY.EDGE_VEC] = edge_vec data[KEY.EDGE_VEC] = edge_vec
data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1) data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1)
return data return data
class BesselBasis(nn.Module): class BesselBasis(nn.Module):
""" """
f : (*, 1) -> (*, bessel_basis_num) f : (*, 1) -> (*, bessel_basis_num)
""" """
def __init__( def __init__(
self, self,
cutoff_length: float, cutoff_length: float,
bessel_basis_num: int = 8, bessel_basis_num: int = 8,
trainable_coeff: bool = True, trainable_coeff: bool = True,
): ):
super().__init__() super().__init__()
self.num_basis = bessel_basis_num self.num_basis = bessel_basis_num
self.prefactor = 2.0 / cutoff_length self.prefactor = 2.0 / cutoff_length
self.coeffs = torch.FloatTensor([ self.coeffs = torch.FloatTensor([
n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1) n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)
]) ])
if trainable_coeff: if trainable_coeff:
self.coeffs = nn.Parameter(self.coeffs) self.coeffs = nn.Parameter(self.coeffs)
def forward(self, r: torch.Tensor) -> torch.Tensor: def forward(self, r: torch.Tensor) -> torch.Tensor:
ur = r.unsqueeze(-1) # to fit dimension ur = r.unsqueeze(-1) # to fit dimension
return self.prefactor * torch.sin(self.coeffs * ur) / ur return self.prefactor * torch.sin(self.coeffs * ur) / ur
class PolynomialCutoff(nn.Module): class PolynomialCutoff(nn.Module):
""" """
f : (*, 1) -> (*, 1) f : (*, 1) -> (*, 1)
https://arxiv.org/pdf/2003.03123.pdf https://arxiv.org/pdf/2003.03123.pdf
""" """
def __init__( def __init__(
self, self,
cutoff_length: float, cutoff_length: float,
poly_cut_p_value: int = 6, poly_cut_p_value: int = 6,
): ):
super().__init__() super().__init__()
p = poly_cut_p_value p = poly_cut_p_value
self.cutoff_length = cutoff_length self.cutoff_length = cutoff_length
self.p = p self.p = p
self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0 self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0
self.coeff_p1 = p * (p + 2.0) self.coeff_p1 = p * (p + 2.0)
self.coeff_p2 = p * (p + 1.0) / 2.0 self.coeff_p2 = p * (p + 1.0) / 2.0
def forward(self, r: torch.Tensor) -> torch.Tensor: def forward(self, r: torch.Tensor) -> torch.Tensor:
r = r / self.cutoff_length r = r / self.cutoff_length
return ( return (
1 1
- self.coeff_p0 * torch.pow(r, self.p) - self.coeff_p0 * torch.pow(r, self.p)
+ self.coeff_p1 * torch.pow(r, self.p + 1.0) + self.coeff_p1 * torch.pow(r, self.p + 1.0)
- self.coeff_p2 * torch.pow(r, self.p + 2.0) - self.coeff_p2 * torch.pow(r, self.p + 2.0)
) )
class XPLORCutoff(nn.Module): class XPLORCutoff(nn.Module):
""" """
https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html
""" """
def __init__( def __init__(
self, self,
cutoff_length: float, cutoff_length: float,
cutoff_on: float, cutoff_on: float,
): ):
super().__init__() super().__init__()
self.r_on = cutoff_on self.r_on = cutoff_on
self.r_cut = cutoff_length self.r_cut = cutoff_length
assert self.r_on < self.r_cut assert self.r_on < self.r_cut
def forward(self, r: torch.Tensor) -> torch.Tensor: def forward(self, r: torch.Tensor) -> torch.Tensor:
r_sq = r * r r_sq = r * r
r_on_sq = self.r_on * self.r_on r_on_sq = self.r_on * self.r_on
r_cut_sq = self.r_cut * self.r_cut r_cut_sq = self.r_cut * self.r_cut
return torch.where( return torch.where(
r < self.r_on, r < self.r_on,
1.0, 1.0,
(r_cut_sq - r_sq) ** 2 (r_cut_sq - r_sq) ** 2
* (r_cut_sq + 2 * r_sq - 3 * r_on_sq) * (r_cut_sq + 2 * r_sq - 3 * r_on_sq)
/ (r_cut_sq - r_on_sq) ** 3, / (r_cut_sq - r_on_sq) ** 3,
) )
@compile_mode('script') @compile_mode('script')
class SphericalEncoding(nn.Module): class SphericalEncoding(nn.Module):
def __init__( def __init__(
self, self,
lmax: int, lmax: int,
parity: int = -1, parity: int = -1,
normalization: str = 'component', normalization: str = 'component',
normalize: bool = True, normalize: bool = True,
): ):
super().__init__() super().__init__()
self.lmax = lmax self.lmax = lmax
self.normalization = normalization self.normalization = normalization
self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e') self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e')
self.irreps_out = Irreps.spherical_harmonics(lmax, parity) self.irreps_out = Irreps.spherical_harmonics(lmax, parity)
self.sph = SphericalHarmonics( self.sph = SphericalHarmonics(
self.irreps_out, self.irreps_out,
normalize=normalize, normalize=normalize,
normalization=normalization, normalization=normalization,
irreps_in=self.irreps_in, irreps_in=self.irreps_in,
) )
def forward(self, r: torch.Tensor) -> torch.Tensor: def forward(self, r: torch.Tensor) -> torch.Tensor:
return self.sph(r) return self.sph(r)
@compile_mode('script') @compile_mode('script')
class EdgeEmbedding(nn.Module): class EdgeEmbedding(nn.Module):
""" """
embedding layer of |r| by embedding layer of |r| by
RadialBasis(|r|)*CutOff(|r|) RadialBasis(|r|)*CutOff(|r|)
f : (N_edge) -> (N_edge, basis_num) f : (N_edge) -> (N_edge, basis_num)
""" """
def __init__( def __init__(
self, self,
basis_module: nn.Module, basis_module: nn.Module,
cutoff_module: nn.Module, cutoff_module: nn.Module,
spherical_module: nn.Module, spherical_module: nn.Module,
): ):
super().__init__() super().__init__()
self.basis_function = basis_module self.basis_function = basis_module
self.cutoff_function = cutoff_module self.cutoff_function = cutoff_module
self.spherical = spherical_module self.spherical = spherical_module
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
rvec = data[KEY.EDGE_VEC] rvec = data[KEY.EDGE_VEC]
r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1) r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1)
data[KEY.EDGE_LENGTH] = r data[KEY.EDGE_LENGTH] = r
data[KEY.EDGE_EMBEDDING] = self.basis_function( data[KEY.EDGE_EMBEDDING] = self.basis_function(
r r
) * self.cutoff_function(r).unsqueeze(-1) ) * self.cutoff_function(r).unsqueeze(-1)
data[KEY.EDGE_ATTR] = self.spherical(rvec) data[KEY.EDGE_ATTR] = self.spherical(rvec)
return data return data
from typing import Callable, Dict from typing import Callable, Dict
import torch.nn as nn import torch.nn as nn
from e3nn.nn import Gate from e3nn.nn import Gate
from e3nn.o3 import Irreps from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
@compile_mode('script') @compile_mode('script')
class EquivariantGate(nn.Module): class EquivariantGate(nn.Module):
def __init__( def __init__(
self, self,
irreps_x: Irreps, irreps_x: Irreps,
act_scalar_dict: Dict[int, Callable], act_scalar_dict: Dict[int, Callable],
act_gate_dict: Dict[int, Callable], act_gate_dict: Dict[int, Callable],
data_key_x: str = KEY.NODE_FEATURE, data_key_x: str = KEY.NODE_FEATURE,
): ):
super().__init__() super().__init__()
self.key_x = data_key_x self.key_x = data_key_x
parity_mapper = {'e': 1, 'o': -1} parity_mapper = {'e': 1, 'o': -1}
act_scalar_dict = { act_scalar_dict = {
parity_mapper[k]: v for k, v in act_scalar_dict.items() parity_mapper[k]: v for k, v in act_scalar_dict.items()
} }
act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()} act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()}
irreps_gated_elem = [] irreps_gated_elem = []
irreps_scalars_elem = [] irreps_scalars_elem = []
# non scalar irreps > gated / scalar irreps > scalars # non scalar irreps > gated / scalar irreps > scalars
for mul, irreps in irreps_x: for mul, irreps in irreps_x:
if irreps.l > 0: if irreps.l > 0:
irreps_gated_elem.append((mul, irreps)) irreps_gated_elem.append((mul, irreps))
else: else:
irreps_scalars_elem.append((mul, irreps)) irreps_scalars_elem.append((mul, irreps))
irreps_scalars = Irreps(irreps_scalars_elem) irreps_scalars = Irreps(irreps_scalars_elem)
irreps_gated = Irreps(irreps_gated_elem) irreps_gated = Irreps(irreps_gated_elem)
irreps_gates_parity = 1 if '0e' in irreps_scalars else -1 irreps_gates_parity = 1 if '0e' in irreps_scalars else -1
irreps_gates = Irreps( irreps_gates = Irreps(
[(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated] [(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated]
) )
act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars] act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars]
act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates] act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates]
self.gate = Gate( self.gate = Gate(
irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated
) )
def get_gate_irreps_in(self): def get_gate_irreps_in(self):
""" """
user must call this function to get proper irreps in for forward user must call this function to get proper irreps in for forward
""" """
return self.gate.irreps_in return self.gate.irreps_in
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = self.gate(data[self.key_x]) data[self.key_x] = self.gate(data[self.key_x])
return data return data
import torch import torch
import torch.nn as nn import torch.nn as nn
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
from .util import broadcast from .util import broadcast
@compile_mode('script') @compile_mode('script')
class ForceOutput(nn.Module): class ForceOutput(nn.Module):
""" """
works when pos.requires_grad_ is True works when pos.requires_grad_ is True
""" """
def __init__( def __init__(
self, self,
data_key_pos: str = KEY.POS, data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY, data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE, data_key_force: str = KEY.PRED_FORCE,
): ):
super().__init__() super().__init__()
self.key_pos = data_key_pos self.key_pos = data_key_pos
self.key_energy = data_key_energy self.key_energy = data_key_energy
self.key_force = data_key_force self.key_force = data_key_force
def get_grad_key(self): def get_grad_key(self):
return self.key_pos return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = [data[self.key_pos]] pos_tensor = [data[self.key_pos]]
energy = [(data[self.key_energy]).sum()] energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch # `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it. # Also can not be deployed when using it.
# But not using it makes problem in # But not using it makes problem in
# force/stress inference in sparse systems # force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator? # TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad( grad = torch.autograd.grad(
energy, energy,
pos_tensor, pos_tensor,
create_graph=self.training, create_graph=self.training,
allow_unused=True, allow_unused=True,
# materialize_grads=True, # materialize_grads=True,
)[0] )[0]
# For torchscript # For torchscript
if grad is not None: if grad is not None:
data[self.key_force] = torch.neg(grad) data[self.key_force] = torch.neg(grad)
return data return data
@compile_mode('script') @compile_mode('script')
class ForceStressOutput(nn.Module): class ForceStressOutput(nn.Module):
""" """
Compute stress and force from positions. Compute stress and force from positions.
Used in serial torchscipt models Used in serial torchscipt models
""" """
def __init__( def __init__(
self, self,
data_key_pos: str = KEY.POS, data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY, data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE, data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS, data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME, data_key_cell_volume: str = KEY.CELL_VOLUME,
): ):
super().__init__() super().__init__()
self.key_pos = data_key_pos self.key_pos = data_key_pos
self.key_energy = data_key_energy self.key_energy = data_key_energy
self.key_force = data_key_force self.key_force = data_key_force
self.key_stress = data_key_stress self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True self._is_batch_data = True
def get_grad_key(self): def get_grad_key(self):
return self.key_pos return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = data[self.key_pos] pos_tensor = data[self.key_pos]
energy = [(data[self.key_energy]).sum()] energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch # `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it. # Also can not be deployed when using it.
# But not using it makes problem in # But not using it makes problem in
# force/stress inference in sparse systems # force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator? # TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad( grad = torch.autograd.grad(
energy, energy,
[pos_tensor, data['_strain']], [pos_tensor, data['_strain']],
create_graph=self.training, create_graph=self.training,
allow_unused=True, allow_unused=True,
# materialize_grads=True, # materialize_grads=True,
) )
# make grad is not Optional[Tensor] # make grad is not Optional[Tensor]
fgrad = grad[0] fgrad = grad[0]
if fgrad is not None: if fgrad is not None:
data[self.key_force] = torch.neg(fgrad) data[self.key_force] = torch.neg(fgrad)
sgrad = grad[1] sgrad = grad[1]
volume = data[self.key_cell_volume] volume = data[self.key_cell_volume]
vlim = 1e-3 # for cell volume = 0 for non PBC structures vlim = 1e-3 # for cell volume = 0 for non PBC structures
if self._is_batch_data: if self._is_batch_data:
volume[volume < vlim] = vlim volume[volume < vlim] = vlim
elif volume < vlim: elif volume < vlim:
volume = torch.tensor(vlim) volume = torch.tensor(vlim)
if sgrad is not None: if sgrad is not None:
if self._is_batch_data: if self._is_batch_data:
stress = sgrad / volume.view(-1, 1, 1) stress = sgrad / volume.view(-1, 1, 1)
stress = torch.neg(stress) stress = torch.neg(stress)
virial_stress = torch.vstack(( virial_stress = torch.vstack((
stress[:, 0, 0], stress[:, 0, 0],
stress[:, 1, 1], stress[:, 1, 1],
stress[:, 2, 2], stress[:, 2, 2],
stress[:, 0, 1], stress[:, 0, 1],
stress[:, 1, 2], stress[:, 1, 2],
stress[:, 0, 2], stress[:, 0, 2],
)) ))
data[self.key_stress] = virial_stress.transpose(0, 1) data[self.key_stress] = virial_stress.transpose(0, 1)
else: else:
stress = sgrad / volume stress = sgrad / volume
stress = torch.neg(stress) stress = torch.neg(stress)
virial_stress = torch.stack(( virial_stress = torch.stack((
stress[0, 0], stress[0, 0],
stress[1, 1], stress[1, 1],
stress[2, 2], stress[2, 2],
stress[0, 1], stress[0, 1],
stress[1, 2], stress[1, 2],
stress[0, 2], stress[0, 2],
)) ))
data[self.key_stress] = virial_stress data[self.key_stress] = virial_stress
return data return data
@compile_mode('script') @compile_mode('script')
class ForceStressOutputFromEdge(nn.Module): class ForceStressOutputFromEdge(nn.Module):
""" """
Compute stress and force from edge. Compute stress and force from edge.
Used in parallel torchscipt models, and training Used in parallel torchscipt models, and training
""" """
def __init__( def __init__(
self, self,
data_key_edge: str = KEY.EDGE_VEC, data_key_edge: str = KEY.EDGE_VEC,
data_key_edge_idx: str = KEY.EDGE_IDX, data_key_edge_idx: str = KEY.EDGE_IDX,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY, data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE, data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS, data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME, data_key_cell_volume: str = KEY.CELL_VOLUME,
): ):
super().__init__() super().__init__()
self.key_edge = data_key_edge self.key_edge = data_key_edge
self.key_edge_idx = data_key_edge_idx self.key_edge_idx = data_key_edge_idx
self.key_energy = data_key_energy self.key_energy = data_key_energy
self.key_force = data_key_force self.key_force = data_key_force
self.key_stress = data_key_stress self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True self._is_batch_data = True
def get_grad_key(self): def get_grad_key(self):
return self.key_edge return self.key_edge
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item? tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item?
rij = data[self.key_edge] rij = data[self.key_edge]
energy = [(data[self.key_energy]).sum()] energy = [(data[self.key_energy]).sum()]
edge_idx = data[self.key_edge_idx] edge_idx = data[self.key_edge_idx]
grad = torch.autograd.grad( grad = torch.autograd.grad(
energy, energy,
[rij], [rij],
create_graph=self.training, create_graph=self.training,
allow_unused=True allow_unused=True
) )
# make grad is not Optional[Tensor] # make grad is not Optional[Tensor]
fij = grad[0] fij = grad[0]
if fij is not None: if fij is not None:
# compute force # compute force
pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
_edge_src = broadcast(edge_idx[0], fij, 0) _edge_src = broadcast(edge_idx[0], fij, 0)
_edge_dst = broadcast(edge_idx[1], fij, 0) _edge_dst = broadcast(edge_idx[1], fij, 0)
pf.scatter_reduce_(0, _edge_src, fij, reduce='sum') pf.scatter_reduce_(0, _edge_src, fij, reduce='sum')
nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum') nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum')
data[self.key_force] = pf - nf data[self.key_force] = pf - nf
# compute virial # compute virial
diag = rij * fij diag = rij * fij
s12 = rij[..., 0] * fij[..., 1] s12 = rij[..., 0] * fij[..., 1]
s23 = rij[..., 1] * fij[..., 2] s23 = rij[..., 1] * fij[..., 2]
s31 = rij[..., 2] * fij[..., 0] s31 = rij[..., 2] * fij[..., 0]
# cat last dimension # cat last dimension
_virial = torch.cat([ _virial = torch.cat([
diag, diag,
s12.unsqueeze(-1), s12.unsqueeze(-1),
s23.unsqueeze(-1), s23.unsqueeze(-1),
s31.unsqueeze(-1) s31.unsqueeze(-1)
], dim=-1) ], dim=-1)
_s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device)
_edge_dst6 = broadcast(edge_idx[1], _virial, 0) _edge_dst6 = broadcast(edge_idx[1], _virial, 0)
_s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum')
if self._is_batch_data: if self._is_batch_data:
batch = data[KEY.BATCH] # for deploy, must be defined first batch = data[KEY.BATCH] # for deploy, must be defined first
nbatch = int(batch.max().cpu().item()) + 1 nbatch = int(batch.max().cpu().item()) + 1
sout = torch.zeros( sout = torch.zeros(
(nbatch, 6), dtype=_virial.dtype, device=_virial.device (nbatch, 6), dtype=_virial.dtype, device=_virial.device
) )
_batch = broadcast(batch, _s, 0) _batch = broadcast(batch, _s, 0)
sout.scatter_reduce_(0, _batch, _s, reduce='sum') sout.scatter_reduce_(0, _batch, _s, reduce='sum')
else: else:
sout = torch.sum(_s, dim=0) sout = torch.sum(_s, dim=0)
data[self.key_stress] =\ data[self.key_stress] =\
torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1) torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1)
return data return data
from typing import Callable, List, Tuple from typing import Callable, List, Tuple
from e3nn.o3 import Irreps from e3nn.o3 import Irreps
import sevenn._keys as KEY import sevenn._keys as KEY
from .convolution import IrrepsConvolution from .convolution import IrrepsConvolution
from .equivariant_gate import EquivariantGate from .equivariant_gate import EquivariantGate
from .linear import IrrepsLinear from .linear import IrrepsLinear
def NequIP_interaction_block( def NequIP_interaction_block(
irreps_x: Irreps, irreps_x: Irreps,
irreps_filter: Irreps, irreps_filter: Irreps,
irreps_out_tp: Irreps, irreps_out_tp: Irreps,
irreps_out: Irreps, irreps_out: Irreps,
weight_nn_layers: List[int], weight_nn_layers: List[int],
conv_denominator: float, conv_denominator: float,
train_conv_denominator: bool, train_conv_denominator: bool,
self_connection_pair: Tuple[Callable, Callable], self_connection_pair: Tuple[Callable, Callable],
act_scalar: Callable, act_scalar: Callable,
act_gate: Callable, act_gate: Callable,
act_radial: Callable, act_radial: Callable,
bias_in_linear: bool, bias_in_linear: bool,
num_species: int, num_species: int,
t: int, # interaction layer index t: int, # interaction layer index
data_key_x: str = KEY.NODE_FEATURE, data_key_x: str = KEY.NODE_FEATURE,
data_key_weight_input: str = KEY.EDGE_EMBEDDING, data_key_weight_input: str = KEY.EDGE_EMBEDDING,
parallel: bool = False, parallel: bool = False,
**conv_kwargs, **conv_kwargs,
): ):
block = {} block = {}
irreps_node_attr = Irreps(f'{num_species}x0e') irreps_node_attr = Irreps(f'{num_species}x0e')
sc_intro, sc_outro = self_connection_pair sc_intro, sc_outro = self_connection_pair
gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate) gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate)
irreps_for_gate_in = gate_layer.get_gate_irreps_in() irreps_for_gate_in = gate_layer.get_gate_irreps_in()
block[f'{t}_self_connection_intro'] = sc_intro( block[f'{t}_self_connection_intro'] = sc_intro(
irreps_x, irreps_x,
irreps_operand=irreps_node_attr, irreps_operand=irreps_node_attr,
irreps_out=irreps_for_gate_in, irreps_out=irreps_for_gate_in,
) )
block[f'{t}_self_interaction_1'] = IrrepsLinear( block[f'{t}_self_interaction_1'] = IrrepsLinear(
irreps_x, irreps_x, irreps_x, irreps_x,
data_key_in=data_key_x, data_key_in=data_key_x,
biases=bias_in_linear, biases=bias_in_linear,
) )
# convolution part, l>lmax is dropped as defined in irreps_out # convolution part, l>lmax is dropped as defined in irreps_out
block[f'{t}_convolution'] = IrrepsConvolution( block[f'{t}_convolution'] = IrrepsConvolution(
irreps_x=irreps_x, irreps_x=irreps_x,
irreps_filter=irreps_filter, irreps_filter=irreps_filter,
irreps_out=irreps_out_tp, irreps_out=irreps_out_tp,
data_key_weight_input=data_key_weight_input, data_key_weight_input=data_key_weight_input,
weight_layer_input_to_hidden=weight_nn_layers, weight_layer_input_to_hidden=weight_nn_layers,
weight_layer_act=act_radial, weight_layer_act=act_radial,
denominator=conv_denominator, denominator=conv_denominator,
train_denominator=train_conv_denominator, train_denominator=train_conv_denominator,
is_parallel=parallel, is_parallel=parallel,
**conv_kwargs, **conv_kwargs,
) )
# irreps of x increase to gate_irreps_in # irreps of x increase to gate_irreps_in
block[f'{t}_self_interaction_2'] = IrrepsLinear( block[f'{t}_self_interaction_2'] = IrrepsLinear(
irreps_out_tp, irreps_out_tp,
irreps_for_gate_in, irreps_for_gate_in,
data_key_in=data_key_x, data_key_in=data_key_x,
biases=bias_in_linear, biases=bias_in_linear,
) )
block[f'{t}_self_connection_outro'] = sc_outro() block[f'{t}_self_connection_outro'] = sc_outro()
block[f'{t}_equivariant_gate'] = gate_layer block[f'{t}_equivariant_gate'] = gate_layer
return block return block
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from e3nn.nn import FullyConnectedNet from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, Linear from e3nn.o3 import Irreps, Linear
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
@compile_mode('script') @compile_mode('script')
class IrrepsLinear(nn.Module): class IrrepsLinear(nn.Module):
""" """
wrapper class of e3nn Linear to operate on AtomGraphData wrapper class of e3nn Linear to operate on AtomGraphData
""" """
def __init__( def __init__(
self, self,
irreps_in: Irreps, irreps_in: Irreps,
irreps_out: Irreps, irreps_out: Irreps,
data_key_in: str, data_key_in: str,
data_key_out: Optional[str] = None, data_key_out: Optional[str] = None,
data_key_modal_attr: str = KEY.MODAL_ATTR, data_key_modal_attr: str = KEY.MODAL_ATTR,
num_modalities: int = 0, num_modalities: int = 0,
lazy_layer_instantiate: bool = True, lazy_layer_instantiate: bool = True,
**linear_kwargs, **linear_kwargs,
): ):
super().__init__() super().__init__()
self.key_input = data_key_in self.key_input = data_key_in
if data_key_out is None: if data_key_out is None:
self.key_output = data_key_in self.key_output = data_key_in
else: else:
self.key_output = data_key_out self.key_output = data_key_out
self.key_modal_attr = data_key_modal_attr self.key_modal_attr = data_key_modal_attr
self._irreps_in_wo_modal = irreps_in self._irreps_in_wo_modal = irreps_in
self.irreps_in = irreps_in self.irreps_in = irreps_in
self.irreps_out = irreps_out self.irreps_out = irreps_out
self.linear_kwargs = linear_kwargs self.linear_kwargs = linear_kwargs
self.linear = None self.linear = None
self.layer_instantiated = False self.layer_instantiated = False
self.num_modalities = num_modalities self.num_modalities = num_modalities
self._is_batch_data = True self._is_batch_data = True
# use getter setter # use getter setter
self.linear_cls = Linear self.linear_cls = Linear
if num_modalities > 1: # in case of multi-modal if num_modalities > 1: # in case of multi-modal
self.set_num_modalities(num_modalities) self.set_num_modalities(num_modalities)
if not lazy_layer_instantiate: if not lazy_layer_instantiate:
self.instantiate() self.instantiate()
def instantiate(self): def instantiate(self):
if self.linear is not None: if self.linear is not None:
raise ValueError('Linear layer already exists') raise ValueError('Linear layer already exists')
self.linear = self.linear_cls( self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs self.irreps_in, self.irreps_out, **self.linear_kwargs
) )
self.layer_instantiated = True self.layer_instantiated = True
def set_num_modalities(self, num_modalities): def set_num_modalities(self, num_modalities):
if self.layer_instantiated: if self.layer_instantiated:
raise ValueError('Layer already instantiated, can not change modalities') raise ValueError('Layer already instantiated, can not change modalities')
irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e') irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e')
self.num_modalities = num_modalities self.num_modalities = num_modalities
self.irreps_in = irreps_in self.irreps_in = irreps_in
def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType: def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data: if self._is_batch_data:
batch = data[KEY.BATCH] batch = data[KEY.BATCH]
batch_modality_onehot = data[self.key_modal_attr].reshape( batch_modality_onehot = data[self.key_modal_attr].reshape(
-1, self.num_modalities -1, self.num_modalities
) )
batch_modality_onehot = batch_modality_onehot.type( batch_modality_onehot = batch_modality_onehot.type(
data[self.key_input].dtype data[self.key_input].dtype
) )
data[self.key_input] = torch.cat( data[self.key_input] = torch.cat(
[data[self.key_input], batch_modality_onehot[batch]], dim=1 [data[self.key_input], batch_modality_onehot[batch]], dim=1
) )
else: else:
modality_onehot = data[self.key_modal_attr].expand( modality_onehot = data[self.key_modal_attr].expand(
len(data[self.key_input]), -1 len(data[self.key_input]), -1
) )
modality_onehot = modality_onehot.type(data[self.key_input].dtype) modality_onehot = modality_onehot.type(data[self.key_input].dtype)
data[self.key_input] = torch.cat( data[self.key_input] = torch.cat(
[data[self.key_input], modality_onehot], dim=1 [data[self.key_input], modality_onehot], dim=1
) )
return data return data
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated' assert self.linear is not None, 'Layer is not instantiated'
if self.num_modalities > 1: if self.num_modalities > 1:
data = self._patch_modal_to_data(data) data = self._patch_modal_to_data(data)
data[self.key_output] = self.linear(data[self.key_input]) data[self.key_output] = self.linear(data[self.key_input])
return data return data
@compile_mode('script') @compile_mode('script')
class AtomReduce(nn.Module): class AtomReduce(nn.Module):
""" """
atomic energy -> total energy atomic energy -> total energy
constant is multiplied to data constant is multiplied to data
""" """
def __init__( def __init__(
self, self,
data_key_in: str, data_key_in: str,
data_key_out: str, data_key_out: str,
reduce: str = 'sum', reduce: str = 'sum',
constant: float = 1.0, constant: float = 1.0,
): ):
super().__init__() super().__init__()
self.key_input = data_key_in self.key_input = data_key_in
self.key_output = data_key_out self.key_output = data_key_out
self.constant = constant self.constant = constant
self.reduce = reduce self.reduce = reduce
# controlled by the upper most wrapper 'AtomGraphSequential' # controlled by the upper most wrapper 'AtomGraphSequential'
self._is_batch_data = True self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data: if self._is_batch_data:
src = data[self.key_input].squeeze(1) src = data[self.key_input].squeeze(1)
size = int(data[KEY.BATCH].max()) + 1 size = int(data[KEY.BATCH].max()) + 1
output = torch.zeros( output = torch.zeros(
(size), (size),
dtype=src.dtype, dtype=src.dtype,
device=src.device, device=src.device,
) )
output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum') output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum')
data[self.key_output] = output * self.constant data[self.key_output] = output * self.constant
else: else:
data[self.key_output] = torch.sum(data[self.key_input]) * self.constant data[self.key_output] = torch.sum(data[self.key_input]) * self.constant
return data return data
@compile_mode('script') @compile_mode('script')
class FCN_e3nn(nn.Module): class FCN_e3nn(nn.Module):
""" """
wrapper class of e3nn FullyConnectedNet wrapper class of e3nn FullyConnectedNet
""" """
def __init__( def __init__(
self, self,
irreps_in: Irreps, # confirm it is scalar & input size irreps_in: Irreps, # confirm it is scalar & input size
dim_out: int, dim_out: int,
hidden_neurons: List[int], hidden_neurons: List[int],
activation: Callable, activation: Callable,
data_key_in: str, data_key_in: str,
data_key_out: Optional[str] = None, data_key_out: Optional[str] = None,
**e3nn_kwargs, **e3nn_kwargs,
): ):
super().__init__() super().__init__()
self.key_input = data_key_in self.key_input = data_key_in
self.irreps_in = irreps_in self.irreps_in = irreps_in
if data_key_out is None: if data_key_out is None:
self.key_output = data_key_in self.key_output = data_key_in
else: else:
self.key_output = data_key_out self.key_output = data_key_out
for _, irrep in irreps_in: for _, irrep in irreps_in:
assert irrep.is_scalar() assert irrep.is_scalar()
inp_dim = irreps_in.dim inp_dim = irreps_in.dim
self.fcn = FullyConnectedNet( self.fcn = FullyConnectedNet(
[inp_dim] + hidden_neurons + [dim_out], [inp_dim] + hidden_neurons + [dim_out],
activation, activation,
**e3nn_kwargs, **e3nn_kwargs,
) )
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = self.fcn(data[self.key_input]) data[self.key_output] = self.fcn(data[self.key_input])
return data return data
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional import torch.nn.functional
from ase.symbols import symbols2numbers from ase.symbols import symbols2numbers
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
# TODO: put this to model_build and do not preprocess data by onehot # TODO: put this to model_build and do not preprocess data by onehot
@compile_mode('script') @compile_mode('script')
class OnehotEmbedding(nn.Module): class OnehotEmbedding(nn.Module):
""" """
x : tensor of shape (N, 1) x : tensor of shape (N, 1)
x_after : tensor of shape (N, num_classes) x_after : tensor of shape (N, num_classes)
It overwrite data_key_x It overwrite data_key_x
and saves input to data_key_save and output to data_key_additional and saves input to data_key_save and output to data_key_additional
I know this is strange but it is for compatibility with previous version I know this is strange but it is for compatibility with previous version
and to specie wise shift scale work and to specie wise shift scale work
ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2) ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2)
""" """
def __init__( def __init__(
self, self,
num_classes: int, num_classes: int,
data_key_x: str = KEY.NODE_FEATURE, data_key_x: str = KEY.NODE_FEATURE,
data_key_out: Optional[str] = None, data_key_out: Optional[str] = None,
data_key_save: Optional[str] = None, data_key_save: Optional[str] = None,
data_key_additional: Optional[str] = None, # additional output data_key_additional: Optional[str] = None, # additional output
): ):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.key_x = data_key_x self.key_x = data_key_x
if data_key_out is None: if data_key_out is None:
self.key_output = data_key_x self.key_output = data_key_x
else: else:
self.key_output = data_key_out self.key_output = data_key_out
self.key_save = data_key_save self.key_save = data_key_save
self.key_additional_output = data_key_additional self.key_additional_output = data_key_additional
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
inp = data[self.key_x] inp = data[self.key_x]
embd = torch.nn.functional.one_hot(inp, self.num_classes) embd = torch.nn.functional.one_hot(inp, self.num_classes)
embd = embd.float() embd = embd.float()
data[self.key_output] = embd data[self.key_output] = embd
if self.key_additional_output is not None: if self.key_additional_output is not None:
data[self.key_additional_output] = embd # for self-connection data[self.key_additional_output] = embd # for self-connection
if self.key_save is not None: if self.key_save is not None:
data[self.key_save] = inp # for elemwise shift scale data[self.key_save] = inp # for elemwise shift scale
return data return data
def get_type_mapper_from_specie(specie_list: List[str]): def get_type_mapper_from_specie(specie_list: List[str]):
""" """
from ['Hf', 'O'] from ['Hf', 'O']
return {72: 0, 8: 1} return {72: 0, 8: 1}
""" """
specie_list = sorted(specie_list) specie_list = sorted(specie_list)
type_map = {} type_map = {}
unique_counter = 0 unique_counter = 0
for specie in specie_list: for specie in specie_list:
atomic_num = symbols2numbers(specie)[0] atomic_num = symbols2numbers(specie)[0]
if atomic_num in type_map: if atomic_num in type_map:
continue continue
type_map[atomic_num] = unique_counter type_map[atomic_num] = unique_counter
unique_counter += 1 unique_counter += 1
return type_map return type_map
# deprecated # deprecated
def one_hot_atom_embedding( def one_hot_atom_embedding(
atomic_numbers: List[int], type_map: Dict[int, int] atomic_numbers: List[int], type_map: Dict[int, int]
): ):
""" """
atomic numbers from ase.get_atomic_numbers atomic numbers from ase.get_atomic_numbers
type_map from get_type_mapper_from_specie() type_map from get_type_mapper_from_specie()
""" """
num_classes = len(type_map) num_classes = len(type_map)
try: try:
type_numbers = torch.LongTensor( type_numbers = torch.LongTensor(
[type_map[num] for num in atomic_numbers] [type_map[num] for num in atomic_numbers]
) )
except KeyError as e: except KeyError as e:
raise ValueError(f'Atomic number {e.args[0]} is not expected') raise ValueError(f'Atomic number {e.args[0]} is not expected')
embd = torch.nn.functional.one_hot(type_numbers, num_classes) embd = torch.nn.functional.one_hot(type_numbers, num_classes)
embd = embd.to(torch.get_default_dtype()) embd = embd.to(torch.get_default_dtype())
return embd return embd
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType
def _as_univ( def _as_univ(
ss: List[float], type_map: Dict[int, int], default: float ss: List[float], type_map: Dict[int, int], default: float
) -> List[float]: ) -> List[float]:
assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long' assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long'
return [ return [
ss[type_map[z]] if z in type_map else default ss[type_map[z]] if z in type_map else default
for z in range(NUM_UNIV_ELEMENT) for z in range(NUM_UNIV_ELEMENT)
] ]
@compile_mode('script') @compile_mode('script')
class Rescale(nn.Module): class Rescale(nn.Module):
""" """
Scaling and shifting energy (and automatically force and stress) Scaling and shifting energy (and automatically force and stress)
""" """
def __init__( def __init__(
self, self,
shift: float, shift: float,
scale: float, scale: float,
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY, data_key_out: str = KEY.ATOMIC_ENERGY,
train_shift_scale: bool = False, train_shift_scale: bool = False,
**kwargs, **kwargs,
): ):
assert isinstance(shift, float) and isinstance(scale, float) assert isinstance(shift, float) and isinstance(scale, float)
super().__init__() super().__init__()
self.shift = nn.Parameter( self.shift = nn.Parameter(
torch.FloatTensor([shift]), requires_grad=train_shift_scale torch.FloatTensor([shift]), requires_grad=train_shift_scale
) )
self.scale = nn.Parameter( self.scale = nn.Parameter(
torch.FloatTensor([scale]), requires_grad=train_shift_scale torch.FloatTensor([scale]), requires_grad=train_shift_scale
) )
self.key_input = data_key_in self.key_input = data_key_in
self.key_output = data_key_out self.key_output = data_key_out
def get_shift(self) -> float: def get_shift(self) -> float:
return self.shift.detach().cpu().tolist()[0] return self.shift.detach().cpu().tolist()[0]
def get_scale(self) -> float: def get_scale(self) -> float:
return self.scale.detach().cpu().tolist()[0] return self.scale.detach().cpu().tolist()[0]
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = data[self.key_input] * self.scale + self.shift data[self.key_output] = data[self.key_input] * self.scale + self.shift
return data return data
@compile_mode('script') @compile_mode('script')
class SpeciesWiseRescale(nn.Module): class SpeciesWiseRescale(nn.Module):
""" """
Scaling and shifting energy (and automatically force and stress) Scaling and shifting energy (and automatically force and stress)
Use as it is if given list, expand to list if one of them is float Use as it is if given list, expand to list if one of them is float
If two lists are given and length is not the same, raise error If two lists are given and length is not the same, raise error
""" """
def __init__( def __init__(
self, self,
shift: Union[List[float], float], shift: Union[List[float], float],
scale: Union[List[float], float], scale: Union[List[float], float],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY, data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_indices: str = KEY.ATOM_TYPE, data_key_indices: str = KEY.ATOM_TYPE,
train_shift_scale: bool = False, train_shift_scale: bool = False,
): ):
super().__init__() super().__init__()
assert isinstance(shift, float) or isinstance(shift, list) assert isinstance(shift, float) or isinstance(shift, list)
assert isinstance(scale, float) or isinstance(scale, list) assert isinstance(scale, float) or isinstance(scale, list)
if ( if (
isinstance(shift, list) isinstance(shift, list)
and isinstance(scale, list) and isinstance(scale, list)
and len(shift) != len(scale) and len(shift) != len(scale)
): ):
raise ValueError('List length should be same') raise ValueError('List length should be same')
if isinstance(shift, list): if isinstance(shift, list):
num_species = len(shift) num_species = len(shift)
elif isinstance(scale, list): elif isinstance(scale, list):
num_species = len(scale) num_species = len(scale)
else: else:
raise ValueError('Both shift and scale is not a list') raise ValueError('Both shift and scale is not a list')
shift = [shift] * num_species if isinstance(shift, float) else shift shift = [shift] * num_species if isinstance(shift, float) else shift
scale = [scale] * num_species if isinstance(scale, float) else scale scale = [scale] * num_species if isinstance(scale, float) else scale
self.shift = nn.Parameter( self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale torch.FloatTensor(shift), requires_grad=train_shift_scale
) )
self.scale = nn.Parameter( self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale torch.FloatTensor(scale), requires_grad=train_shift_scale
) )
self.key_input = data_key_in self.key_input = data_key_in
self.key_output = data_key_out self.key_output = data_key_out
self.key_indices = data_key_indices self.key_indices = data_key_indices
def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
""" """
Return shift in list of float. If type_map is given, return type_map reversed Return shift in list of float. If type_map is given, return type_map reversed
shift, which index equals atomic_number. 0.0 is assigned for atomis not found shift, which index equals atomic_number. 0.0 is assigned for atomis not found
""" """
shift = self.shift.detach().cpu().tolist() shift = self.shift.detach().cpu().tolist()
if type_map: if type_map:
shift = _as_univ(shift, type_map, 0.0) shift = _as_univ(shift, type_map, 0.0)
return shift return shift
def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
""" """
Return scale in list of float. If type_map is given, return type_map reversed Return scale in list of float. If type_map is given, return type_map reversed
scale, which index equals atomic_number. 1.0 is assigned for atomis not found scale, which index equals atomic_number. 1.0 is assigned for atomis not found
""" """
scale = self.scale.detach().cpu().tolist() scale = self.scale.detach().cpu().tolist()
if type_map: if type_map:
scale = _as_univ(scale, type_map, 1.0) scale = _as_univ(scale, type_map, 1.0)
return scale return scale
@staticmethod @staticmethod
def from_mappers( def from_mappers(
shift: Union[float, List[float]], shift: Union[float, List[float]],
scale: Union[float, List[float]], scale: Union[float, List[float]],
type_map: Dict[int, int], type_map: Dict[int, int],
**kwargs, **kwargs,
): ):
""" """
Fit dimensions or mapping raw shift scale values to that is valid under Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices) the given type_map: (atomic_numbers -> type_indices)
""" """
shift_scale = [] shift_scale = []
n_atom_types = len(type_map) n_atom_types = len(type_map)
for s in (shift, scale): for s in (shift, scale):
if isinstance(s, list) and len(s) > n_atom_types: if isinstance(s, list) and len(s) > n_atom_types:
if len(s) != NUM_UNIV_ELEMENT: if len(s) != NUM_UNIV_ELEMENT:
raise ValueError('given shift or scale is strange') raise ValueError('given shift or scale is strange')
s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])] s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])]
# s = [s[z] for z in sorted(type_map, key=type_map.get)] # s = [s[z] for z in sorted(type_map, key=type_map.get)]
elif isinstance(s, float): elif isinstance(s, float):
s = [s] * n_atom_types s = [s] * n_atom_types
elif isinstance(s, list) and len(s) == 1: elif isinstance(s, list) and len(s) == 1:
s = s * n_atom_types s = s * n_atom_types
shift_scale.append(s) shift_scale.append(s)
assert all([len(s) == n_atom_types for s in shift_scale]) assert all([len(s) == n_atom_types for s in shift_scale])
shift, scale = shift_scale shift, scale = shift_scale
return SpeciesWiseRescale(shift, scale, **kwargs) return SpeciesWiseRescale(shift, scale, **kwargs)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
indices = data[self.key_indices] indices = data[self.key_indices]
data[self.key_output] = data[self.key_input] * self.scale[indices].view( data[self.key_output] = data[self.key_input] * self.scale[indices].view(
-1, 1 -1, 1
) + self.shift[indices].view(-1, 1) ) + self.shift[indices].view(-1, 1)
return data return data
@compile_mode('script') @compile_mode('script')
class ModalWiseRescale(nn.Module): class ModalWiseRescale(nn.Module):
""" """
Scaling and shifting energy (and automatically force and stress) Scaling and shifting energy (and automatically force and stress)
Given shift or scale is either modal-wise and atom-wise or Given shift or scale is either modal-wise and atom-wise or
not modal-wise but atom-wise. It is always interpreted as atom-wise. not modal-wise but atom-wise. It is always interpreted as atom-wise.
""" """
def __init__( def __init__(
self, self,
shift: List[List[float]], shift: List[List[float]],
scale: List[List[float]], scale: List[List[float]],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY, data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_modal_indices: str = KEY.MODAL_TYPE, data_key_modal_indices: str = KEY.MODAL_TYPE,
data_key_atom_indices: str = KEY.ATOM_TYPE, data_key_atom_indices: str = KEY.ATOM_TYPE,
use_modal_wise_shift: bool = False, use_modal_wise_shift: bool = False,
use_modal_wise_scale: bool = False, use_modal_wise_scale: bool = False,
train_shift_scale: bool = False, train_shift_scale: bool = False,
): ):
super().__init__() super().__init__()
self.shift = nn.Parameter( self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale torch.FloatTensor(shift), requires_grad=train_shift_scale
) )
self.scale = nn.Parameter( self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale torch.FloatTensor(scale), requires_grad=train_shift_scale
) )
self.key_input = data_key_in self.key_input = data_key_in
self.key_output = data_key_out self.key_output = data_key_out
self.key_atom_indices = data_key_atom_indices self.key_atom_indices = data_key_atom_indices
self.key_modal_indices = data_key_modal_indices self.key_modal_indices = data_key_modal_indices
self.use_modal_wise_shift = use_modal_wise_shift self.use_modal_wise_shift = use_modal_wise_shift
self.use_modal_wise_scale = use_modal_wise_scale self.use_modal_wise_scale = use_modal_wise_scale
self._is_batch_data = True self._is_batch_data = True
def get_shift( def get_shift(
self, self,
type_map: Optional[Dict[int, int]] = None, type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None, modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]: ) -> Union[List[float], Dict[str, List[float]]]:
""" """
Nothing is given: return as it is Nothing is given: return as it is
type_map is given but not modal wise shift: return univ shift type_map is given but not modal wise shift: return univ shift
both type_map and modal_map is given and modal wise shift: return fully both type_map and modal_map is given and modal wise shift: return fully
resolved modalwise univ shift resolved modalwise univ shift
""" """
shift = self.shift.detach().cpu().tolist() shift = self.shift.detach().cpu().tolist()
if type_map and not self.use_modal_wise_shift: if type_map and not self.use_modal_wise_shift:
shift = _as_univ(shift, type_map, 0.0) shift = _as_univ(shift, type_map, 0.0)
elif self.use_modal_wise_shift and modal_map and type_map: elif self.use_modal_wise_shift and modal_map and type_map:
shift = [_as_univ(s, type_map, 0.0) for s in shift] shift = [_as_univ(s, type_map, 0.0) for s in shift]
shift = {modal: shift[idx] for modal, idx in modal_map.items()} shift = {modal: shift[idx] for modal, idx in modal_map.items()}
return shift return shift
def get_scale( def get_scale(
self, self,
type_map: Optional[Dict[int, int]] = None, type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None, modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]: ) -> Union[List[float], Dict[str, List[float]]]:
""" """
Nothing is given: return as it is Nothing is given: return as it is
type_map is given but not modal wise scale: return univ scale type_map is given but not modal wise scale: return univ scale
both type_map and modal_map is given and modal wise scale: return fully both type_map and modal_map is given and modal wise scale: return fully
resolved modalwise univ scale resolved modalwise univ scale
""" """
scale = self.scale.detach().cpu().tolist() scale = self.scale.detach().cpu().tolist()
if type_map and not self.use_modal_wise_scale: if type_map and not self.use_modal_wise_scale:
scale = _as_univ(scale, type_map, 0.0) scale = _as_univ(scale, type_map, 0.0)
elif self.use_modal_wise_scale and modal_map and type_map: elif self.use_modal_wise_scale and modal_map and type_map:
scale = [_as_univ(s, type_map, 0.0) for s in scale] scale = [_as_univ(s, type_map, 0.0) for s in scale]
scale = {modal: scale[idx] for modal, idx in modal_map.items()} scale = {modal: scale[idx] for modal, idx in modal_map.items()}
return scale return scale
@staticmethod @staticmethod
def from_mappers( def from_mappers(
shift: Union[float, List[float], Dict[str, Any]], shift: Union[float, List[float], Dict[str, Any]],
scale: Union[float, List[float], Dict[str, Any]], scale: Union[float, List[float], Dict[str, Any]],
use_modal_wise_shift: bool, use_modal_wise_shift: bool,
use_modal_wise_scale: bool, use_modal_wise_scale: bool,
type_map: Dict[int, int], type_map: Dict[int, int],
modal_map: Dict[str, int], modal_map: Dict[str, int],
**kwargs, **kwargs,
): ):
""" """
Fit dimensions or mapping raw shift scale values to that is valid under Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices) the given type_map: (atomic_numbers -> type_indices)
If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT
, assume it is element-wise list , assume it is element-wise list
otherwise, it is modal-wise list otherwise, it is modal-wise list
""" """
def solve_mapper(arr, map): def solve_mapper(arr, map):
# value is attr index and never overlap, key is either 'z' or modal str # value is attr index and never overlap, key is either 'z' or modal str
return [arr[z] for z in sorted(map, key=lambda x: map[x])] return [arr[z] for z in sorted(map, key=lambda x: map[x])]
shift_scale = [] shift_scale = []
n_atom_types = len(type_map) n_atom_types = len(type_map)
n_modals = len(modal_map) n_modals = len(modal_map)
for s, use_mw in ( for s, use_mw in (
(shift, use_modal_wise_shift), (shift, use_modal_wise_shift),
(scale, use_modal_wise_scale), (scale, use_modal_wise_scale),
): ):
# solve elemewise, or broadcast # solve elemewise, or broadcast
if isinstance(s, float): if isinstance(s, float):
# given, modal-wise: no, elem-wise: no => broadcast # given, modal-wise: no, elem-wise: no => broadcast
shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,) shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,)
res = torch.full(shape, s).tolist() # TODO: w/o torch res = torch.full(shape, s).tolist() # TODO: w/o torch
elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT: elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT:
# given, modal-wise: no, elem-wise: yes(univ) => solve elem map # given, modal-wise: no, elem-wise: yes(univ) => solve elem map
s = solve_mapper(s, type_map) s = solve_mapper(s, type_map)
res = [s] * n_modals if use_mw else s res = [s] * n_modals if use_mw else s
elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise
isinstance(s, list) isinstance(s, list)
and isinstance(s[0], float) and isinstance(s[0], float)
and len(s) == n_modals and len(s) == n_modals
and use_mw and use_mw
): ):
res = [[v] * n_atom_types for v in s] res = [[v] * n_atom_types for v in s]
elif ( # given, modal-wise: no, elem-wise: yes => as it is elif ( # given, modal-wise: no, elem-wise: yes => as it is
isinstance(s, list) isinstance(s, list)
and isinstance(s[0], float) and isinstance(s[0], float)
and len(s) == n_atom_types and len(s) == n_atom_types
and not use_mw and not use_mw
): ):
res = s res = s
elif ( # given, modal-wise: yes, elem-wise: yes => as it is elif ( # given, modal-wise: yes, elem-wise: yes => as it is
isinstance(s, list) isinstance(s, list)
and isinstance(s[0], list) and isinstance(s[0], list)
and len(s) == n_modals and len(s) == n_modals
and len(s[0]) == n_atom_types and len(s[0]) == n_atom_types
and use_mw and use_mw
): ):
res = s res = s
elif isinstance(s, dict) and use_mw: elif isinstance(s, dict) and use_mw:
# solve modal dict, modal-wise: yes # solve modal dict, modal-wise: yes
s = solve_mapper(s, modal_map) s = solve_mapper(s, modal_map)
res = [] res = []
for v in s: for v in s:
if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT: if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT:
# elem-wise: yes(univ) => solve elem map # elem-wise: yes(univ) => solve elem map
v = solve_mapper(v, type_map) v = solve_mapper(v, type_map)
elif isinstance(v, float): elif isinstance(v, float):
# elem-wise: no => broadcast to elemwise # elem-wise: no => broadcast to elemwise
v = [v] * n_atom_types v = [v] * n_atom_types
else: else:
raise ValueError(f'Invalid shift or scale {s}') raise ValueError(f'Invalid shift or scale {s}')
res.append(v) res.append(v)
else: else:
raise ValueError(f'Invalid shift or scale {s}') raise ValueError(f'Invalid shift or scale {s}')
if use_mw: if use_mw:
assert ( assert (
isinstance(res, list) isinstance(res, list)
and isinstance(res[0], list) and isinstance(res[0], list)
and len(res) == n_modals and len(res) == n_modals
) )
assert all([len(r) == n_atom_types for r in res]) # type: ignore assert all([len(r) == n_atom_types for r in res]) # type: ignore
else: else:
assert ( assert (
isinstance(res, list) isinstance(res, list)
and isinstance(res[0], float) and isinstance(res[0], float)
and len(res) == n_atom_types and len(res) == n_atom_types
) )
shift_scale.append(res) shift_scale.append(res)
shift, scale = shift_scale shift, scale = shift_scale
return ModalWiseRescale( return ModalWiseRescale(
shift, shift,
scale, scale,
use_modal_wise_shift=use_modal_wise_shift, use_modal_wise_shift=use_modal_wise_shift,
use_modal_wise_scale=use_modal_wise_scale, use_modal_wise_scale=use_modal_wise_scale,
**kwargs, **kwargs,
) )
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data: if self._is_batch_data:
batch = data[KEY.BATCH] batch = data[KEY.BATCH]
modal_indices = data[self.key_modal_indices][batch] modal_indices = data[self.key_modal_indices][batch]
else: else:
modal_indices = data[self.key_modal_indices] modal_indices = data[self.key_modal_indices]
atom_indices = data[self.key_atom_indices] atom_indices = data[self.key_atom_indices]
shift = ( shift = (
self.shift[modal_indices, atom_indices] self.shift[modal_indices, atom_indices]
if self.use_modal_wise_shift if self.use_modal_wise_shift
else self.shift[atom_indices] else self.shift[atom_indices]
) )
scale = ( scale = (
self.scale[modal_indices, atom_indices] self.scale[modal_indices, atom_indices]
if self.use_modal_wise_scale if self.use_modal_wise_scale
else self.scale[atom_indices] else self.scale[atom_indices]
) )
data[self.key_output] = data[self.key_input] * scale.view( data[self.key_output] = data[self.key_input] * scale.view(
-1, 1 -1, 1
) + shift.view(-1, 1) ) + shift.view(-1, 1)
return data return data
def get_resolved_shift_scale( def get_resolved_shift_scale(
module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale], module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale],
type_map: Optional[Dict[int, int]] = None, type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None, modal_map: Optional[Dict[str, int]] = None,
): ):
""" """
Return resolved shift and scale from scale modules. For element wise case, Return resolved shift and scale from scale modules. For element wise case,
convert to list of floats where idx is atomic number. For modal wise case, return convert to list of floats where idx is atomic number. For modal wise case, return
dictionary of shift scale where key is modal name given in modal_map dictionary of shift scale where key is modal name given in modal_map
Return: Return:
Tuple of solved shift and scale Tuple of solved shift and scale
""" """
if isinstance(module, Rescale): if isinstance(module, Rescale):
return (module.get_shift(), module.get_scale()) return (module.get_shift(), module.get_scale())
elif isinstance(module, SpeciesWiseRescale): elif isinstance(module, SpeciesWiseRescale):
return (module.get_shift(type_map), module.get_scale(type_map)) return (module.get_shift(type_map), module.get_scale(type_map))
elif isinstance(module, ModalWiseRescale): elif isinstance(module, ModalWiseRescale):
return ( return (
module.get_shift(type_map, modal_map), module.get_shift(type_map, modal_map),
module.get_scale(type_map, modal_map), module.get_scale(type_map, modal_map),
) )
raise ValueError('Not scale module') raise ValueError('Not scale module')
import torch.nn as nn import torch.nn as nn
from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
@compile_mode('script') @compile_mode('script')
class SelfConnectionIntro(nn.Module): class SelfConnectionIntro(nn.Module):
""" """
do TensorProduct of x and some data(here attribute of x) do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro) and save it (to concatenate updated x at SelfConnectionOutro)
""" """
def __init__( def __init__(
self, self,
irreps_in: Irreps, irreps_in: Irreps,
irreps_operand: Irreps, irreps_operand: Irreps,
irreps_out: Irreps, irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE, data_key_x: str = KEY.NODE_FEATURE,
data_key_operand: str = KEY.NODE_ATTR, data_key_operand: str = KEY.NODE_ATTR,
lazy_layer_instantiate: bool = True, lazy_layer_instantiate: bool = True,
**kwargs, # for compatibility **kwargs, # for compatibility
): ):
super().__init__() super().__init__()
self.fc_tensor_product = FullyConnectedTensorProduct( self.fc_tensor_product = FullyConnectedTensorProduct(
irreps_in, irreps_operand, irreps_out irreps_in, irreps_operand, irreps_out
) )
self.irreps_in1 = irreps_in self.irreps_in1 = irreps_in
self.irreps_in2 = irreps_operand self.irreps_in2 = irreps_operand
self.irreps_out = irreps_out self.irreps_out = irreps_out
self.key_x = data_key_x self.key_x = data_key_x
self.key_operand = data_key_operand self.key_operand = data_key_operand
self.fc_tensor_product = None self.fc_tensor_product = None
self.layer_instantiated = False self.layer_instantiated = False
self.fc_tensor_product_cls = FullyConnectedTensorProduct self.fc_tensor_product_cls = FullyConnectedTensorProduct
self.fc_tensor_product_kwargs = kwargs self.fc_tensor_product_kwargs = kwargs
if not lazy_layer_instantiate: if not lazy_layer_instantiate:
self.instantiate() self.instantiate()
def instantiate(self): def instantiate(self):
if self.fc_tensor_product is not None: if self.fc_tensor_product is not None:
raise ValueError('fc_tensor_product layer already exists') raise ValueError('fc_tensor_product layer already exists')
self.fc_tensor_product = self.fc_tensor_product_cls( self.fc_tensor_product = self.fc_tensor_product_cls(
self.irreps_in1, self.irreps_in1,
self.irreps_in2, self.irreps_in2,
self.irreps_out, self.irreps_out,
shared_weights=True, shared_weights=True,
internal_weights=None, # same as True internal_weights=None, # same as True
**self.fc_tensor_product_kwargs, **self.fc_tensor_product_kwargs,
) )
self.layer_instantiated = True self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.fc_tensor_product is not None, 'Layer is not instantiated' assert self.fc_tensor_product is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product( data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product(
data[self.key_x], data[self.key_operand] data[self.key_x], data[self.key_operand]
) )
return data return data
@compile_mode('script') @compile_mode('script')
class SelfConnectionLinearIntro(nn.Module): class SelfConnectionLinearIntro(nn.Module):
""" """
Linear style self connection update Linear style self connection update
""" """
def __init__( def __init__(
self, self,
irreps_in: Irreps, irreps_in: Irreps,
irreps_out: Irreps, irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE, data_key_x: str = KEY.NODE_FEATURE,
lazy_layer_instantiate: bool = True, lazy_layer_instantiate: bool = True,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
self.irreps_in = irreps_in self.irreps_in = irreps_in
self.irreps_out = irreps_out self.irreps_out = irreps_out
self.key_x = data_key_x self.key_x = data_key_x
self.linear = None self.linear = None
self.layer_instantiated = False self.layer_instantiated = False
self.linear_cls = Linear self.linear_cls = Linear
# TODO: better to have SelfConnectionIntro super class # TODO: better to have SelfConnectionIntro super class
kwargs.pop('irreps_operand') kwargs.pop('irreps_operand')
self.linear_kwargs = kwargs self.linear_kwargs = kwargs
if not lazy_layer_instantiate: if not lazy_layer_instantiate:
self.instantiate() self.instantiate()
def instantiate(self): def instantiate(self):
if self.linear is not None: if self.linear is not None:
raise ValueError('Linear layer already exists') raise ValueError('Linear layer already exists')
self.linear = self.linear_cls( self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs self.irreps_in, self.irreps_out, **self.linear_kwargs
) )
self.layer_instantiated = True self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated' assert self.linear is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x]) data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x])
return data return data
@compile_mode('script') @compile_mode('script')
class SelfConnectionOutro(nn.Module): class SelfConnectionOutro(nn.Module):
""" """
do TensorProduct of x and some data(here attribute of x) do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro) and save it (to concatenate updated x at SelfConnectionOutro)
""" """
def __init__( def __init__(
self, self,
data_key_x: str = KEY.NODE_FEATURE, data_key_x: str = KEY.NODE_FEATURE,
): ):
super().__init__() super().__init__()
self.key_x = data_key_x self.key_x = data_key_x
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP] data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP]
del data[KEY.SELF_CONNECTION_TEMP] del data[KEY.SELF_CONNECTION_TEMP]
return data return data
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType from sevenn._const import AtomGraphDataType
def _instantiate_modules(modules): def _instantiate_modules(modules):
# see IrrepsLinear of linear.py # see IrrepsLinear of linear.py
for module in modules.values(): for module in modules.values():
if not getattr(module, 'layer_instantiated', True): if not getattr(module, 'layer_instantiated', True):
module.instantiate() module.instantiate()
@compile_mode('script') @compile_mode('script')
class _ModalInputPrepare(nn.Module): class _ModalInputPrepare(nn.Module):
def __init__( def __init__(
self, self,
modal_idx: int modal_idx: int
): ):
super().__init__() super().__init__()
self.modal_idx = modal_idx self.modal_idx = modal_idx
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[KEY.MODAL_TYPE] = torch.tensor( data[KEY.MODAL_TYPE] = torch.tensor(
self.modal_idx, self.modal_idx,
dtype=torch.int64, dtype=torch.int64,
device=data['x'].device, device=data['x'].device,
) )
return data return data
@compile_mode('script') @compile_mode('script')
class AtomGraphSequential(nn.Sequential): class AtomGraphSequential(nn.Sequential):
""" """
Wrapper of SevenNet model Wrapper of SevenNet model
Args: Args:
modules: OrderedDict of nn.Modules modules: OrderedDict of nn.Modules
cutoff: not used internally, but makes sense to have cutoff: not used internally, but makes sense to have
type_map: atomic_numbers => onehot index (see nn/node_embedding.py) type_map: atomic_numbers => onehot index (see nn/node_embedding.py)
eval_type_map: perform index mapping using type_map defaults to True eval_type_map: perform index mapping using type_map defaults to True
data_key_atomic_numbers: used when eval_type_map is True data_key_atomic_numbers: used when eval_type_map is True
data_key_node_feature: used when eval_type_map is True data_key_node_feature: used when eval_type_map is True
data_key_grad: if given, sets its requires grad True before pred data_key_grad: if given, sets its requires grad True before pred
""" """
def __init__( def __init__(
self, self,
modules: Dict[str, nn.Module], modules: Dict[str, nn.Module],
cutoff: float = 0.0, cutoff: float = 0.0,
type_map: Optional[Dict[int, int]] = None, type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None, modal_map: Optional[Dict[str, int]] = None,
eval_type_map: bool = True, eval_type_map: bool = True,
eval_modal_map: bool = False, eval_modal_map: bool = False,
data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS, data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS,
data_key_node_feature: str = KEY.NODE_FEATURE, data_key_node_feature: str = KEY.NODE_FEATURE,
data_key_grad: Optional[str] = None, data_key_grad: Optional[str] = None,
): ):
if not isinstance(modules, OrderedDict): # backward compat if not isinstance(modules, OrderedDict): # backward compat
modules = OrderedDict(modules) modules = OrderedDict(modules)
self.cutoff = cutoff self.cutoff = cutoff
self.type_map = type_map self.type_map = type_map
self.eval_type_map = eval_type_map self.eval_type_map = eval_type_map
self.is_batch_data = True self.is_batch_data = True
if cutoff == 0.0: if cutoff == 0.0:
warnings.warn('cutoff is 0.0 or not given', UserWarning) warnings.warn('cutoff is 0.0 or not given', UserWarning)
if self.type_map is None: if self.type_map is None:
warnings.warn('type_map is not given', UserWarning) warnings.warn('type_map is not given', UserWarning)
self.eval_type_map = False self.eval_type_map = False
else: else:
z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long)) z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long))
for z, onehot in self.type_map.items(): for z, onehot in self.type_map.items():
z_to_onehot_tensor[z] = onehot z_to_onehot_tensor[z] = onehot
self.z_to_onehot_tensor = z_to_onehot_tensor self.z_to_onehot_tensor = z_to_onehot_tensor
if eval_modal_map and modal_map is None: if eval_modal_map and modal_map is None:
raise ValueError('eval_modal_map is True but modal_map is None') raise ValueError('eval_modal_map is True but modal_map is None')
self.eval_modal_map = eval_modal_map self.eval_modal_map = eval_modal_map
self.modal_map = modal_map self.modal_map = modal_map
self.key_atomic_numbers = data_key_atomic_numbers self.key_atomic_numbers = data_key_atomic_numbers
self.key_node_feature = data_key_node_feature self.key_node_feature = data_key_node_feature
self.key_grad = data_key_grad self.key_grad = data_key_grad
_instantiate_modules(modules) _instantiate_modules(modules)
super().__init__(modules) super().__init__(modules)
if not isinstance(self._modules, OrderedDict): # backward compat if not isinstance(self._modules, OrderedDict): # backward compat
self._modules = OrderedDict(self._modules) self._modules = OrderedDict(self._modules)
def set_is_batch_data(self, flag: bool): def set_is_batch_data(self, flag: bool):
# whether given data is batched or not some module have to change # whether given data is batched or not some module have to change
# its behavior. checking whether data is batched or not inside # its behavior. checking whether data is batched or not inside
# forward function make problem harder when make it into torchscript # forward function make problem harder when make it into torchscript
for module in self: for module in self:
try: # Easier to ask for forgiveness than permission. try: # Easier to ask for forgiveness than permission.
module._is_batch_data = flag # type: ignore module._is_batch_data = flag # type: ignore
except AttributeError: except AttributeError:
pass pass
self.is_batch_data = flag self.is_batch_data = flag
def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'): def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'):
tg_module = self._modules[modlue_name] tg_module = self._modules[modlue_name]
for m in tg_module.modules(): for m in tg_module.modules():
try: try:
return repr(m.__getattribute__(attr_key)) return repr(m.__getattribute__(attr_key))
except AttributeError: except AttributeError:
pass pass
return None return None
def prepand_module(self, key: str, module: nn.Module): def prepand_module(self, key: str, module: nn.Module):
self._modules.update({key: module}) self._modules.update({key: module})
self._modules.move_to_end(key, last=False) # type: ignore self._modules.move_to_end(key, last=False) # type: ignore
def replace_module(self, key: str, module: nn.Module): def replace_module(self, key: str, module: nn.Module):
self._modules.update({key: module}) self._modules.update({key: module})
def delete_module_by_key(self, key: str): def delete_module_by_key(self, key: str):
if key in self._modules.keys(): if key in self._modules.keys():
del self._modules[key] del self._modules[key]
@torch.jit.unused @torch.jit.unused
def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor): def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor):
assert atomic_numbers.dtype == torch.int64 assert atomic_numbers.dtype == torch.int64
device = atomic_numbers.device device = atomic_numbers.device
z_to_onehot_tensor = self.z_to_onehot_tensor.to(device) z_to_onehot_tensor = self.z_to_onehot_tensor.to(device)
return torch.index_select( return torch.index_select(
input=z_to_onehot_tensor, dim=0, index=atomic_numbers input=z_to_onehot_tensor, dim=0, index=atomic_numbers
) )
@torch.jit.unused @torch.jit.unused
def _eval_modal_map(self, data: AtomGraphDataType): def _eval_modal_map(self, data: AtomGraphDataType):
assert self.modal_map is not None assert self.modal_map is not None
# modal_map: dict[str, int] # modal_map: dict[str, int]
if not self.is_batch_data: if not self.is_batch_data:
modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore
else: else:
modal_idx = [ modal_idx = [
self.modal_map[ii] # type: ignore self.modal_map[ii] # type: ignore
for ii in data[KEY.DATA_MODALITY] for ii in data[KEY.DATA_MODALITY]
] ]
modal_idx = torch.tensor( modal_idx = torch.tensor(
modal_idx, modal_idx,
dtype=torch.int64, dtype=torch.int64,
device=data.x.device, # type: ignore device=data.x.device, # type: ignore
) )
data[KEY.MODAL_TYPE] = modal_idx data[KEY.MODAL_TYPE] = modal_idx
def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType: def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self.eval_type_map: if self.eval_type_map:
atomic_numbers = data[self.key_atomic_numbers] atomic_numbers = data[self.key_atomic_numbers]
onehot = self._atomic_numbers_to_onehot(atomic_numbers) onehot = self._atomic_numbers_to_onehot(atomic_numbers)
data[self.key_node_feature] = onehot data[self.key_node_feature] = onehot
if self.eval_modal_map: if self.eval_modal_map:
self._eval_modal_map(data) self._eval_modal_map(data)
if self.key_grad is not None: if self.key_grad is not None:
data[self.key_grad].requires_grad_(True) data[self.key_grad].requires_grad_(True)
return data return data
def prepare_modal_deploy(self, modal: str): def prepare_modal_deploy(self, modal: str):
if self.modal_map is None: if self.modal_map is None:
return return
self.eval_modal_map = False self.eval_modal_map = False
self.set_is_batch_data(False) self.set_is_batch_data(False)
modal_idx = self.modal_map[modal] # type: ignore modal_idx = self.modal_map[modal] # type: ignore
self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx)) self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx))
def forward(self, input: AtomGraphDataType) -> AtomGraphDataType: def forward(self, input: AtomGraphDataType) -> AtomGraphDataType:
data = self._preprocess(input) data = self._preprocess(input)
for module in self: for module in self:
data = module(data) data = module(data)
return data return data
import torch import torch
def broadcast( def broadcast(
src: torch.Tensor, src: torch.Tensor,
other: torch.Tensor, other: torch.Tensor,
dim: int dim: int
): ):
if dim < 0: if dim < 0:
dim = other.dim() + dim dim = other.dim() + dim
if src.dim() == 1: if src.dim() == 1:
for _ in range(0, dim): for _ in range(0, dim):
src = src.unsqueeze(0) src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()): for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1) src = src.unsqueeze(-1)
src = src.expand_as(other) src = src.expand_as(other)
return src return src
#!/bin/bash #!/bin/bash
lammps_root=$1 lammps_root=$1
cxx_standard=$2 # 14, 17 cxx_standard=$2 # 14, 17
d3_support=$3 # 1, 0 d3_support=$3 # 1, 0
SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
########################################### ###########################################
# Check if the given arguments are valid # # Check if the given arguments are valid #
########################################### ###########################################
# Check the number of arguments # Check the number of arguments
if [ "$#" -ne 3 ]; then if [ "$#" -ne 3 ]; then
echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}" echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}"
echo " {lammps_root}: Root directory of LAMMPS source" echo " {lammps_root}: Root directory of LAMMPS source"
echo " {cxx_standard}: C++ standard (14, 17)" echo " {cxx_standard}: C++ standard (14, 17)"
echo " {d3_support}: Support for pair_d3 (1, 0)" echo " {d3_support}: Support for pair_d3 (1, 0)"
exit 1 exit 1
fi fi
# Check if the lammps_root directory exists # Check if the lammps_root directory exists
if [ ! -d "$lammps_root" ]; then if [ ! -d "$lammps_root" ]; then
echo "Error: No such directory: $lammps_root" echo "Error: No such directory: $lammps_root"
exit 1 exit 1
fi fi
# Check if the given directory is the root of LAMMPS source # Check if the given directory is the root of LAMMPS source
if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then
echo "Error: Given $lammps_root is not a root of LAMMPS source" echo "Error: Given $lammps_root is not a root of LAMMPS source"
exit 1 exit 1
fi fi
# Check if the script is being run from the root of SevenNet # Check if the script is being run from the root of SevenNet
if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then
echo "Error: Script executed in a wrong directory" echo "Error: Script executed in a wrong directory"
exit 1 exit 1
fi fi
# Check if the patch is already applied # Check if the patch is already applied
if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then
echo "----------------------------------------------------------" echo "----------------------------------------------------------"
echo "Seems like given LAMMPS is already patched." echo "Seems like given LAMMPS is already patched."
echo "Try again after removing src/pair_e3gnn.cpp to force patch" echo "Try again after removing src/pair_e3gnn.cpp to force patch"
echo "----------------------------------------------------------" echo "----------------------------------------------------------"
echo "Example build commands, under LAMMPS root" echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build" echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4" echo " make -j 4"
exit 0 exit 0
fi fi
# Check if OpenMPI exists and if it is CUDA-aware # Check if OpenMPI exists and if it is CUDA-aware
if command -v ompi_info &> /dev/null; then if command -v ompi_info &> /dev/null; then
cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value) cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value)
if [[ -z "$cuda_support" ]]; then if [[ -z "$cuda_support" ]]; then
echo "OpenMPI not found, parallel performance is not optimal" echo "OpenMPI not found, parallel performance is not optimal"
elif [[ "$cuda_support" == *"true" ]]; then elif [[ "$cuda_support" == *"true" ]]; then
echo "OpenMPI is CUDA aware" echo "OpenMPI is CUDA aware"
else else
echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal" echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal"
fi fi
else else
echo "OpenMPI not found, parallel performance is not optimal" echo "OpenMPI not found, parallel performance is not optimal"
fi fi
# Extract LAMMPS version and update # Extract LAMMPS version and update
lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"') lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"')
# Combine version and update # Combine version and update
detected_version="$lammps_version" detected_version="$lammps_version"
required_version="2 Aug 2023" # Example required version required_version="2 Aug 2023" # Example required version
# Check if the detected version is compatible # Check if the detected version is compatible
if [[ "$detected_version" != "$required_version" ]]; then if [[ "$detected_version" != "$required_version" ]]; then
echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version" echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version"
fi fi
########################################### ###########################################
# Backup original LAMMPS source code # # Backup original LAMMPS source code #
########################################### ###########################################
# Create a backup directory if it doesn't exist # Create a backup directory if it doesn't exist
backup_dir="$lammps_root/_backups" backup_dir="$lammps_root/_backups"
mkdir -p $backup_dir mkdir -p $backup_dir
# Copy comm_* from original LAMMPS source as backup # Copy comm_* from original LAMMPS source as backup
cp $lammps_root/src/comm_brick.cpp $backup_dir/ cp $lammps_root/src/comm_brick.cpp $backup_dir/
cp $lammps_root/src/comm_brick.h $backup_dir/ cp $lammps_root/src/comm_brick.h $backup_dir/
# Copy cmake/CMakeLists.txt from original source as backup # Copy cmake/CMakeLists.txt from original source as backup
cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt
########################################### ###########################################
# Patch LAMMPS source code: e3gnn # # Patch LAMMPS source code: e3gnn #
########################################### ###########################################
# 1. Copy pair_e3gnn files to LAMMPS source # 1. Copy pair_e3gnn files to LAMMPS source
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt # 2. Patch cmake/CMakeLists.txt
sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}") target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}")
EOF EOF
########################################### ###########################################
# Patch LAMMPS source code: d3 # # Patch LAMMPS source code: d3 #
########################################### ###########################################
if [ "$d3_support" -ne 0 ]; then if [ "$d3_support" -ne 0 ]; then
# 1. Copy pair_d3 files to LAMMPS source # 1. Copy pair_d3 files to LAMMPS source
cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/ cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/
cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/ cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/
cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/ cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt # 2. Patch cmake/CMakeLists.txt
sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt
sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(CUDA) find_package(CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3")
string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda) target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda)
EOF EOF
fi fi
########################################### ###########################################
# Print changes and backup file locations # # Print changes and backup file locations #
########################################### ###########################################
# Print changes and backup file locations # Print changes and backup file locations
echo "Changes made:" echo "Changes made:"
echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups"
echo " - Copied contents of pair_e3gnn to $lammps_root/src/" echo " - Copied contents of pair_e3gnn to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard"
if [ "$d3_support" -ne 0 ]; then if [ "$d3_support" -ne 0 ]; then
echo " - Copied contents of pair_d3 to $lammps_root/src/" echo " - Copied contents of pair_d3 to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include CUDA" echo " - Patched CMakeLists.txt: include CUDA"
fi fi
# Provide example cmake command to the user # Provide example cmake command to the user
echo "Example build commands, under LAMMPS root" echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build" echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4" echo " make -j 4"
exit 0 exit 0
import glob import glob
import os import os
import warnings import warnings
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
import torch import torch
import yaml import yaml
import sevenn._const as _const import sevenn._const as _const
import sevenn._keys as KEY import sevenn._keys as KEY
import sevenn.util as util import sevenn.util as util
def config_initialize( def config_initialize(
key: str, key: str,
config: Dict, config: Dict,
default: Any, default: Any,
conditions: Dict, conditions: Dict,
): ):
# default value exist & no user input -> return default # default value exist & no user input -> return default
if key not in config.keys(): if key not in config.keys():
return default return default
# No validation method exist => accept user input # No validation method exist => accept user input
user_input = config[key] user_input = config[key]
if key in conditions: if key in conditions:
condition = conditions[key] condition = conditions[key]
else: else:
return user_input return user_input
if type(default) is dict and isinstance(condition, dict): if type(default) is dict and isinstance(condition, dict):
for i_key, val in default.items(): for i_key, val in default.items():
user_input[i_key] = config_initialize( user_input[i_key] = config_initialize(
i_key, user_input, val, condition i_key, user_input, val, condition
) )
return user_input return user_input
elif isinstance(condition, type): elif isinstance(condition, type):
if isinstance(user_input, condition): if isinstance(user_input, condition):
return user_input return user_input
else: else:
try: try:
return condition(user_input) # try type casting return condition(user_input) # try type casting
except ValueError: except ValueError:
raise ValueError( raise ValueError(
f"Expect '{user_input}' for '{key}' is {condition}" f"Expect '{user_input}' for '{key}' is {condition}"
) )
elif isinstance(condition, Callable) and condition(user_input): elif isinstance(condition, Callable) and condition(user_input):
return user_input return user_input
else: else:
raise ValueError( raise ValueError(
f"Given input '{user_input}' for '{key}' is not valid" f"Given input '{user_input}' for '{key}' is not valid"
) )
def init_model_config(config: Dict): def init_model_config(config: Dict):
# defaults = _const.model_defaults(config) # defaults = _const.model_defaults(config)
model_meta = {} model_meta = {}
# init complicated ones # init complicated ones
if KEY.CHEMICAL_SPECIES not in config.keys(): if KEY.CHEMICAL_SPECIES not in config.keys():
raise ValueError('required key chemical_species not exist') raise ValueError('required key chemical_species not exist')
input_chem = config[KEY.CHEMICAL_SPECIES] input_chem = config[KEY.CHEMICAL_SPECIES]
if isinstance(input_chem, str) and input_chem.lower() == 'auto': if isinstance(input_chem, str) and input_chem.lower() == 'auto':
model_meta[KEY.CHEMICAL_SPECIES] = 'auto' model_meta[KEY.CHEMICAL_SPECIES] = 'auto'
model_meta[KEY.NUM_SPECIES] = 'auto' model_meta[KEY.NUM_SPECIES] = 'auto'
model_meta[KEY.TYPE_MAP] = 'auto' model_meta[KEY.TYPE_MAP] = 'auto'
elif isinstance(input_chem, str) and 'univ' in input_chem.lower(): elif isinstance(input_chem, str) and 'univ' in input_chem.lower():
model_meta.update(util.chemical_species_preprocess([], universal=True)) model_meta.update(util.chemical_species_preprocess([], universal=True))
else: else:
if isinstance(input_chem, list) and all( if isinstance(input_chem, list) and all(
isinstance(x, str) for x in input_chem isinstance(x, str) for x in input_chem
): ):
pass pass
elif isinstance(input_chem, str): elif isinstance(input_chem, str):
input_chem = ( input_chem = (
input_chem.replace('-', ',').replace(' ', ',').split(',') input_chem.replace('-', ',').replace(' ', ',').split(',')
) )
input_chem = [chem for chem in input_chem if len(chem) != 0] input_chem = [chem for chem in input_chem if len(chem) != 0]
else: else:
raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange') raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange')
model_meta.update(util.chemical_species_preprocess(input_chem)) model_meta.update(util.chemical_species_preprocess(input_chem))
# deprecation warnings # deprecation warnings
if KEY.AVG_NUM_NEIGH in config: if KEY.AVG_NUM_NEIGH in config:
warnings.warn( warnings.warn(
"key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'." "key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'."
' We use the default, the average number of neighbors in the' ' We use the default, the average number of neighbors in the'
' dataset, if not provided.', ' dataset, if not provided.',
UserWarning, UserWarning,
) )
config.pop(KEY.AVG_NUM_NEIGH) config.pop(KEY.AVG_NUM_NEIGH)
if KEY.TRAIN_AVG_NUM_NEIGH in config: if KEY.TRAIN_AVG_NUM_NEIGH in config:
warnings.warn( warnings.warn(
"key 'train_avg_num_neigh' is deprecated. Please use" "key 'train_avg_num_neigh' is deprecated. Please use"
" 'train_denominator'. We overwrite train_denominator as given" " 'train_denominator'. We overwrite train_denominator as given"
' train_avg_num_neigh', ' train_avg_num_neigh',
UserWarning, UserWarning,
) )
config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH] config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH]
config.pop(KEY.TRAIN_AVG_NUM_NEIGH) config.pop(KEY.TRAIN_AVG_NUM_NEIGH)
if KEY.OPTIMIZE_BY_REDUCE in config: if KEY.OPTIMIZE_BY_REDUCE in config:
warnings.warn( warnings.warn(
"key 'optimize_by_reduce' is deprecated. Always true", "key 'optimize_by_reduce' is deprecated. Always true",
UserWarning, UserWarning,
) )
config.pop(KEY.OPTIMIZE_BY_REDUCE) config.pop(KEY.OPTIMIZE_BY_REDUCE)
# init simpler ones # init simpler ones
for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items(): for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items():
model_meta[key] = config_initialize( model_meta[key] = config_initialize(
key, config, default, _const.MODEL_CONFIG_CONDITION key, config, default, _const.MODEL_CONFIG_CONDITION
) )
unknown_keys = [ unknown_keys = [
key for key in config.keys() if key not in model_meta.keys() key for key in config.keys() if key not in model_meta.keys()
] ]
if len(unknown_keys) != 0: if len(unknown_keys) != 0:
warnings.warn( warnings.warn(
f'Unexpected model keys: {unknown_keys} will be ignored', f'Unexpected model keys: {unknown_keys} will be ignored',
UserWarning, UserWarning,
) )
return model_meta return model_meta
def init_train_config(config: Dict): def init_train_config(config: Dict):
train_meta = {} train_meta = {}
# defaults = _const.train_defaults(config) # defaults = _const.train_defaults(config)
try: try:
device_input = config[KEY.DEVICE] device_input = config[KEY.DEVICE]
train_meta[KEY.DEVICE] = torch.device(device_input) train_meta[KEY.DEVICE] = torch.device(device_input)
except KeyError: except KeyError:
train_meta[KEY.DEVICE] = ( train_meta[KEY.DEVICE] = (
torch.device('cuda') torch.device('cuda')
if torch.cuda.is_available() if torch.cuda.is_available()
else torch.device('cpu') else torch.device('cpu')
) )
train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE]) train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE])
# init simpler ones # init simpler ones
for key, default in _const.DEFAULT_TRAINING_CONFIG.items(): for key, default in _const.DEFAULT_TRAINING_CONFIG.items():
train_meta[key] = config_initialize( train_meta[key] = config_initialize(
key, config, default, _const.TRAINING_CONFIG_CONDITION key, config, default, _const.TRAINING_CONFIG_CONDITION
) )
if KEY.CONTINUE in config.keys(): if KEY.CONTINUE in config.keys():
cnt_dct = config[KEY.CONTINUE] cnt_dct = config[KEY.CONTINUE]
if KEY.CHECKPOINT not in cnt_dct.keys(): if KEY.CHECKPOINT not in cnt_dct.keys():
raise ValueError('no checkpoint is given in continue') raise ValueError('no checkpoint is given in continue')
checkpoint = cnt_dct[KEY.CHECKPOINT] checkpoint = cnt_dct[KEY.CHECKPOINT]
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
checkpoint_file = checkpoint checkpoint_file = checkpoint
else: else:
checkpoint_file = util.pretrained_name_to_path(checkpoint) checkpoint_file = util.pretrained_name_to_path(checkpoint)
train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file}) train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file})
unknown_keys = [ unknown_keys = [
key for key in config.keys() if key not in train_meta.keys() key for key in config.keys() if key not in train_meta.keys()
] ]
if len(unknown_keys) != 0: if len(unknown_keys) != 0:
warnings.warn( warnings.warn(
f'Unexpected train keys: {unknown_keys} will be ignored', f'Unexpected train keys: {unknown_keys} will be ignored',
UserWarning, UserWarning,
) )
return train_meta return train_meta
def init_data_config(config: Dict): def init_data_config(config: Dict):
data_meta = {} data_meta = {}
# defaults = _const.data_defaults(config) # defaults = _const.data_defaults(config)
load_data_keys = [] load_data_keys = []
for k in config: for k in config:
if k.startswith('load_') and k.endswith('_path'): if k.startswith('load_') and k.endswith('_path'):
load_data_keys.append(k) load_data_keys.append(k)
for load_data_key in load_data_keys: for load_data_key in load_data_keys:
if load_data_key in config.keys(): if load_data_key in config.keys():
inp = config[load_data_key] inp = config[load_data_key]
extended = [] extended = []
if type(inp) not in [str, list]: if type(inp) not in [str, list]:
raise ValueError(f'unexpected input {inp} for sturcture_list') raise ValueError(f'unexpected input {inp} for sturcture_list')
if type(inp) is str: if type(inp) is str:
extended = glob.glob(inp) extended = glob.glob(inp)
elif type(inp) is list: elif type(inp) is list:
for i in inp: for i in inp:
if isinstance(i, str): if isinstance(i, str):
extended.extend(glob.glob(i)) extended.extend(glob.glob(i))
elif isinstance(i, dict): elif isinstance(i, dict):
extended.append(i) extended.append(i)
if len(extended) == 0: if len(extended) == 0:
raise ValueError( raise ValueError(
f'Cannot find {inp} for {load_data_key}' f'Cannot find {inp} for {load_data_key}'
+ ' or path is not given' + ' or path is not given'
) )
data_meta[load_data_key] = extended data_meta[load_data_key] = extended
else: else:
data_meta[load_data_key] = False data_meta[load_data_key] = False
for key, default in _const.DEFAULT_DATA_CONFIG.items(): for key, default in _const.DEFAULT_DATA_CONFIG.items():
data_meta[key] = config_initialize( data_meta[key] = config_initialize(
key, config, default, _const.DATA_CONFIG_CONDITION key, config, default, _const.DATA_CONFIG_CONDITION
) )
unknown_keys = [ unknown_keys = [
key for key in config.keys() if key not in data_meta.keys() key for key in config.keys() if key not in data_meta.keys()
] ]
if len(unknown_keys) != 0: if len(unknown_keys) != 0:
warnings.warn( warnings.warn(
f'Unexpected data keys: {unknown_keys} will be ignored', f'Unexpected data keys: {unknown_keys} will be ignored',
UserWarning, UserWarning,
) )
return data_meta return data_meta
def read_config_yaml(filename: str, return_separately: bool = False): def read_config_yaml(filename: str, return_separately: bool = False):
with open(filename, 'r') as fstream: with open(filename, 'r') as fstream:
inputs = yaml.safe_load(fstream) inputs = yaml.safe_load(fstream)
model_meta, train_meta, data_meta = {}, {}, {} model_meta, train_meta, data_meta = {}, {}, {}
for key, config in inputs.items(): for key, config in inputs.items():
if key == 'model': if key == 'model':
model_meta = init_model_config(config) model_meta = init_model_config(config)
elif key == 'train': elif key == 'train':
train_meta = init_train_config(config) train_meta = init_train_config(config)
elif key == 'data': elif key == 'data':
data_meta = init_data_config(config) data_meta = init_data_config(config)
else: else:
raise ValueError(f'Unexpected input {key} given') raise ValueError(f'Unexpected input {key} given')
if return_separately: if return_separately:
return model_meta, train_meta, data_meta return model_meta, train_meta, data_meta
else: else:
model_meta.update(train_meta) model_meta.update(train_meta)
model_meta.update(data_meta) model_meta.update(data_meta)
return model_meta return model_meta
def main(): def main():
filename = './input.yaml' filename = './input.yaml'
read_config_yaml(filename) read_config_yaml(filename)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment