Unverified Commit 273418ee authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Merge pull request #1 from hjhk258/main

Fix
parents 7fb73825 9d7b4f63
import math
import torch
@torch.jit.script
def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x) - math.log(2.0)
import math
import torch
@torch.jit.script
def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x) - math.log(2.0)
from typing import List
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, TensorProduct
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .activation import ShiftedSoftPlus
from .util import broadcast
def message_gather(
node_features: torch.Tensor,
edge_dst: torch.Tensor,
message: torch.Tensor
):
index = broadcast(edge_dst, message, 0)
out_shape = [len(node_features)] + list(message.shape[1:])
out = torch.zeros(
out_shape,
dtype=node_features.dtype,
device=node_features.device
)
out.scatter_reduce_(0, index, message, reduce='sum')
return out
@compile_mode('script')
class IrrepsConvolution(nn.Module):
"""
convolution of (fig 2.b), comm. in LAMMPS
"""
def __init__(
self,
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out: Irreps,
weight_layer_input_to_hidden: List[int],
weight_layer_act=ShiftedSoftPlus,
denominator: float = 1.0,
train_denominator: bool = False,
data_key_x: str = KEY.NODE_FEATURE,
data_key_filter: str = KEY.EDGE_ATTR,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
data_key_edge_idx: str = KEY.EDGE_IDX,
lazy_layer_instantiate: bool = True,
is_parallel: bool = False,
):
super().__init__()
self.denominator = nn.Parameter(
torch.FloatTensor([denominator]), requires_grad=train_denominator
)
self.key_x = data_key_x
self.key_filter = data_key_filter
self.key_weight_input = data_key_weight_input
self.key_edge_idx = data_key_edge_idx
self.is_parallel = is_parallel
instructions = []
irreps_mid = []
weight_numel = 0
for i, (mul_x, ir_x) in enumerate(irreps_x):
for j, (_, ir_filter) in enumerate(irreps_filter):
for ir_out in ir_x * ir_filter:
if ir_out in irreps_out: # here we drop l > lmax
k = len(irreps_mid)
weight_numel += mul_x * 1 # path shape
irreps_mid.append((mul_x, ir_out))
instructions.append((i, j, k, 'uvu', True))
irreps_mid = Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort() # type: ignore
instructions = [
(i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
# From v0.11.x, to compatible with cuEquivariance
self._instructions_before_sort = instructions
instructions = sorted(instructions, key=lambda x: x[2])
self.convolution_kwargs = dict(
irreps_in1=irreps_x,
irreps_in2=irreps_filter,
irreps_out=irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.weight_nn_kwargs = dict(
hs=weight_layer_input_to_hidden + [weight_numel],
act=weight_layer_act
)
self.convolution = None
self.weight_nn = None
self.layer_instantiated = False
self.convolution_cls = TensorProduct
self.weight_nn_cls = FullyConnectedNet
if not lazy_layer_instantiate:
self.instantiate()
self._comm_size = irreps_x.dim # used in parallel
def instantiate(self):
if self.convolution is not None:
raise ValueError('Convolution layer already exists')
if self.weight_nn is not None:
raise ValueError('Weight_nn layer already exists')
self.convolution = self.convolution_cls(**self.convolution_kwargs)
self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.convolution is not None, 'Convolution is not instantiated'
assert self.weight_nn is not None, 'Weight_nn is not instantiated'
weight = self.weight_nn(data[self.key_weight_input])
x = data[self.key_x]
if self.is_parallel:
x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]])
# note that 1 -> src 0 -> dst
edge_src = data[self.key_edge_idx][1]
edge_dst = data[self.key_edge_idx][0]
message = self.convolution(x[edge_src], data[self.key_filter], weight)
x = message_gather(x, edge_dst, message)
x = x.div(self.denominator)
if self.is_parallel:
x = torch.tensor_split(x, data[KEY.NLOCAL])[0]
data[self.key_x] = x
return data
from typing import List
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, TensorProduct
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .activation import ShiftedSoftPlus
from .util import broadcast
def message_gather(
node_features: torch.Tensor,
edge_dst: torch.Tensor,
message: torch.Tensor
):
index = broadcast(edge_dst, message, 0)
out_shape = [len(node_features)] + list(message.shape[1:])
out = torch.zeros(
out_shape,
dtype=node_features.dtype,
device=node_features.device
)
out.scatter_reduce_(0, index, message, reduce='sum')
return out
@compile_mode('script')
class IrrepsConvolution(nn.Module):
"""
convolution of (fig 2.b), comm. in LAMMPS
"""
def __init__(
self,
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out: Irreps,
weight_layer_input_to_hidden: List[int],
weight_layer_act=ShiftedSoftPlus,
denominator: float = 1.0,
train_denominator: bool = False,
data_key_x: str = KEY.NODE_FEATURE,
data_key_filter: str = KEY.EDGE_ATTR,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
data_key_edge_idx: str = KEY.EDGE_IDX,
lazy_layer_instantiate: bool = True,
is_parallel: bool = False,
):
super().__init__()
self.denominator = nn.Parameter(
torch.FloatTensor([denominator]), requires_grad=train_denominator
)
self.key_x = data_key_x
self.key_filter = data_key_filter
self.key_weight_input = data_key_weight_input
self.key_edge_idx = data_key_edge_idx
self.is_parallel = is_parallel
instructions = []
irreps_mid = []
weight_numel = 0
for i, (mul_x, ir_x) in enumerate(irreps_x):
for j, (_, ir_filter) in enumerate(irreps_filter):
for ir_out in ir_x * ir_filter:
if ir_out in irreps_out: # here we drop l > lmax
k = len(irreps_mid)
weight_numel += mul_x * 1 # path shape
irreps_mid.append((mul_x, ir_out))
instructions.append((i, j, k, 'uvu', True))
irreps_mid = Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort() # type: ignore
instructions = [
(i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
# From v0.11.x, to compatible with cuEquivariance
self._instructions_before_sort = instructions
instructions = sorted(instructions, key=lambda x: x[2])
self.convolution_kwargs = dict(
irreps_in1=irreps_x,
irreps_in2=irreps_filter,
irreps_out=irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.weight_nn_kwargs = dict(
hs=weight_layer_input_to_hidden + [weight_numel],
act=weight_layer_act
)
self.convolution = None
self.weight_nn = None
self.layer_instantiated = False
self.convolution_cls = TensorProduct
self.weight_nn_cls = FullyConnectedNet
if not lazy_layer_instantiate:
self.instantiate()
self._comm_size = irreps_x.dim # used in parallel
def instantiate(self):
if self.convolution is not None:
raise ValueError('Convolution layer already exists')
if self.weight_nn is not None:
raise ValueError('Weight_nn layer already exists')
self.convolution = self.convolution_cls(**self.convolution_kwargs)
self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.convolution is not None, 'Convolution is not instantiated'
assert self.weight_nn is not None, 'Weight_nn is not instantiated'
weight = self.weight_nn(data[self.key_weight_input])
x = data[self.key_x]
if self.is_parallel:
x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]])
# note that 1 -> src 0 -> dst
edge_src = data[self.key_edge_idx][1]
edge_dst = data[self.key_edge_idx][0]
message = self.convolution(x[edge_src], data[self.key_filter], weight)
x = message_gather(x, edge_dst, message)
x = x.div(self.denominator)
if self.is_parallel:
x = torch.tensor_split(x, data[KEY.NLOCAL])[0]
data[self.key_x] = x
return data
import itertools
import warnings
from typing import Iterator, Literal, Union
import e3nn.o3 as o3
import numpy as np
from .convolution import IrrepsConvolution
from .linear import IrrepsLinear
from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
_CUE_AVAILABLE = True
# Obatained from MACE
class O3_e3nn(cue.O3):
def __mul__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> Iterator['O3_e3nn']:
return [ # type: ignore
O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)
]
@classmethod
def clebsch_gordan( # type: ignore
cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn'
) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim
)
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> bool:
rep2 = rep1._from(rep2) # type: ignore
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod
def iterator(cls) -> Iterator['O3_e3nn']:
for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
except ImportError:
_CUE_AVAILABLE = False
def is_cue_available():
return _CUE_AVAILABLE
def cue_needed(func):
def wrapper(*args, **kwargs):
if is_cue_available():
return func(*args, **kwargs)
else:
raise ImportError('cue is not available')
return wrapper
def _check_may_not_compatible(orig_kwargs, defaults):
for k, v in defaults.items():
v_given = orig_kwargs.pop(k, v)
if v_given != v:
warnings.warn(f'{k}: {v} is ignored to use cuEquivariance')
def is_cue_cuda_available_model(config):
if config.get('use_bias_in_linear', False):
warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn')
return False
else:
return True
@cue_needed
def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']):
"""Convert e3nn irreps to given group's cue irreps"""
if group == 'SO3':
assert all(irrep.ir.p == 1 for irrep in irreps)
return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore
elif group == 'O3':
return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore
else:
raise ValueError(f'Unknown group: {group}')
@cue_needed
def patch_linear(
module: Union[IrrepsLinear, SelfConnectionLinearIntro],
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
orig_kwargs = module.linear_kwargs
may_not_compatible_default = dict(
f_in=None,
f_out=None,
instructions=None,
biases=False,
path_normalization='element',
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(orig_kwargs, may_not_compatible_default)
module.linear_cls = cuet.Linear # type: ignore
orig_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_convolution(
module: IrrepsConvolution,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
# conv_kwargs will be patched in place
conv_kwargs = module.convolution_kwargs
conv_kwargs.update(
dict(
irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group),
irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group),
filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group),
)
)
inst_orig = conv_kwargs.pop('instructions')
inst_sorted = sorted(inst_orig, key=lambda x: x[2])
assert all([a == b for a, b in zip(inst_orig, inst_sorted)])
may_not_compatible_default = dict(
in1_var=None,
in2_var=None,
out_var=None,
irrep_normalization=False,
path_normalization='element',
compile_left_right=True,
compile_right=False,
_specialized_code=None,
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(conv_kwargs, may_not_compatible_default)
module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore
conv_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_fully_connected(
module: SelfConnectionIntro,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore
module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
may_not_compatible_default = dict(
irrep_normalization=None,
path_normalization=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(
module.fc_tensor_product_kwargs, may_not_compatible_default
)
module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore
module.fc_tensor_product_kwargs.update(**cue_kwargs)
return module
import itertools
import warnings
from typing import Iterator, Literal, Union
import e3nn.o3 as o3
import numpy as np
from .convolution import IrrepsConvolution
from .linear import IrrepsLinear
from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
_CUE_AVAILABLE = True
# Obatained from MACE
class O3_e3nn(cue.O3):
def __mul__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> Iterator['O3_e3nn']:
return [ # type: ignore
O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)
]
@classmethod
def clebsch_gordan( # type: ignore
cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn'
) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim
)
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> bool:
rep2 = rep1._from(rep2) # type: ignore
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod
def iterator(cls) -> Iterator['O3_e3nn']:
for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
except ImportError:
_CUE_AVAILABLE = False
def is_cue_available():
return _CUE_AVAILABLE
def cue_needed(func):
def wrapper(*args, **kwargs):
if is_cue_available():
return func(*args, **kwargs)
else:
raise ImportError('cue is not available')
return wrapper
def _check_may_not_compatible(orig_kwargs, defaults):
for k, v in defaults.items():
v_given = orig_kwargs.pop(k, v)
if v_given != v:
warnings.warn(f'{k}: {v} is ignored to use cuEquivariance')
def is_cue_cuda_available_model(config):
if config.get('use_bias_in_linear', False):
warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn')
return False
else:
return True
@cue_needed
def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']):
"""Convert e3nn irreps to given group's cue irreps"""
if group == 'SO3':
assert all(irrep.ir.p == 1 for irrep in irreps)
return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore
elif group == 'O3':
return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore
else:
raise ValueError(f'Unknown group: {group}')
@cue_needed
def patch_linear(
module: Union[IrrepsLinear, SelfConnectionLinearIntro],
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
orig_kwargs = module.linear_kwargs
may_not_compatible_default = dict(
f_in=None,
f_out=None,
instructions=None,
biases=False,
path_normalization='element',
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(orig_kwargs, may_not_compatible_default)
module.linear_cls = cuet.Linear # type: ignore
orig_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_convolution(
module: IrrepsConvolution,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
# conv_kwargs will be patched in place
conv_kwargs = module.convolution_kwargs
conv_kwargs.update(
dict(
irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group),
irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group),
filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group),
)
)
inst_orig = conv_kwargs.pop('instructions')
inst_sorted = sorted(inst_orig, key=lambda x: x[2])
assert all([a == b for a, b in zip(inst_orig, inst_sorted)])
may_not_compatible_default = dict(
in1_var=None,
in2_var=None,
out_var=None,
irrep_normalization=False,
path_normalization='element',
compile_left_right=True,
compile_right=False,
_specialized_code=None,
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(conv_kwargs, may_not_compatible_default)
module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore
conv_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_fully_connected(
module: SelfConnectionIntro,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore
module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
may_not_compatible_default = dict(
irrep_normalization=None,
path_normalization=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(
module.fc_tensor_product_kwargs, may_not_compatible_default
)
module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore
module.fc_tensor_product_kwargs.update(**cue_kwargs)
return module
import math
import torch
import torch.nn as nn
from e3nn.o3 import Irreps, SphericalHarmonics
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EdgePreprocess(nn.Module):
"""
preprocessing pos to edge vectors and edge lengths
currently used in sevenn/scripts/deploy for lammps serial model
"""
def __init__(self, is_stress: bool):
super().__init__()
# controlled by 'AtomGraphSequential'
self.is_stress = is_stress
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
cell = data[KEY.CELL].view(-1, 3, 3)
else:
cell = data[KEY.CELL].view(3, 3)
cell_shift = data[KEY.CELL_SHIFT]
pos = data[KEY.POS]
batch = data[KEY.BATCH] # for deploy, must be defined first
if self.is_stress:
if self._is_batch_data:
num_batch = int(batch.max().cpu().item()) + 1
strain = torch.zeros(
(num_batch, 3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.bmm(
pos.unsqueeze(-2), sym_strain[batch]
).squeeze(-2)
cell = cell + torch.bmm(cell, sym_strain)
else:
strain = torch.zeros(
(3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.mm(pos, sym_strain)
cell = cell + torch.mm(cell, sym_strain)
idx_src = data[KEY.EDGE_IDX][0]
idx_dst = data[KEY.EDGE_IDX][1]
edge_vec = pos[idx_dst] - pos[idx_src]
if self._is_batch_data:
edge_vec = edge_vec + torch.einsum(
'ni,nij->nj', cell_shift, cell[batch[idx_src]]
)
else:
edge_vec = edge_vec + torch.einsum(
'ni,ij->nj', cell_shift, cell.squeeze(0)
)
data[KEY.EDGE_VEC] = edge_vec
data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1)
return data
class BesselBasis(nn.Module):
"""
f : (*, 1) -> (*, bessel_basis_num)
"""
def __init__(
self,
cutoff_length: float,
bessel_basis_num: int = 8,
trainable_coeff: bool = True,
):
super().__init__()
self.num_basis = bessel_basis_num
self.prefactor = 2.0 / cutoff_length
self.coeffs = torch.FloatTensor([
n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)
])
if trainable_coeff:
self.coeffs = nn.Parameter(self.coeffs)
def forward(self, r: torch.Tensor) -> torch.Tensor:
ur = r.unsqueeze(-1) # to fit dimension
return self.prefactor * torch.sin(self.coeffs * ur) / ur
class PolynomialCutoff(nn.Module):
"""
f : (*, 1) -> (*, 1)
https://arxiv.org/pdf/2003.03123.pdf
"""
def __init__(
self,
cutoff_length: float,
poly_cut_p_value: int = 6,
):
super().__init__()
p = poly_cut_p_value
self.cutoff_length = cutoff_length
self.p = p
self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0
self.coeff_p1 = p * (p + 2.0)
self.coeff_p2 = p * (p + 1.0) / 2.0
def forward(self, r: torch.Tensor) -> torch.Tensor:
r = r / self.cutoff_length
return (
1
- self.coeff_p0 * torch.pow(r, self.p)
+ self.coeff_p1 * torch.pow(r, self.p + 1.0)
- self.coeff_p2 * torch.pow(r, self.p + 2.0)
)
class XPLORCutoff(nn.Module):
"""
https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html
"""
def __init__(
self,
cutoff_length: float,
cutoff_on: float,
):
super().__init__()
self.r_on = cutoff_on
self.r_cut = cutoff_length
assert self.r_on < self.r_cut
def forward(self, r: torch.Tensor) -> torch.Tensor:
r_sq = r * r
r_on_sq = self.r_on * self.r_on
r_cut_sq = self.r_cut * self.r_cut
return torch.where(
r < self.r_on,
1.0,
(r_cut_sq - r_sq) ** 2
* (r_cut_sq + 2 * r_sq - 3 * r_on_sq)
/ (r_cut_sq - r_on_sq) ** 3,
)
@compile_mode('script')
class SphericalEncoding(nn.Module):
def __init__(
self,
lmax: int,
parity: int = -1,
normalization: str = 'component',
normalize: bool = True,
):
super().__init__()
self.lmax = lmax
self.normalization = normalization
self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e')
self.irreps_out = Irreps.spherical_harmonics(lmax, parity)
self.sph = SphericalHarmonics(
self.irreps_out,
normalize=normalize,
normalization=normalization,
irreps_in=self.irreps_in,
)
def forward(self, r: torch.Tensor) -> torch.Tensor:
return self.sph(r)
@compile_mode('script')
class EdgeEmbedding(nn.Module):
"""
embedding layer of |r| by
RadialBasis(|r|)*CutOff(|r|)
f : (N_edge) -> (N_edge, basis_num)
"""
def __init__(
self,
basis_module: nn.Module,
cutoff_module: nn.Module,
spherical_module: nn.Module,
):
super().__init__()
self.basis_function = basis_module
self.cutoff_function = cutoff_module
self.spherical = spherical_module
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
rvec = data[KEY.EDGE_VEC]
r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1)
data[KEY.EDGE_LENGTH] = r
data[KEY.EDGE_EMBEDDING] = self.basis_function(
r
) * self.cutoff_function(r).unsqueeze(-1)
data[KEY.EDGE_ATTR] = self.spherical(rvec)
return data
import math
import torch
import torch.nn as nn
from e3nn.o3 import Irreps, SphericalHarmonics
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EdgePreprocess(nn.Module):
"""
preprocessing pos to edge vectors and edge lengths
currently used in sevenn/scripts/deploy for lammps serial model
"""
def __init__(self, is_stress: bool):
super().__init__()
# controlled by 'AtomGraphSequential'
self.is_stress = is_stress
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
cell = data[KEY.CELL].view(-1, 3, 3)
else:
cell = data[KEY.CELL].view(3, 3)
cell_shift = data[KEY.CELL_SHIFT]
pos = data[KEY.POS]
batch = data[KEY.BATCH] # for deploy, must be defined first
if self.is_stress:
if self._is_batch_data:
num_batch = int(batch.max().cpu().item()) + 1
strain = torch.zeros(
(num_batch, 3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.bmm(
pos.unsqueeze(-2), sym_strain[batch]
).squeeze(-2)
cell = cell + torch.bmm(cell, sym_strain)
else:
strain = torch.zeros(
(3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.mm(pos, sym_strain)
cell = cell + torch.mm(cell, sym_strain)
idx_src = data[KEY.EDGE_IDX][0]
idx_dst = data[KEY.EDGE_IDX][1]
edge_vec = pos[idx_dst] - pos[idx_src]
if self._is_batch_data:
edge_vec = edge_vec + torch.einsum(
'ni,nij->nj', cell_shift, cell[batch[idx_src]]
)
else:
edge_vec = edge_vec + torch.einsum(
'ni,ij->nj', cell_shift, cell.squeeze(0)
)
data[KEY.EDGE_VEC] = edge_vec
data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1)
return data
class BesselBasis(nn.Module):
"""
f : (*, 1) -> (*, bessel_basis_num)
"""
def __init__(
self,
cutoff_length: float,
bessel_basis_num: int = 8,
trainable_coeff: bool = True,
):
super().__init__()
self.num_basis = bessel_basis_num
self.prefactor = 2.0 / cutoff_length
self.coeffs = torch.FloatTensor([
n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)
])
if trainable_coeff:
self.coeffs = nn.Parameter(self.coeffs)
def forward(self, r: torch.Tensor) -> torch.Tensor:
ur = r.unsqueeze(-1) # to fit dimension
return self.prefactor * torch.sin(self.coeffs * ur) / ur
class PolynomialCutoff(nn.Module):
"""
f : (*, 1) -> (*, 1)
https://arxiv.org/pdf/2003.03123.pdf
"""
def __init__(
self,
cutoff_length: float,
poly_cut_p_value: int = 6,
):
super().__init__()
p = poly_cut_p_value
self.cutoff_length = cutoff_length
self.p = p
self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0
self.coeff_p1 = p * (p + 2.0)
self.coeff_p2 = p * (p + 1.0) / 2.0
def forward(self, r: torch.Tensor) -> torch.Tensor:
r = r / self.cutoff_length
return (
1
- self.coeff_p0 * torch.pow(r, self.p)
+ self.coeff_p1 * torch.pow(r, self.p + 1.0)
- self.coeff_p2 * torch.pow(r, self.p + 2.0)
)
class XPLORCutoff(nn.Module):
"""
https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html
"""
def __init__(
self,
cutoff_length: float,
cutoff_on: float,
):
super().__init__()
self.r_on = cutoff_on
self.r_cut = cutoff_length
assert self.r_on < self.r_cut
def forward(self, r: torch.Tensor) -> torch.Tensor:
r_sq = r * r
r_on_sq = self.r_on * self.r_on
r_cut_sq = self.r_cut * self.r_cut
return torch.where(
r < self.r_on,
1.0,
(r_cut_sq - r_sq) ** 2
* (r_cut_sq + 2 * r_sq - 3 * r_on_sq)
/ (r_cut_sq - r_on_sq) ** 3,
)
@compile_mode('script')
class SphericalEncoding(nn.Module):
def __init__(
self,
lmax: int,
parity: int = -1,
normalization: str = 'component',
normalize: bool = True,
):
super().__init__()
self.lmax = lmax
self.normalization = normalization
self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e')
self.irreps_out = Irreps.spherical_harmonics(lmax, parity)
self.sph = SphericalHarmonics(
self.irreps_out,
normalize=normalize,
normalization=normalization,
irreps_in=self.irreps_in,
)
def forward(self, r: torch.Tensor) -> torch.Tensor:
return self.sph(r)
@compile_mode('script')
class EdgeEmbedding(nn.Module):
"""
embedding layer of |r| by
RadialBasis(|r|)*CutOff(|r|)
f : (N_edge) -> (N_edge, basis_num)
"""
def __init__(
self,
basis_module: nn.Module,
cutoff_module: nn.Module,
spherical_module: nn.Module,
):
super().__init__()
self.basis_function = basis_module
self.cutoff_function = cutoff_module
self.spherical = spherical_module
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
rvec = data[KEY.EDGE_VEC]
r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1)
data[KEY.EDGE_LENGTH] = r
data[KEY.EDGE_EMBEDDING] = self.basis_function(
r
) * self.cutoff_function(r).unsqueeze(-1)
data[KEY.EDGE_ATTR] = self.spherical(rvec)
return data
from typing import Callable, Dict
import torch.nn as nn
from e3nn.nn import Gate
from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EquivariantGate(nn.Module):
def __init__(
self,
irreps_x: Irreps,
act_scalar_dict: Dict[int, Callable],
act_gate_dict: Dict[int, Callable],
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
parity_mapper = {'e': 1, 'o': -1}
act_scalar_dict = {
parity_mapper[k]: v for k, v in act_scalar_dict.items()
}
act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()}
irreps_gated_elem = []
irreps_scalars_elem = []
# non scalar irreps > gated / scalar irreps > scalars
for mul, irreps in irreps_x:
if irreps.l > 0:
irreps_gated_elem.append((mul, irreps))
else:
irreps_scalars_elem.append((mul, irreps))
irreps_scalars = Irreps(irreps_scalars_elem)
irreps_gated = Irreps(irreps_gated_elem)
irreps_gates_parity = 1 if '0e' in irreps_scalars else -1
irreps_gates = Irreps(
[(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated]
)
act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars]
act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates]
self.gate = Gate(
irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated
)
def get_gate_irreps_in(self):
"""
user must call this function to get proper irreps in for forward
"""
return self.gate.irreps_in
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = self.gate(data[self.key_x])
return data
from typing import Callable, Dict
import torch.nn as nn
from e3nn.nn import Gate
from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EquivariantGate(nn.Module):
def __init__(
self,
irreps_x: Irreps,
act_scalar_dict: Dict[int, Callable],
act_gate_dict: Dict[int, Callable],
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
parity_mapper = {'e': 1, 'o': -1}
act_scalar_dict = {
parity_mapper[k]: v for k, v in act_scalar_dict.items()
}
act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()}
irreps_gated_elem = []
irreps_scalars_elem = []
# non scalar irreps > gated / scalar irreps > scalars
for mul, irreps in irreps_x:
if irreps.l > 0:
irreps_gated_elem.append((mul, irreps))
else:
irreps_scalars_elem.append((mul, irreps))
irreps_scalars = Irreps(irreps_scalars_elem)
irreps_gated = Irreps(irreps_gated_elem)
irreps_gates_parity = 1 if '0e' in irreps_scalars else -1
irreps_gates = Irreps(
[(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated]
)
act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars]
act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates]
self.gate = Gate(
irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated
)
def get_gate_irreps_in(self):
"""
user must call this function to get proper irreps in for forward
"""
return self.gate.irreps_in
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = self.gate(data[self.key_x])
return data
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .util import broadcast
@compile_mode('script')
class ForceOutput(nn.Module):
"""
works when pos.requires_grad_ is True
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = [data[self.key_pos]]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
pos_tensor,
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)[0]
# For torchscript
if grad is not None:
data[self.key_force] = torch.neg(grad)
return data
@compile_mode('script')
class ForceStressOutput(nn.Module):
"""
Compute stress and force from positions.
Used in serial torchscipt models
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = data[self.key_pos]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
[pos_tensor, data['_strain']],
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)
# make grad is not Optional[Tensor]
fgrad = grad[0]
if fgrad is not None:
data[self.key_force] = torch.neg(fgrad)
sgrad = grad[1]
volume = data[self.key_cell_volume]
vlim = 1e-3 # for cell volume = 0 for non PBC structures
if self._is_batch_data:
volume[volume < vlim] = vlim
elif volume < vlim:
volume = torch.tensor(vlim)
if sgrad is not None:
if self._is_batch_data:
stress = sgrad / volume.view(-1, 1, 1)
stress = torch.neg(stress)
virial_stress = torch.vstack((
stress[:, 0, 0],
stress[:, 1, 1],
stress[:, 2, 2],
stress[:, 0, 1],
stress[:, 1, 2],
stress[:, 0, 2],
))
data[self.key_stress] = virial_stress.transpose(0, 1)
else:
stress = sgrad / volume
stress = torch.neg(stress)
virial_stress = torch.stack((
stress[0, 0],
stress[1, 1],
stress[2, 2],
stress[0, 1],
stress[1, 2],
stress[0, 2],
))
data[self.key_stress] = virial_stress
return data
@compile_mode('script')
class ForceStressOutputFromEdge(nn.Module):
"""
Compute stress and force from edge.
Used in parallel torchscipt models, and training
"""
def __init__(
self,
data_key_edge: str = KEY.EDGE_VEC,
data_key_edge_idx: str = KEY.EDGE_IDX,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_edge = data_key_edge
self.key_edge_idx = data_key_edge_idx
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_edge
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item?
rij = data[self.key_edge]
energy = [(data[self.key_energy]).sum()]
edge_idx = data[self.key_edge_idx]
grad = torch.autograd.grad(
energy,
[rij],
create_graph=self.training,
allow_unused=True
)
# make grad is not Optional[Tensor]
fij = grad[0]
if fij is not None:
# compute force
pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
_edge_src = broadcast(edge_idx[0], fij, 0)
_edge_dst = broadcast(edge_idx[1], fij, 0)
pf.scatter_reduce_(0, _edge_src, fij, reduce='sum')
nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum')
data[self.key_force] = pf - nf
# compute virial
diag = rij * fij
s12 = rij[..., 0] * fij[..., 1]
s23 = rij[..., 1] * fij[..., 2]
s31 = rij[..., 2] * fij[..., 0]
# cat last dimension
_virial = torch.cat([
diag,
s12.unsqueeze(-1),
s23.unsqueeze(-1),
s31.unsqueeze(-1)
], dim=-1)
_s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device)
_edge_dst6 = broadcast(edge_idx[1], _virial, 0)
_s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum')
if self._is_batch_data:
batch = data[KEY.BATCH] # for deploy, must be defined first
nbatch = int(batch.max().cpu().item()) + 1
sout = torch.zeros(
(nbatch, 6), dtype=_virial.dtype, device=_virial.device
)
_batch = broadcast(batch, _s, 0)
sout.scatter_reduce_(0, _batch, _s, reduce='sum')
else:
sout = torch.sum(_s, dim=0)
data[self.key_stress] =\
torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1)
return data
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .util import broadcast
@compile_mode('script')
class ForceOutput(nn.Module):
"""
works when pos.requires_grad_ is True
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = [data[self.key_pos]]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
pos_tensor,
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)[0]
# For torchscript
if grad is not None:
data[self.key_force] = torch.neg(grad)
return data
@compile_mode('script')
class ForceStressOutput(nn.Module):
"""
Compute stress and force from positions.
Used in serial torchscipt models
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = data[self.key_pos]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
[pos_tensor, data['_strain']],
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)
# make grad is not Optional[Tensor]
fgrad = grad[0]
if fgrad is not None:
data[self.key_force] = torch.neg(fgrad)
sgrad = grad[1]
volume = data[self.key_cell_volume]
vlim = 1e-3 # for cell volume = 0 for non PBC structures
if self._is_batch_data:
volume[volume < vlim] = vlim
elif volume < vlim:
volume = torch.tensor(vlim)
if sgrad is not None:
if self._is_batch_data:
stress = sgrad / volume.view(-1, 1, 1)
stress = torch.neg(stress)
virial_stress = torch.vstack((
stress[:, 0, 0],
stress[:, 1, 1],
stress[:, 2, 2],
stress[:, 0, 1],
stress[:, 1, 2],
stress[:, 0, 2],
))
data[self.key_stress] = virial_stress.transpose(0, 1)
else:
stress = sgrad / volume
stress = torch.neg(stress)
virial_stress = torch.stack((
stress[0, 0],
stress[1, 1],
stress[2, 2],
stress[0, 1],
stress[1, 2],
stress[0, 2],
))
data[self.key_stress] = virial_stress
return data
@compile_mode('script')
class ForceStressOutputFromEdge(nn.Module):
"""
Compute stress and force from edge.
Used in parallel torchscipt models, and training
"""
def __init__(
self,
data_key_edge: str = KEY.EDGE_VEC,
data_key_edge_idx: str = KEY.EDGE_IDX,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_edge = data_key_edge
self.key_edge_idx = data_key_edge_idx
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_edge
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item?
rij = data[self.key_edge]
energy = [(data[self.key_energy]).sum()]
edge_idx = data[self.key_edge_idx]
grad = torch.autograd.grad(
energy,
[rij],
create_graph=self.training,
allow_unused=True
)
# make grad is not Optional[Tensor]
fij = grad[0]
if fij is not None:
# compute force
pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
_edge_src = broadcast(edge_idx[0], fij, 0)
_edge_dst = broadcast(edge_idx[1], fij, 0)
pf.scatter_reduce_(0, _edge_src, fij, reduce='sum')
nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum')
data[self.key_force] = pf - nf
# compute virial
diag = rij * fij
s12 = rij[..., 0] * fij[..., 1]
s23 = rij[..., 1] * fij[..., 2]
s31 = rij[..., 2] * fij[..., 0]
# cat last dimension
_virial = torch.cat([
diag,
s12.unsqueeze(-1),
s23.unsqueeze(-1),
s31.unsqueeze(-1)
], dim=-1)
_s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device)
_edge_dst6 = broadcast(edge_idx[1], _virial, 0)
_s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum')
if self._is_batch_data:
batch = data[KEY.BATCH] # for deploy, must be defined first
nbatch = int(batch.max().cpu().item()) + 1
sout = torch.zeros(
(nbatch, 6), dtype=_virial.dtype, device=_virial.device
)
_batch = broadcast(batch, _s, 0)
sout.scatter_reduce_(0, _batch, _s, reduce='sum')
else:
sout = torch.sum(_s, dim=0)
data[self.key_stress] =\
torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1)
return data
from typing import Callable, List, Tuple
from e3nn.o3 import Irreps
import sevenn._keys as KEY
from .convolution import IrrepsConvolution
from .equivariant_gate import EquivariantGate
from .linear import IrrepsLinear
def NequIP_interaction_block(
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out_tp: Irreps,
irreps_out: Irreps,
weight_nn_layers: List[int],
conv_denominator: float,
train_conv_denominator: bool,
self_connection_pair: Tuple[Callable, Callable],
act_scalar: Callable,
act_gate: Callable,
act_radial: Callable,
bias_in_linear: bool,
num_species: int,
t: int, # interaction layer index
data_key_x: str = KEY.NODE_FEATURE,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
parallel: bool = False,
**conv_kwargs,
):
block = {}
irreps_node_attr = Irreps(f'{num_species}x0e')
sc_intro, sc_outro = self_connection_pair
gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate)
irreps_for_gate_in = gate_layer.get_gate_irreps_in()
block[f'{t}_self_connection_intro'] = sc_intro(
irreps_x,
irreps_operand=irreps_node_attr,
irreps_out=irreps_for_gate_in,
)
block[f'{t}_self_interaction_1'] = IrrepsLinear(
irreps_x, irreps_x,
data_key_in=data_key_x,
biases=bias_in_linear,
)
# convolution part, l>lmax is dropped as defined in irreps_out
block[f'{t}_convolution'] = IrrepsConvolution(
irreps_x=irreps_x,
irreps_filter=irreps_filter,
irreps_out=irreps_out_tp,
data_key_weight_input=data_key_weight_input,
weight_layer_input_to_hidden=weight_nn_layers,
weight_layer_act=act_radial,
denominator=conv_denominator,
train_denominator=train_conv_denominator,
is_parallel=parallel,
**conv_kwargs,
)
# irreps of x increase to gate_irreps_in
block[f'{t}_self_interaction_2'] = IrrepsLinear(
irreps_out_tp,
irreps_for_gate_in,
data_key_in=data_key_x,
biases=bias_in_linear,
)
block[f'{t}_self_connection_outro'] = sc_outro()
block[f'{t}_equivariant_gate'] = gate_layer
return block
from typing import Callable, List, Tuple
from e3nn.o3 import Irreps
import sevenn._keys as KEY
from .convolution import IrrepsConvolution
from .equivariant_gate import EquivariantGate
from .linear import IrrepsLinear
def NequIP_interaction_block(
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out_tp: Irreps,
irreps_out: Irreps,
weight_nn_layers: List[int],
conv_denominator: float,
train_conv_denominator: bool,
self_connection_pair: Tuple[Callable, Callable],
act_scalar: Callable,
act_gate: Callable,
act_radial: Callable,
bias_in_linear: bool,
num_species: int,
t: int, # interaction layer index
data_key_x: str = KEY.NODE_FEATURE,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
parallel: bool = False,
**conv_kwargs,
):
block = {}
irreps_node_attr = Irreps(f'{num_species}x0e')
sc_intro, sc_outro = self_connection_pair
gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate)
irreps_for_gate_in = gate_layer.get_gate_irreps_in()
block[f'{t}_self_connection_intro'] = sc_intro(
irreps_x,
irreps_operand=irreps_node_attr,
irreps_out=irreps_for_gate_in,
)
block[f'{t}_self_interaction_1'] = IrrepsLinear(
irreps_x, irreps_x,
data_key_in=data_key_x,
biases=bias_in_linear,
)
# convolution part, l>lmax is dropped as defined in irreps_out
block[f'{t}_convolution'] = IrrepsConvolution(
irreps_x=irreps_x,
irreps_filter=irreps_filter,
irreps_out=irreps_out_tp,
data_key_weight_input=data_key_weight_input,
weight_layer_input_to_hidden=weight_nn_layers,
weight_layer_act=act_radial,
denominator=conv_denominator,
train_denominator=train_conv_denominator,
is_parallel=parallel,
**conv_kwargs,
)
# irreps of x increase to gate_irreps_in
block[f'{t}_self_interaction_2'] = IrrepsLinear(
irreps_out_tp,
irreps_for_gate_in,
data_key_in=data_key_x,
biases=bias_in_linear,
)
block[f'{t}_self_connection_outro'] = sc_outro()
block[f'{t}_equivariant_gate'] = gate_layer
return block
from typing import Callable, List, Optional
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class IrrepsLinear(nn.Module):
"""
wrapper class of e3nn Linear to operate on AtomGraphData
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_in: str,
data_key_out: Optional[str] = None,
data_key_modal_attr: str = KEY.MODAL_ATTR,
num_modalities: int = 0,
lazy_layer_instantiate: bool = True,
**linear_kwargs,
):
super().__init__()
self.key_input = data_key_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
self.key_modal_attr = data_key_modal_attr
self._irreps_in_wo_modal = irreps_in
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.linear_kwargs = linear_kwargs
self.linear = None
self.layer_instantiated = False
self.num_modalities = num_modalities
self._is_batch_data = True
# use getter setter
self.linear_cls = Linear
if num_modalities > 1: # in case of multi-modal
self.set_num_modalities(num_modalities)
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def set_num_modalities(self, num_modalities):
if self.layer_instantiated:
raise ValueError('Layer already instantiated, can not change modalities')
irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e')
self.num_modalities = num_modalities
self.irreps_in = irreps_in
def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
batch_modality_onehot = data[self.key_modal_attr].reshape(
-1, self.num_modalities
)
batch_modality_onehot = batch_modality_onehot.type(
data[self.key_input].dtype
)
data[self.key_input] = torch.cat(
[data[self.key_input], batch_modality_onehot[batch]], dim=1
)
else:
modality_onehot = data[self.key_modal_attr].expand(
len(data[self.key_input]), -1
)
modality_onehot = modality_onehot.type(data[self.key_input].dtype)
data[self.key_input] = torch.cat(
[data[self.key_input], modality_onehot], dim=1
)
return data
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
if self.num_modalities > 1:
data = self._patch_modal_to_data(data)
data[self.key_output] = self.linear(data[self.key_input])
return data
@compile_mode('script')
class AtomReduce(nn.Module):
"""
atomic energy -> total energy
constant is multiplied to data
"""
def __init__(
self,
data_key_in: str,
data_key_out: str,
reduce: str = 'sum',
constant: float = 1.0,
):
super().__init__()
self.key_input = data_key_in
self.key_output = data_key_out
self.constant = constant
self.reduce = reduce
# controlled by the upper most wrapper 'AtomGraphSequential'
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
src = data[self.key_input].squeeze(1)
size = int(data[KEY.BATCH].max()) + 1
output = torch.zeros(
(size),
dtype=src.dtype,
device=src.device,
)
output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum')
data[self.key_output] = output * self.constant
else:
data[self.key_output] = torch.sum(data[self.key_input]) * self.constant
return data
@compile_mode('script')
class FCN_e3nn(nn.Module):
"""
wrapper class of e3nn FullyConnectedNet
"""
def __init__(
self,
irreps_in: Irreps, # confirm it is scalar & input size
dim_out: int,
hidden_neurons: List[int],
activation: Callable,
data_key_in: str,
data_key_out: Optional[str] = None,
**e3nn_kwargs,
):
super().__init__()
self.key_input = data_key_in
self.irreps_in = irreps_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
for _, irrep in irreps_in:
assert irrep.is_scalar()
inp_dim = irreps_in.dim
self.fcn = FullyConnectedNet(
[inp_dim] + hidden_neurons + [dim_out],
activation,
**e3nn_kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = self.fcn(data[self.key_input])
return data
from typing import Callable, List, Optional
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class IrrepsLinear(nn.Module):
"""
wrapper class of e3nn Linear to operate on AtomGraphData
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_in: str,
data_key_out: Optional[str] = None,
data_key_modal_attr: str = KEY.MODAL_ATTR,
num_modalities: int = 0,
lazy_layer_instantiate: bool = True,
**linear_kwargs,
):
super().__init__()
self.key_input = data_key_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
self.key_modal_attr = data_key_modal_attr
self._irreps_in_wo_modal = irreps_in
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.linear_kwargs = linear_kwargs
self.linear = None
self.layer_instantiated = False
self.num_modalities = num_modalities
self._is_batch_data = True
# use getter setter
self.linear_cls = Linear
if num_modalities > 1: # in case of multi-modal
self.set_num_modalities(num_modalities)
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def set_num_modalities(self, num_modalities):
if self.layer_instantiated:
raise ValueError('Layer already instantiated, can not change modalities')
irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e')
self.num_modalities = num_modalities
self.irreps_in = irreps_in
def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
batch_modality_onehot = data[self.key_modal_attr].reshape(
-1, self.num_modalities
)
batch_modality_onehot = batch_modality_onehot.type(
data[self.key_input].dtype
)
data[self.key_input] = torch.cat(
[data[self.key_input], batch_modality_onehot[batch]], dim=1
)
else:
modality_onehot = data[self.key_modal_attr].expand(
len(data[self.key_input]), -1
)
modality_onehot = modality_onehot.type(data[self.key_input].dtype)
data[self.key_input] = torch.cat(
[data[self.key_input], modality_onehot], dim=1
)
return data
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
if self.num_modalities > 1:
data = self._patch_modal_to_data(data)
data[self.key_output] = self.linear(data[self.key_input])
return data
@compile_mode('script')
class AtomReduce(nn.Module):
"""
atomic energy -> total energy
constant is multiplied to data
"""
def __init__(
self,
data_key_in: str,
data_key_out: str,
reduce: str = 'sum',
constant: float = 1.0,
):
super().__init__()
self.key_input = data_key_in
self.key_output = data_key_out
self.constant = constant
self.reduce = reduce
# controlled by the upper most wrapper 'AtomGraphSequential'
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
src = data[self.key_input].squeeze(1)
size = int(data[KEY.BATCH].max()) + 1
output = torch.zeros(
(size),
dtype=src.dtype,
device=src.device,
)
output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum')
data[self.key_output] = output * self.constant
else:
data[self.key_output] = torch.sum(data[self.key_input]) * self.constant
return data
@compile_mode('script')
class FCN_e3nn(nn.Module):
"""
wrapper class of e3nn FullyConnectedNet
"""
def __init__(
self,
irreps_in: Irreps, # confirm it is scalar & input size
dim_out: int,
hidden_neurons: List[int],
activation: Callable,
data_key_in: str,
data_key_out: Optional[str] = None,
**e3nn_kwargs,
):
super().__init__()
self.key_input = data_key_in
self.irreps_in = irreps_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
for _, irrep in irreps_in:
assert irrep.is_scalar()
inp_dim = irreps_in.dim
self.fcn = FullyConnectedNet(
[inp_dim] + hidden_neurons + [dim_out],
activation,
**e3nn_kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = self.fcn(data[self.key_input])
return data
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional
from ase.symbols import symbols2numbers
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
# TODO: put this to model_build and do not preprocess data by onehot
@compile_mode('script')
class OnehotEmbedding(nn.Module):
"""
x : tensor of shape (N, 1)
x_after : tensor of shape (N, num_classes)
It overwrite data_key_x
and saves input to data_key_save and output to data_key_additional
I know this is strange but it is for compatibility with previous version
and to specie wise shift scale work
ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2)
"""
def __init__(
self,
num_classes: int,
data_key_x: str = KEY.NODE_FEATURE,
data_key_out: Optional[str] = None,
data_key_save: Optional[str] = None,
data_key_additional: Optional[str] = None, # additional output
):
super().__init__()
self.num_classes = num_classes
self.key_x = data_key_x
if data_key_out is None:
self.key_output = data_key_x
else:
self.key_output = data_key_out
self.key_save = data_key_save
self.key_additional_output = data_key_additional
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
inp = data[self.key_x]
embd = torch.nn.functional.one_hot(inp, self.num_classes)
embd = embd.float()
data[self.key_output] = embd
if self.key_additional_output is not None:
data[self.key_additional_output] = embd # for self-connection
if self.key_save is not None:
data[self.key_save] = inp # for elemwise shift scale
return data
def get_type_mapper_from_specie(specie_list: List[str]):
"""
from ['Hf', 'O']
return {72: 0, 8: 1}
"""
specie_list = sorted(specie_list)
type_map = {}
unique_counter = 0
for specie in specie_list:
atomic_num = symbols2numbers(specie)[0]
if atomic_num in type_map:
continue
type_map[atomic_num] = unique_counter
unique_counter += 1
return type_map
# deprecated
def one_hot_atom_embedding(
atomic_numbers: List[int], type_map: Dict[int, int]
):
"""
atomic numbers from ase.get_atomic_numbers
type_map from get_type_mapper_from_specie()
"""
num_classes = len(type_map)
try:
type_numbers = torch.LongTensor(
[type_map[num] for num in atomic_numbers]
)
except KeyError as e:
raise ValueError(f'Atomic number {e.args[0]} is not expected')
embd = torch.nn.functional.one_hot(type_numbers, num_classes)
embd = embd.to(torch.get_default_dtype())
return embd
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional
from ase.symbols import symbols2numbers
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
# TODO: put this to model_build and do not preprocess data by onehot
@compile_mode('script')
class OnehotEmbedding(nn.Module):
"""
x : tensor of shape (N, 1)
x_after : tensor of shape (N, num_classes)
It overwrite data_key_x
and saves input to data_key_save and output to data_key_additional
I know this is strange but it is for compatibility with previous version
and to specie wise shift scale work
ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2)
"""
def __init__(
self,
num_classes: int,
data_key_x: str = KEY.NODE_FEATURE,
data_key_out: Optional[str] = None,
data_key_save: Optional[str] = None,
data_key_additional: Optional[str] = None, # additional output
):
super().__init__()
self.num_classes = num_classes
self.key_x = data_key_x
if data_key_out is None:
self.key_output = data_key_x
else:
self.key_output = data_key_out
self.key_save = data_key_save
self.key_additional_output = data_key_additional
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
inp = data[self.key_x]
embd = torch.nn.functional.one_hot(inp, self.num_classes)
embd = embd.float()
data[self.key_output] = embd
if self.key_additional_output is not None:
data[self.key_additional_output] = embd # for self-connection
if self.key_save is not None:
data[self.key_save] = inp # for elemwise shift scale
return data
def get_type_mapper_from_specie(specie_list: List[str]):
"""
from ['Hf', 'O']
return {72: 0, 8: 1}
"""
specie_list = sorted(specie_list)
type_map = {}
unique_counter = 0
for specie in specie_list:
atomic_num = symbols2numbers(specie)[0]
if atomic_num in type_map:
continue
type_map[atomic_num] = unique_counter
unique_counter += 1
return type_map
# deprecated
def one_hot_atom_embedding(
atomic_numbers: List[int], type_map: Dict[int, int]
):
"""
atomic numbers from ase.get_atomic_numbers
type_map from get_type_mapper_from_specie()
"""
num_classes = len(type_map)
try:
type_numbers = torch.LongTensor(
[type_map[num] for num in atomic_numbers]
)
except KeyError as e:
raise ValueError(f'Atomic number {e.args[0]} is not expected')
embd = torch.nn.functional.one_hot(type_numbers, num_classes)
embd = embd.to(torch.get_default_dtype())
return embd
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType
def _as_univ(
ss: List[float], type_map: Dict[int, int], default: float
) -> List[float]:
assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long'
return [
ss[type_map[z]] if z in type_map else default
for z in range(NUM_UNIV_ELEMENT)
]
@compile_mode('script')
class Rescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
"""
def __init__(
self,
shift: float,
scale: float,
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
train_shift_scale: bool = False,
**kwargs,
):
assert isinstance(shift, float) and isinstance(scale, float)
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor([shift]), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor([scale]), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
def get_shift(self) -> float:
return self.shift.detach().cpu().tolist()[0]
def get_scale(self) -> float:
return self.scale.detach().cpu().tolist()[0]
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = data[self.key_input] * self.scale + self.shift
return data
@compile_mode('script')
class SpeciesWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Use as it is if given list, expand to list if one of them is float
If two lists are given and length is not the same, raise error
"""
def __init__(
self,
shift: Union[List[float], float],
scale: Union[List[float], float],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_indices: str = KEY.ATOM_TYPE,
train_shift_scale: bool = False,
):
super().__init__()
assert isinstance(shift, float) or isinstance(shift, list)
assert isinstance(scale, float) or isinstance(scale, list)
if (
isinstance(shift, list)
and isinstance(scale, list)
and len(shift) != len(scale)
):
raise ValueError('List length should be same')
if isinstance(shift, list):
num_species = len(shift)
elif isinstance(scale, list):
num_species = len(scale)
else:
raise ValueError('Both shift and scale is not a list')
shift = [shift] * num_species if isinstance(shift, float) else shift
scale = [scale] * num_species if isinstance(scale, float) else scale
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_indices = data_key_indices
def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return shift in list of float. If type_map is given, return type_map reversed
shift, which index equals atomic_number. 0.0 is assigned for atomis not found
"""
shift = self.shift.detach().cpu().tolist()
if type_map:
shift = _as_univ(shift, type_map, 0.0)
return shift
def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return scale in list of float. If type_map is given, return type_map reversed
scale, which index equals atomic_number. 1.0 is assigned for atomis not found
"""
scale = self.scale.detach().cpu().tolist()
if type_map:
scale = _as_univ(scale, type_map, 1.0)
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float]],
scale: Union[float, List[float]],
type_map: Dict[int, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
"""
shift_scale = []
n_atom_types = len(type_map)
for s in (shift, scale):
if isinstance(s, list) and len(s) > n_atom_types:
if len(s) != NUM_UNIV_ELEMENT:
raise ValueError('given shift or scale is strange')
s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])]
# s = [s[z] for z in sorted(type_map, key=type_map.get)]
elif isinstance(s, float):
s = [s] * n_atom_types
elif isinstance(s, list) and len(s) == 1:
s = s * n_atom_types
shift_scale.append(s)
assert all([len(s) == n_atom_types for s in shift_scale])
shift, scale = shift_scale
return SpeciesWiseRescale(shift, scale, **kwargs)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
indices = data[self.key_indices]
data[self.key_output] = data[self.key_input] * self.scale[indices].view(
-1, 1
) + self.shift[indices].view(-1, 1)
return data
@compile_mode('script')
class ModalWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Given shift or scale is either modal-wise and atom-wise or
not modal-wise but atom-wise. It is always interpreted as atom-wise.
"""
def __init__(
self,
shift: List[List[float]],
scale: List[List[float]],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_modal_indices: str = KEY.MODAL_TYPE,
data_key_atom_indices: str = KEY.ATOM_TYPE,
use_modal_wise_shift: bool = False,
use_modal_wise_scale: bool = False,
train_shift_scale: bool = False,
):
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_atom_indices = data_key_atom_indices
self.key_modal_indices = data_key_modal_indices
self.use_modal_wise_shift = use_modal_wise_shift
self.use_modal_wise_scale = use_modal_wise_scale
self._is_batch_data = True
def get_shift(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise shift: return univ shift
both type_map and modal_map is given and modal wise shift: return fully
resolved modalwise univ shift
"""
shift = self.shift.detach().cpu().tolist()
if type_map and not self.use_modal_wise_shift:
shift = _as_univ(shift, type_map, 0.0)
elif self.use_modal_wise_shift and modal_map and type_map:
shift = [_as_univ(s, type_map, 0.0) for s in shift]
shift = {modal: shift[idx] for modal, idx in modal_map.items()}
return shift
def get_scale(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise scale: return univ scale
both type_map and modal_map is given and modal wise scale: return fully
resolved modalwise univ scale
"""
scale = self.scale.detach().cpu().tolist()
if type_map and not self.use_modal_wise_scale:
scale = _as_univ(scale, type_map, 0.0)
elif self.use_modal_wise_scale and modal_map and type_map:
scale = [_as_univ(s, type_map, 0.0) for s in scale]
scale = {modal: scale[idx] for modal, idx in modal_map.items()}
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float], Dict[str, Any]],
scale: Union[float, List[float], Dict[str, Any]],
use_modal_wise_shift: bool,
use_modal_wise_scale: bool,
type_map: Dict[int, int],
modal_map: Dict[str, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT
, assume it is element-wise list
otherwise, it is modal-wise list
"""
def solve_mapper(arr, map):
# value is attr index and never overlap, key is either 'z' or modal str
return [arr[z] for z in sorted(map, key=lambda x: map[x])]
shift_scale = []
n_atom_types = len(type_map)
n_modals = len(modal_map)
for s, use_mw in (
(shift, use_modal_wise_shift),
(scale, use_modal_wise_scale),
):
# solve elemewise, or broadcast
if isinstance(s, float):
# given, modal-wise: no, elem-wise: no => broadcast
shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,)
res = torch.full(shape, s).tolist() # TODO: w/o torch
elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT:
# given, modal-wise: no, elem-wise: yes(univ) => solve elem map
s = solve_mapper(s, type_map)
res = [s] * n_modals if use_mw else s
elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_modals
and use_mw
):
res = [[v] * n_atom_types for v in s]
elif ( # given, modal-wise: no, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_atom_types
and not use_mw
):
res = s
elif ( # given, modal-wise: yes, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], list)
and len(s) == n_modals
and len(s[0]) == n_atom_types
and use_mw
):
res = s
elif isinstance(s, dict) and use_mw:
# solve modal dict, modal-wise: yes
s = solve_mapper(s, modal_map)
res = []
for v in s:
if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT:
# elem-wise: yes(univ) => solve elem map
v = solve_mapper(v, type_map)
elif isinstance(v, float):
# elem-wise: no => broadcast to elemwise
v = [v] * n_atom_types
else:
raise ValueError(f'Invalid shift or scale {s}')
res.append(v)
else:
raise ValueError(f'Invalid shift or scale {s}')
if use_mw:
assert (
isinstance(res, list)
and isinstance(res[0], list)
and len(res) == n_modals
)
assert all([len(r) == n_atom_types for r in res]) # type: ignore
else:
assert (
isinstance(res, list)
and isinstance(res[0], float)
and len(res) == n_atom_types
)
shift_scale.append(res)
shift, scale = shift_scale
return ModalWiseRescale(
shift,
scale,
use_modal_wise_shift=use_modal_wise_shift,
use_modal_wise_scale=use_modal_wise_scale,
**kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
modal_indices = data[self.key_modal_indices][batch]
else:
modal_indices = data[self.key_modal_indices]
atom_indices = data[self.key_atom_indices]
shift = (
self.shift[modal_indices, atom_indices]
if self.use_modal_wise_shift
else self.shift[atom_indices]
)
scale = (
self.scale[modal_indices, atom_indices]
if self.use_modal_wise_scale
else self.scale[atom_indices]
)
data[self.key_output] = data[self.key_input] * scale.view(
-1, 1
) + shift.view(-1, 1)
return data
def get_resolved_shift_scale(
module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale],
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
):
"""
Return resolved shift and scale from scale modules. For element wise case,
convert to list of floats where idx is atomic number. For modal wise case, return
dictionary of shift scale where key is modal name given in modal_map
Return:
Tuple of solved shift and scale
"""
if isinstance(module, Rescale):
return (module.get_shift(), module.get_scale())
elif isinstance(module, SpeciesWiseRescale):
return (module.get_shift(type_map), module.get_scale(type_map))
elif isinstance(module, ModalWiseRescale):
return (
module.get_shift(type_map, modal_map),
module.get_scale(type_map, modal_map),
)
raise ValueError('Not scale module')
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType
def _as_univ(
ss: List[float], type_map: Dict[int, int], default: float
) -> List[float]:
assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long'
return [
ss[type_map[z]] if z in type_map else default
for z in range(NUM_UNIV_ELEMENT)
]
@compile_mode('script')
class Rescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
"""
def __init__(
self,
shift: float,
scale: float,
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
train_shift_scale: bool = False,
**kwargs,
):
assert isinstance(shift, float) and isinstance(scale, float)
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor([shift]), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor([scale]), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
def get_shift(self) -> float:
return self.shift.detach().cpu().tolist()[0]
def get_scale(self) -> float:
return self.scale.detach().cpu().tolist()[0]
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = data[self.key_input] * self.scale + self.shift
return data
@compile_mode('script')
class SpeciesWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Use as it is if given list, expand to list if one of them is float
If two lists are given and length is not the same, raise error
"""
def __init__(
self,
shift: Union[List[float], float],
scale: Union[List[float], float],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_indices: str = KEY.ATOM_TYPE,
train_shift_scale: bool = False,
):
super().__init__()
assert isinstance(shift, float) or isinstance(shift, list)
assert isinstance(scale, float) or isinstance(scale, list)
if (
isinstance(shift, list)
and isinstance(scale, list)
and len(shift) != len(scale)
):
raise ValueError('List length should be same')
if isinstance(shift, list):
num_species = len(shift)
elif isinstance(scale, list):
num_species = len(scale)
else:
raise ValueError('Both shift and scale is not a list')
shift = [shift] * num_species if isinstance(shift, float) else shift
scale = [scale] * num_species if isinstance(scale, float) else scale
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_indices = data_key_indices
def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return shift in list of float. If type_map is given, return type_map reversed
shift, which index equals atomic_number. 0.0 is assigned for atomis not found
"""
shift = self.shift.detach().cpu().tolist()
if type_map:
shift = _as_univ(shift, type_map, 0.0)
return shift
def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return scale in list of float. If type_map is given, return type_map reversed
scale, which index equals atomic_number. 1.0 is assigned for atomis not found
"""
scale = self.scale.detach().cpu().tolist()
if type_map:
scale = _as_univ(scale, type_map, 1.0)
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float]],
scale: Union[float, List[float]],
type_map: Dict[int, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
"""
shift_scale = []
n_atom_types = len(type_map)
for s in (shift, scale):
if isinstance(s, list) and len(s) > n_atom_types:
if len(s) != NUM_UNIV_ELEMENT:
raise ValueError('given shift or scale is strange')
s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])]
# s = [s[z] for z in sorted(type_map, key=type_map.get)]
elif isinstance(s, float):
s = [s] * n_atom_types
elif isinstance(s, list) and len(s) == 1:
s = s * n_atom_types
shift_scale.append(s)
assert all([len(s) == n_atom_types for s in shift_scale])
shift, scale = shift_scale
return SpeciesWiseRescale(shift, scale, **kwargs)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
indices = data[self.key_indices]
data[self.key_output] = data[self.key_input] * self.scale[indices].view(
-1, 1
) + self.shift[indices].view(-1, 1)
return data
@compile_mode('script')
class ModalWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Given shift or scale is either modal-wise and atom-wise or
not modal-wise but atom-wise. It is always interpreted as atom-wise.
"""
def __init__(
self,
shift: List[List[float]],
scale: List[List[float]],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_modal_indices: str = KEY.MODAL_TYPE,
data_key_atom_indices: str = KEY.ATOM_TYPE,
use_modal_wise_shift: bool = False,
use_modal_wise_scale: bool = False,
train_shift_scale: bool = False,
):
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_atom_indices = data_key_atom_indices
self.key_modal_indices = data_key_modal_indices
self.use_modal_wise_shift = use_modal_wise_shift
self.use_modal_wise_scale = use_modal_wise_scale
self._is_batch_data = True
def get_shift(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise shift: return univ shift
both type_map and modal_map is given and modal wise shift: return fully
resolved modalwise univ shift
"""
shift = self.shift.detach().cpu().tolist()
if type_map and not self.use_modal_wise_shift:
shift = _as_univ(shift, type_map, 0.0)
elif self.use_modal_wise_shift and modal_map and type_map:
shift = [_as_univ(s, type_map, 0.0) for s in shift]
shift = {modal: shift[idx] for modal, idx in modal_map.items()}
return shift
def get_scale(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise scale: return univ scale
both type_map and modal_map is given and modal wise scale: return fully
resolved modalwise univ scale
"""
scale = self.scale.detach().cpu().tolist()
if type_map and not self.use_modal_wise_scale:
scale = _as_univ(scale, type_map, 0.0)
elif self.use_modal_wise_scale and modal_map and type_map:
scale = [_as_univ(s, type_map, 0.0) for s in scale]
scale = {modal: scale[idx] for modal, idx in modal_map.items()}
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float], Dict[str, Any]],
scale: Union[float, List[float], Dict[str, Any]],
use_modal_wise_shift: bool,
use_modal_wise_scale: bool,
type_map: Dict[int, int],
modal_map: Dict[str, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT
, assume it is element-wise list
otherwise, it is modal-wise list
"""
def solve_mapper(arr, map):
# value is attr index and never overlap, key is either 'z' or modal str
return [arr[z] for z in sorted(map, key=lambda x: map[x])]
shift_scale = []
n_atom_types = len(type_map)
n_modals = len(modal_map)
for s, use_mw in (
(shift, use_modal_wise_shift),
(scale, use_modal_wise_scale),
):
# solve elemewise, or broadcast
if isinstance(s, float):
# given, modal-wise: no, elem-wise: no => broadcast
shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,)
res = torch.full(shape, s).tolist() # TODO: w/o torch
elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT:
# given, modal-wise: no, elem-wise: yes(univ) => solve elem map
s = solve_mapper(s, type_map)
res = [s] * n_modals if use_mw else s
elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_modals
and use_mw
):
res = [[v] * n_atom_types for v in s]
elif ( # given, modal-wise: no, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_atom_types
and not use_mw
):
res = s
elif ( # given, modal-wise: yes, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], list)
and len(s) == n_modals
and len(s[0]) == n_atom_types
and use_mw
):
res = s
elif isinstance(s, dict) and use_mw:
# solve modal dict, modal-wise: yes
s = solve_mapper(s, modal_map)
res = []
for v in s:
if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT:
# elem-wise: yes(univ) => solve elem map
v = solve_mapper(v, type_map)
elif isinstance(v, float):
# elem-wise: no => broadcast to elemwise
v = [v] * n_atom_types
else:
raise ValueError(f'Invalid shift or scale {s}')
res.append(v)
else:
raise ValueError(f'Invalid shift or scale {s}')
if use_mw:
assert (
isinstance(res, list)
and isinstance(res[0], list)
and len(res) == n_modals
)
assert all([len(r) == n_atom_types for r in res]) # type: ignore
else:
assert (
isinstance(res, list)
and isinstance(res[0], float)
and len(res) == n_atom_types
)
shift_scale.append(res)
shift, scale = shift_scale
return ModalWiseRescale(
shift,
scale,
use_modal_wise_shift=use_modal_wise_shift,
use_modal_wise_scale=use_modal_wise_scale,
**kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
modal_indices = data[self.key_modal_indices][batch]
else:
modal_indices = data[self.key_modal_indices]
atom_indices = data[self.key_atom_indices]
shift = (
self.shift[modal_indices, atom_indices]
if self.use_modal_wise_shift
else self.shift[atom_indices]
)
scale = (
self.scale[modal_indices, atom_indices]
if self.use_modal_wise_scale
else self.scale[atom_indices]
)
data[self.key_output] = data[self.key_input] * scale.view(
-1, 1
) + shift.view(-1, 1)
return data
def get_resolved_shift_scale(
module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale],
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
):
"""
Return resolved shift and scale from scale modules. For element wise case,
convert to list of floats where idx is atomic number. For modal wise case, return
dictionary of shift scale where key is modal name given in modal_map
Return:
Tuple of solved shift and scale
"""
if isinstance(module, Rescale):
return (module.get_shift(), module.get_scale())
elif isinstance(module, SpeciesWiseRescale):
return (module.get_shift(type_map), module.get_scale(type_map))
elif isinstance(module, ModalWiseRescale):
return (
module.get_shift(type_map, modal_map),
module.get_scale(type_map, modal_map),
)
raise ValueError('Not scale module')
import torch.nn as nn
from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class SelfConnectionIntro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
irreps_in: Irreps,
irreps_operand: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
data_key_operand: str = KEY.NODE_ATTR,
lazy_layer_instantiate: bool = True,
**kwargs, # for compatibility
):
super().__init__()
self.fc_tensor_product = FullyConnectedTensorProduct(
irreps_in, irreps_operand, irreps_out
)
self.irreps_in1 = irreps_in
self.irreps_in2 = irreps_operand
self.irreps_out = irreps_out
self.key_x = data_key_x
self.key_operand = data_key_operand
self.fc_tensor_product = None
self.layer_instantiated = False
self.fc_tensor_product_cls = FullyConnectedTensorProduct
self.fc_tensor_product_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.fc_tensor_product is not None:
raise ValueError('fc_tensor_product layer already exists')
self.fc_tensor_product = self.fc_tensor_product_cls(
self.irreps_in1,
self.irreps_in2,
self.irreps_out,
shared_weights=True,
internal_weights=None, # same as True
**self.fc_tensor_product_kwargs,
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.fc_tensor_product is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product(
data[self.key_x], data[self.key_operand]
)
return data
@compile_mode('script')
class SelfConnectionLinearIntro(nn.Module):
"""
Linear style self connection update
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
lazy_layer_instantiate: bool = True,
**kwargs,
):
super().__init__()
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.key_x = data_key_x
self.linear = None
self.layer_instantiated = False
self.linear_cls = Linear
# TODO: better to have SelfConnectionIntro super class
kwargs.pop('irreps_operand')
self.linear_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x])
return data
@compile_mode('script')
class SelfConnectionOutro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP]
del data[KEY.SELF_CONNECTION_TEMP]
return data
import torch.nn as nn
from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class SelfConnectionIntro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
irreps_in: Irreps,
irreps_operand: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
data_key_operand: str = KEY.NODE_ATTR,
lazy_layer_instantiate: bool = True,
**kwargs, # for compatibility
):
super().__init__()
self.fc_tensor_product = FullyConnectedTensorProduct(
irreps_in, irreps_operand, irreps_out
)
self.irreps_in1 = irreps_in
self.irreps_in2 = irreps_operand
self.irreps_out = irreps_out
self.key_x = data_key_x
self.key_operand = data_key_operand
self.fc_tensor_product = None
self.layer_instantiated = False
self.fc_tensor_product_cls = FullyConnectedTensorProduct
self.fc_tensor_product_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.fc_tensor_product is not None:
raise ValueError('fc_tensor_product layer already exists')
self.fc_tensor_product = self.fc_tensor_product_cls(
self.irreps_in1,
self.irreps_in2,
self.irreps_out,
shared_weights=True,
internal_weights=None, # same as True
**self.fc_tensor_product_kwargs,
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.fc_tensor_product is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product(
data[self.key_x], data[self.key_operand]
)
return data
@compile_mode('script')
class SelfConnectionLinearIntro(nn.Module):
"""
Linear style self connection update
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
lazy_layer_instantiate: bool = True,
**kwargs,
):
super().__init__()
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.key_x = data_key_x
self.linear = None
self.layer_instantiated = False
self.linear_cls = Linear
# TODO: better to have SelfConnectionIntro super class
kwargs.pop('irreps_operand')
self.linear_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x])
return data
@compile_mode('script')
class SelfConnectionOutro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP]
del data[KEY.SELF_CONNECTION_TEMP]
return data
import warnings
from collections import OrderedDict
from typing import Dict, Optional
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
def _instantiate_modules(modules):
# see IrrepsLinear of linear.py
for module in modules.values():
if not getattr(module, 'layer_instantiated', True):
module.instantiate()
@compile_mode('script')
class _ModalInputPrepare(nn.Module):
def __init__(
self,
modal_idx: int
):
super().__init__()
self.modal_idx = modal_idx
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[KEY.MODAL_TYPE] = torch.tensor(
self.modal_idx,
dtype=torch.int64,
device=data['x'].device,
)
return data
@compile_mode('script')
class AtomGraphSequential(nn.Sequential):
"""
Wrapper of SevenNet model
Args:
modules: OrderedDict of nn.Modules
cutoff: not used internally, but makes sense to have
type_map: atomic_numbers => onehot index (see nn/node_embedding.py)
eval_type_map: perform index mapping using type_map defaults to True
data_key_atomic_numbers: used when eval_type_map is True
data_key_node_feature: used when eval_type_map is True
data_key_grad: if given, sets its requires grad True before pred
"""
def __init__(
self,
modules: Dict[str, nn.Module],
cutoff: float = 0.0,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
eval_type_map: bool = True,
eval_modal_map: bool = False,
data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS,
data_key_node_feature: str = KEY.NODE_FEATURE,
data_key_grad: Optional[str] = None,
):
if not isinstance(modules, OrderedDict): # backward compat
modules = OrderedDict(modules)
self.cutoff = cutoff
self.type_map = type_map
self.eval_type_map = eval_type_map
self.is_batch_data = True
if cutoff == 0.0:
warnings.warn('cutoff is 0.0 or not given', UserWarning)
if self.type_map is None:
warnings.warn('type_map is not given', UserWarning)
self.eval_type_map = False
else:
z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long))
for z, onehot in self.type_map.items():
z_to_onehot_tensor[z] = onehot
self.z_to_onehot_tensor = z_to_onehot_tensor
if eval_modal_map and modal_map is None:
raise ValueError('eval_modal_map is True but modal_map is None')
self.eval_modal_map = eval_modal_map
self.modal_map = modal_map
self.key_atomic_numbers = data_key_atomic_numbers
self.key_node_feature = data_key_node_feature
self.key_grad = data_key_grad
_instantiate_modules(modules)
super().__init__(modules)
if not isinstance(self._modules, OrderedDict): # backward compat
self._modules = OrderedDict(self._modules)
def set_is_batch_data(self, flag: bool):
# whether given data is batched or not some module have to change
# its behavior. checking whether data is batched or not inside
# forward function make problem harder when make it into torchscript
for module in self:
try: # Easier to ask for forgiveness than permission.
module._is_batch_data = flag # type: ignore
except AttributeError:
pass
self.is_batch_data = flag
def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'):
tg_module = self._modules[modlue_name]
for m in tg_module.modules():
try:
return repr(m.__getattribute__(attr_key))
except AttributeError:
pass
return None
def prepand_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
self._modules.move_to_end(key, last=False) # type: ignore
def replace_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
def delete_module_by_key(self, key: str):
if key in self._modules.keys():
del self._modules[key]
@torch.jit.unused
def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor):
assert atomic_numbers.dtype == torch.int64
device = atomic_numbers.device
z_to_onehot_tensor = self.z_to_onehot_tensor.to(device)
return torch.index_select(
input=z_to_onehot_tensor, dim=0, index=atomic_numbers
)
@torch.jit.unused
def _eval_modal_map(self, data: AtomGraphDataType):
assert self.modal_map is not None
# modal_map: dict[str, int]
if not self.is_batch_data:
modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore
else:
modal_idx = [
self.modal_map[ii] # type: ignore
for ii in data[KEY.DATA_MODALITY]
]
modal_idx = torch.tensor(
modal_idx,
dtype=torch.int64,
device=data.x.device, # type: ignore
)
data[KEY.MODAL_TYPE] = modal_idx
def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self.eval_type_map:
atomic_numbers = data[self.key_atomic_numbers]
onehot = self._atomic_numbers_to_onehot(atomic_numbers)
data[self.key_node_feature] = onehot
if self.eval_modal_map:
self._eval_modal_map(data)
if self.key_grad is not None:
data[self.key_grad].requires_grad_(True)
return data
def prepare_modal_deploy(self, modal: str):
if self.modal_map is None:
return
self.eval_modal_map = False
self.set_is_batch_data(False)
modal_idx = self.modal_map[modal] # type: ignore
self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx))
def forward(self, input: AtomGraphDataType) -> AtomGraphDataType:
data = self._preprocess(input)
for module in self:
data = module(data)
return data
import warnings
from collections import OrderedDict
from typing import Dict, Optional
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
def _instantiate_modules(modules):
# see IrrepsLinear of linear.py
for module in modules.values():
if not getattr(module, 'layer_instantiated', True):
module.instantiate()
@compile_mode('script')
class _ModalInputPrepare(nn.Module):
def __init__(
self,
modal_idx: int
):
super().__init__()
self.modal_idx = modal_idx
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[KEY.MODAL_TYPE] = torch.tensor(
self.modal_idx,
dtype=torch.int64,
device=data['x'].device,
)
return data
@compile_mode('script')
class AtomGraphSequential(nn.Sequential):
"""
Wrapper of SevenNet model
Args:
modules: OrderedDict of nn.Modules
cutoff: not used internally, but makes sense to have
type_map: atomic_numbers => onehot index (see nn/node_embedding.py)
eval_type_map: perform index mapping using type_map defaults to True
data_key_atomic_numbers: used when eval_type_map is True
data_key_node_feature: used when eval_type_map is True
data_key_grad: if given, sets its requires grad True before pred
"""
def __init__(
self,
modules: Dict[str, nn.Module],
cutoff: float = 0.0,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
eval_type_map: bool = True,
eval_modal_map: bool = False,
data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS,
data_key_node_feature: str = KEY.NODE_FEATURE,
data_key_grad: Optional[str] = None,
):
if not isinstance(modules, OrderedDict): # backward compat
modules = OrderedDict(modules)
self.cutoff = cutoff
self.type_map = type_map
self.eval_type_map = eval_type_map
self.is_batch_data = True
if cutoff == 0.0:
warnings.warn('cutoff is 0.0 or not given', UserWarning)
if self.type_map is None:
warnings.warn('type_map is not given', UserWarning)
self.eval_type_map = False
else:
z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long))
for z, onehot in self.type_map.items():
z_to_onehot_tensor[z] = onehot
self.z_to_onehot_tensor = z_to_onehot_tensor
if eval_modal_map and modal_map is None:
raise ValueError('eval_modal_map is True but modal_map is None')
self.eval_modal_map = eval_modal_map
self.modal_map = modal_map
self.key_atomic_numbers = data_key_atomic_numbers
self.key_node_feature = data_key_node_feature
self.key_grad = data_key_grad
_instantiate_modules(modules)
super().__init__(modules)
if not isinstance(self._modules, OrderedDict): # backward compat
self._modules = OrderedDict(self._modules)
def set_is_batch_data(self, flag: bool):
# whether given data is batched or not some module have to change
# its behavior. checking whether data is batched or not inside
# forward function make problem harder when make it into torchscript
for module in self:
try: # Easier to ask for forgiveness than permission.
module._is_batch_data = flag # type: ignore
except AttributeError:
pass
self.is_batch_data = flag
def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'):
tg_module = self._modules[modlue_name]
for m in tg_module.modules():
try:
return repr(m.__getattribute__(attr_key))
except AttributeError:
pass
return None
def prepand_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
self._modules.move_to_end(key, last=False) # type: ignore
def replace_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
def delete_module_by_key(self, key: str):
if key in self._modules.keys():
del self._modules[key]
@torch.jit.unused
def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor):
assert atomic_numbers.dtype == torch.int64
device = atomic_numbers.device
z_to_onehot_tensor = self.z_to_onehot_tensor.to(device)
return torch.index_select(
input=z_to_onehot_tensor, dim=0, index=atomic_numbers
)
@torch.jit.unused
def _eval_modal_map(self, data: AtomGraphDataType):
assert self.modal_map is not None
# modal_map: dict[str, int]
if not self.is_batch_data:
modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore
else:
modal_idx = [
self.modal_map[ii] # type: ignore
for ii in data[KEY.DATA_MODALITY]
]
modal_idx = torch.tensor(
modal_idx,
dtype=torch.int64,
device=data.x.device, # type: ignore
)
data[KEY.MODAL_TYPE] = modal_idx
def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self.eval_type_map:
atomic_numbers = data[self.key_atomic_numbers]
onehot = self._atomic_numbers_to_onehot(atomic_numbers)
data[self.key_node_feature] = onehot
if self.eval_modal_map:
self._eval_modal_map(data)
if self.key_grad is not None:
data[self.key_grad].requires_grad_(True)
return data
def prepare_modal_deploy(self, modal: str):
if self.modal_map is None:
return
self.eval_modal_map = False
self.set_is_batch_data(False)
modal_idx = self.modal_map[modal] # type: ignore
self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx))
def forward(self, input: AtomGraphDataType) -> AtomGraphDataType:
data = self._preprocess(input)
for module in self:
data = module(data)
return data
import torch
def broadcast(
src: torch.Tensor,
other: torch.Tensor,
dim: int
):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
import torch
def broadcast(
src: torch.Tensor,
other: torch.Tensor,
dim: int
):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
#!/bin/bash
lammps_root=$1
cxx_standard=$2 # 14, 17
d3_support=$3 # 1, 0
SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
###########################################
# Check if the given arguments are valid #
###########################################
# Check the number of arguments
if [ "$#" -ne 3 ]; then
echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}"
echo " {lammps_root}: Root directory of LAMMPS source"
echo " {cxx_standard}: C++ standard (14, 17)"
echo " {d3_support}: Support for pair_d3 (1, 0)"
exit 1
fi
# Check if the lammps_root directory exists
if [ ! -d "$lammps_root" ]; then
echo "Error: No such directory: $lammps_root"
exit 1
fi
# Check if the given directory is the root of LAMMPS source
if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then
echo "Error: Given $lammps_root is not a root of LAMMPS source"
exit 1
fi
# Check if the script is being run from the root of SevenNet
if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then
echo "Error: Script executed in a wrong directory"
exit 1
fi
# Check if the patch is already applied
if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then
echo "----------------------------------------------------------"
echo "Seems like given LAMMPS is already patched."
echo "Try again after removing src/pair_e3gnn.cpp to force patch"
echo "----------------------------------------------------------"
echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4"
exit 0
fi
# Check if OpenMPI exists and if it is CUDA-aware
if command -v ompi_info &> /dev/null; then
cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value)
if [[ -z "$cuda_support" ]]; then
echo "OpenMPI not found, parallel performance is not optimal"
elif [[ "$cuda_support" == *"true" ]]; then
echo "OpenMPI is CUDA aware"
else
echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal"
fi
else
echo "OpenMPI not found, parallel performance is not optimal"
fi
# Extract LAMMPS version and update
lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"')
# Combine version and update
detected_version="$lammps_version"
required_version="2 Aug 2023" # Example required version
# Check if the detected version is compatible
if [[ "$detected_version" != "$required_version" ]]; then
echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version"
fi
###########################################
# Backup original LAMMPS source code #
###########################################
# Create a backup directory if it doesn't exist
backup_dir="$lammps_root/_backups"
mkdir -p $backup_dir
# Copy comm_* from original LAMMPS source as backup
cp $lammps_root/src/comm_brick.cpp $backup_dir/
cp $lammps_root/src/comm_brick.h $backup_dir/
# Copy cmake/CMakeLists.txt from original source as backup
cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt
###########################################
# Patch LAMMPS source code: e3gnn #
###########################################
# 1. Copy pair_e3gnn files to LAMMPS source
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt
sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}")
EOF
###########################################
# Patch LAMMPS source code: d3 #
###########################################
if [ "$d3_support" -ne 0 ]; then
# 1. Copy pair_d3 files to LAMMPS source
cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/
cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/
cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt
sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt
sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3")
string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda)
EOF
fi
###########################################
# Print changes and backup file locations #
###########################################
# Print changes and backup file locations
echo "Changes made:"
echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups"
echo " - Copied contents of pair_e3gnn to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard"
if [ "$d3_support" -ne 0 ]; then
echo " - Copied contents of pair_d3 to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include CUDA"
fi
# Provide example cmake command to the user
echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4"
exit 0
#!/bin/bash
lammps_root=$1
cxx_standard=$2 # 14, 17
d3_support=$3 # 1, 0
SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
###########################################
# Check if the given arguments are valid #
###########################################
# Check the number of arguments
if [ "$#" -ne 3 ]; then
echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}"
echo " {lammps_root}: Root directory of LAMMPS source"
echo " {cxx_standard}: C++ standard (14, 17)"
echo " {d3_support}: Support for pair_d3 (1, 0)"
exit 1
fi
# Check if the lammps_root directory exists
if [ ! -d "$lammps_root" ]; then
echo "Error: No such directory: $lammps_root"
exit 1
fi
# Check if the given directory is the root of LAMMPS source
if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then
echo "Error: Given $lammps_root is not a root of LAMMPS source"
exit 1
fi
# Check if the script is being run from the root of SevenNet
if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then
echo "Error: Script executed in a wrong directory"
exit 1
fi
# Check if the patch is already applied
if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then
echo "----------------------------------------------------------"
echo "Seems like given LAMMPS is already patched."
echo "Try again after removing src/pair_e3gnn.cpp to force patch"
echo "----------------------------------------------------------"
echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4"
exit 0
fi
# Check if OpenMPI exists and if it is CUDA-aware
if command -v ompi_info &> /dev/null; then
cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value)
if [[ -z "$cuda_support" ]]; then
echo "OpenMPI not found, parallel performance is not optimal"
elif [[ "$cuda_support" == *"true" ]]; then
echo "OpenMPI is CUDA aware"
else
echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal"
fi
else
echo "OpenMPI not found, parallel performance is not optimal"
fi
# Extract LAMMPS version and update
lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"')
# Combine version and update
detected_version="$lammps_version"
required_version="2 Aug 2023" # Example required version
# Check if the detected version is compatible
if [[ "$detected_version" != "$required_version" ]]; then
echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version"
fi
###########################################
# Backup original LAMMPS source code #
###########################################
# Create a backup directory if it doesn't exist
backup_dir="$lammps_root/_backups"
mkdir -p $backup_dir
# Copy comm_* from original LAMMPS source as backup
cp $lammps_root/src/comm_brick.cpp $backup_dir/
cp $lammps_root/src/comm_brick.h $backup_dir/
# Copy cmake/CMakeLists.txt from original source as backup
cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt
###########################################
# Patch LAMMPS source code: e3gnn #
###########################################
# 1. Copy pair_e3gnn files to LAMMPS source
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt
sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}")
EOF
###########################################
# Patch LAMMPS source code: d3 #
###########################################
if [ "$d3_support" -ne 0 ]; then
# 1. Copy pair_d3 files to LAMMPS source
cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/
cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/
cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt
sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt
sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3")
string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda)
EOF
fi
###########################################
# Print changes and backup file locations #
###########################################
# Print changes and backup file locations
echo "Changes made:"
echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups"
echo " - Copied contents of pair_e3gnn to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard"
if [ "$d3_support" -ne 0 ]; then
echo " - Copied contents of pair_d3 to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include CUDA"
fi
# Provide example cmake command to the user
echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4"
exit 0
import glob
import os
import warnings
from typing import Any, Callable, Dict
import torch
import yaml
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
def config_initialize(
key: str,
config: Dict,
default: Any,
conditions: Dict,
):
# default value exist & no user input -> return default
if key not in config.keys():
return default
# No validation method exist => accept user input
user_input = config[key]
if key in conditions:
condition = conditions[key]
else:
return user_input
if type(default) is dict and isinstance(condition, dict):
for i_key, val in default.items():
user_input[i_key] = config_initialize(
i_key, user_input, val, condition
)
return user_input
elif isinstance(condition, type):
if isinstance(user_input, condition):
return user_input
else:
try:
return condition(user_input) # try type casting
except ValueError:
raise ValueError(
f"Expect '{user_input}' for '{key}' is {condition}"
)
elif isinstance(condition, Callable) and condition(user_input):
return user_input
else:
raise ValueError(
f"Given input '{user_input}' for '{key}' is not valid"
)
def init_model_config(config: Dict):
# defaults = _const.model_defaults(config)
model_meta = {}
# init complicated ones
if KEY.CHEMICAL_SPECIES not in config.keys():
raise ValueError('required key chemical_species not exist')
input_chem = config[KEY.CHEMICAL_SPECIES]
if isinstance(input_chem, str) and input_chem.lower() == 'auto':
model_meta[KEY.CHEMICAL_SPECIES] = 'auto'
model_meta[KEY.NUM_SPECIES] = 'auto'
model_meta[KEY.TYPE_MAP] = 'auto'
elif isinstance(input_chem, str) and 'univ' in input_chem.lower():
model_meta.update(util.chemical_species_preprocess([], universal=True))
else:
if isinstance(input_chem, list) and all(
isinstance(x, str) for x in input_chem
):
pass
elif isinstance(input_chem, str):
input_chem = (
input_chem.replace('-', ',').replace(' ', ',').split(',')
)
input_chem = [chem for chem in input_chem if len(chem) != 0]
else:
raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange')
model_meta.update(util.chemical_species_preprocess(input_chem))
# deprecation warnings
if KEY.AVG_NUM_NEIGH in config:
warnings.warn(
"key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'."
' We use the default, the average number of neighbors in the'
' dataset, if not provided.',
UserWarning,
)
config.pop(KEY.AVG_NUM_NEIGH)
if KEY.TRAIN_AVG_NUM_NEIGH in config:
warnings.warn(
"key 'train_avg_num_neigh' is deprecated. Please use"
" 'train_denominator'. We overwrite train_denominator as given"
' train_avg_num_neigh',
UserWarning,
)
config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH]
config.pop(KEY.TRAIN_AVG_NUM_NEIGH)
if KEY.OPTIMIZE_BY_REDUCE in config:
warnings.warn(
"key 'optimize_by_reduce' is deprecated. Always true",
UserWarning,
)
config.pop(KEY.OPTIMIZE_BY_REDUCE)
# init simpler ones
for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items():
model_meta[key] = config_initialize(
key, config, default, _const.MODEL_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in model_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected model keys: {unknown_keys} will be ignored',
UserWarning,
)
return model_meta
def init_train_config(config: Dict):
train_meta = {}
# defaults = _const.train_defaults(config)
try:
device_input = config[KEY.DEVICE]
train_meta[KEY.DEVICE] = torch.device(device_input)
except KeyError:
train_meta[KEY.DEVICE] = (
torch.device('cuda')
if torch.cuda.is_available()
else torch.device('cpu')
)
train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE])
# init simpler ones
for key, default in _const.DEFAULT_TRAINING_CONFIG.items():
train_meta[key] = config_initialize(
key, config, default, _const.TRAINING_CONFIG_CONDITION
)
if KEY.CONTINUE in config.keys():
cnt_dct = config[KEY.CONTINUE]
if KEY.CHECKPOINT not in cnt_dct.keys():
raise ValueError('no checkpoint is given in continue')
checkpoint = cnt_dct[KEY.CHECKPOINT]
if os.path.isfile(checkpoint):
checkpoint_file = checkpoint
else:
checkpoint_file = util.pretrained_name_to_path(checkpoint)
train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file})
unknown_keys = [
key for key in config.keys() if key not in train_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected train keys: {unknown_keys} will be ignored',
UserWarning,
)
return train_meta
def init_data_config(config: Dict):
data_meta = {}
# defaults = _const.data_defaults(config)
load_data_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
load_data_keys.append(k)
for load_data_key in load_data_keys:
if load_data_key in config.keys():
inp = config[load_data_key]
extended = []
if type(inp) not in [str, list]:
raise ValueError(f'unexpected input {inp} for sturcture_list')
if type(inp) is str:
extended = glob.glob(inp)
elif type(inp) is list:
for i in inp:
if isinstance(i, str):
extended.extend(glob.glob(i))
elif isinstance(i, dict):
extended.append(i)
if len(extended) == 0:
raise ValueError(
f'Cannot find {inp} for {load_data_key}'
+ ' or path is not given'
)
data_meta[load_data_key] = extended
else:
data_meta[load_data_key] = False
for key, default in _const.DEFAULT_DATA_CONFIG.items():
data_meta[key] = config_initialize(
key, config, default, _const.DATA_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in data_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected data keys: {unknown_keys} will be ignored',
UserWarning,
)
return data_meta
def read_config_yaml(filename: str, return_separately: bool = False):
with open(filename, 'r') as fstream:
inputs = yaml.safe_load(fstream)
model_meta, train_meta, data_meta = {}, {}, {}
for key, config in inputs.items():
if key == 'model':
model_meta = init_model_config(config)
elif key == 'train':
train_meta = init_train_config(config)
elif key == 'data':
data_meta = init_data_config(config)
else:
raise ValueError(f'Unexpected input {key} given')
if return_separately:
return model_meta, train_meta, data_meta
else:
model_meta.update(train_meta)
model_meta.update(data_meta)
return model_meta
def main():
filename = './input.yaml'
read_config_yaml(filename)
if __name__ == '__main__':
main()
import glob
import os
import warnings
from typing import Any, Callable, Dict
import torch
import yaml
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
def config_initialize(
key: str,
config: Dict,
default: Any,
conditions: Dict,
):
# default value exist & no user input -> return default
if key not in config.keys():
return default
# No validation method exist => accept user input
user_input = config[key]
if key in conditions:
condition = conditions[key]
else:
return user_input
if type(default) is dict and isinstance(condition, dict):
for i_key, val in default.items():
user_input[i_key] = config_initialize(
i_key, user_input, val, condition
)
return user_input
elif isinstance(condition, type):
if isinstance(user_input, condition):
return user_input
else:
try:
return condition(user_input) # try type casting
except ValueError:
raise ValueError(
f"Expect '{user_input}' for '{key}' is {condition}"
)
elif isinstance(condition, Callable) and condition(user_input):
return user_input
else:
raise ValueError(
f"Given input '{user_input}' for '{key}' is not valid"
)
def init_model_config(config: Dict):
# defaults = _const.model_defaults(config)
model_meta = {}
# init complicated ones
if KEY.CHEMICAL_SPECIES not in config.keys():
raise ValueError('required key chemical_species not exist')
input_chem = config[KEY.CHEMICAL_SPECIES]
if isinstance(input_chem, str) and input_chem.lower() == 'auto':
model_meta[KEY.CHEMICAL_SPECIES] = 'auto'
model_meta[KEY.NUM_SPECIES] = 'auto'
model_meta[KEY.TYPE_MAP] = 'auto'
elif isinstance(input_chem, str) and 'univ' in input_chem.lower():
model_meta.update(util.chemical_species_preprocess([], universal=True))
else:
if isinstance(input_chem, list) and all(
isinstance(x, str) for x in input_chem
):
pass
elif isinstance(input_chem, str):
input_chem = (
input_chem.replace('-', ',').replace(' ', ',').split(',')
)
input_chem = [chem for chem in input_chem if len(chem) != 0]
else:
raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange')
model_meta.update(util.chemical_species_preprocess(input_chem))
# deprecation warnings
if KEY.AVG_NUM_NEIGH in config:
warnings.warn(
"key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'."
' We use the default, the average number of neighbors in the'
' dataset, if not provided.',
UserWarning,
)
config.pop(KEY.AVG_NUM_NEIGH)
if KEY.TRAIN_AVG_NUM_NEIGH in config:
warnings.warn(
"key 'train_avg_num_neigh' is deprecated. Please use"
" 'train_denominator'. We overwrite train_denominator as given"
' train_avg_num_neigh',
UserWarning,
)
config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH]
config.pop(KEY.TRAIN_AVG_NUM_NEIGH)
if KEY.OPTIMIZE_BY_REDUCE in config:
warnings.warn(
"key 'optimize_by_reduce' is deprecated. Always true",
UserWarning,
)
config.pop(KEY.OPTIMIZE_BY_REDUCE)
# init simpler ones
for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items():
model_meta[key] = config_initialize(
key, config, default, _const.MODEL_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in model_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected model keys: {unknown_keys} will be ignored',
UserWarning,
)
return model_meta
def init_train_config(config: Dict):
train_meta = {}
# defaults = _const.train_defaults(config)
try:
device_input = config[KEY.DEVICE]
train_meta[KEY.DEVICE] = torch.device(device_input)
except KeyError:
train_meta[KEY.DEVICE] = (
torch.device('cuda')
if torch.cuda.is_available()
else torch.device('cpu')
)
train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE])
# init simpler ones
for key, default in _const.DEFAULT_TRAINING_CONFIG.items():
train_meta[key] = config_initialize(
key, config, default, _const.TRAINING_CONFIG_CONDITION
)
if KEY.CONTINUE in config.keys():
cnt_dct = config[KEY.CONTINUE]
if KEY.CHECKPOINT not in cnt_dct.keys():
raise ValueError('no checkpoint is given in continue')
checkpoint = cnt_dct[KEY.CHECKPOINT]
if os.path.isfile(checkpoint):
checkpoint_file = checkpoint
else:
checkpoint_file = util.pretrained_name_to_path(checkpoint)
train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file})
unknown_keys = [
key for key in config.keys() if key not in train_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected train keys: {unknown_keys} will be ignored',
UserWarning,
)
return train_meta
def init_data_config(config: Dict):
data_meta = {}
# defaults = _const.data_defaults(config)
load_data_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
load_data_keys.append(k)
for load_data_key in load_data_keys:
if load_data_key in config.keys():
inp = config[load_data_key]
extended = []
if type(inp) not in [str, list]:
raise ValueError(f'unexpected input {inp} for sturcture_list')
if type(inp) is str:
extended = glob.glob(inp)
elif type(inp) is list:
for i in inp:
if isinstance(i, str):
extended.extend(glob.glob(i))
elif isinstance(i, dict):
extended.append(i)
if len(extended) == 0:
raise ValueError(
f'Cannot find {inp} for {load_data_key}'
+ ' or path is not given'
)
data_meta[load_data_key] = extended
else:
data_meta[load_data_key] = False
for key, default in _const.DEFAULT_DATA_CONFIG.items():
data_meta[key] = config_initialize(
key, config, default, _const.DATA_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in data_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected data keys: {unknown_keys} will be ignored',
UserWarning,
)
return data_meta
def read_config_yaml(filename: str, return_separately: bool = False):
with open(filename, 'r') as fstream:
inputs = yaml.safe_load(fstream)
model_meta, train_meta, data_meta = {}, {}, {}
for key, config in inputs.items():
if key == 'model':
model_meta = init_model_config(config)
elif key == 'train':
train_meta = init_train_config(config)
elif key == 'data':
data_meta = init_data_config(config)
else:
raise ValueError(f'Unexpected input {key} given')
if return_separately:
return model_meta, train_meta, data_meta
else:
model_meta.update(train_meta)
model_meta.update(data_meta)
return model_meta
def main():
filename = './input.yaml'
read_config_yaml(filename)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment