Unverified Commit ca86f720 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent b75ed73c
import math
import torch
@torch.jit.script
def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x) - math.log(2.0)
from typing import List
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, TensorProduct
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .activation import ShiftedSoftPlus
from .util import broadcast
def message_gather(
node_features: torch.Tensor,
edge_dst: torch.Tensor,
message: torch.Tensor
):
index = broadcast(edge_dst, message, 0)
out_shape = [len(node_features)] + list(message.shape[1:])
out = torch.zeros(
out_shape,
dtype=node_features.dtype,
device=node_features.device
)
out.scatter_reduce_(0, index, message, reduce='sum')
return out
@compile_mode('script')
class IrrepsConvolution(nn.Module):
"""
convolution of (fig 2.b), comm. in LAMMPS
"""
def __init__(
self,
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out: Irreps,
weight_layer_input_to_hidden: List[int],
weight_layer_act=ShiftedSoftPlus,
denominator: float = 1.0,
train_denominator: bool = False,
data_key_x: str = KEY.NODE_FEATURE,
data_key_filter: str = KEY.EDGE_ATTR,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
data_key_edge_idx: str = KEY.EDGE_IDX,
lazy_layer_instantiate: bool = True,
is_parallel: bool = False,
):
super().__init__()
self.denominator = nn.Parameter(
torch.FloatTensor([denominator]), requires_grad=train_denominator
)
self.key_x = data_key_x
self.key_filter = data_key_filter
self.key_weight_input = data_key_weight_input
self.key_edge_idx = data_key_edge_idx
self.is_parallel = is_parallel
instructions = []
irreps_mid = []
weight_numel = 0
for i, (mul_x, ir_x) in enumerate(irreps_x):
for j, (_, ir_filter) in enumerate(irreps_filter):
for ir_out in ir_x * ir_filter:
if ir_out in irreps_out: # here we drop l > lmax
k = len(irreps_mid)
weight_numel += mul_x * 1 # path shape
irreps_mid.append((mul_x, ir_out))
instructions.append((i, j, k, 'uvu', True))
irreps_mid = Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort() # type: ignore
instructions = [
(i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
# From v0.11.x, to compatible with cuEquivariance
self._instructions_before_sort = instructions
instructions = sorted(instructions, key=lambda x: x[2])
self.convolution_kwargs = dict(
irreps_in1=irreps_x,
irreps_in2=irreps_filter,
irreps_out=irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.weight_nn_kwargs = dict(
hs=weight_layer_input_to_hidden + [weight_numel],
act=weight_layer_act
)
self.convolution = None
self.weight_nn = None
self.layer_instantiated = False
self.convolution_cls = TensorProduct
self.weight_nn_cls = FullyConnectedNet
if not lazy_layer_instantiate:
self.instantiate()
self._comm_size = irreps_x.dim # used in parallel
def instantiate(self):
if self.convolution is not None:
raise ValueError('Convolution layer already exists')
if self.weight_nn is not None:
raise ValueError('Weight_nn layer already exists')
self.convolution = self.convolution_cls(**self.convolution_kwargs)
self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.convolution is not None, 'Convolution is not instantiated'
assert self.weight_nn is not None, 'Weight_nn is not instantiated'
weight = self.weight_nn(data[self.key_weight_input])
x = data[self.key_x]
if self.is_parallel:
x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]])
# note that 1 -> src 0 -> dst
edge_src = data[self.key_edge_idx][1]
edge_dst = data[self.key_edge_idx][0]
message = self.convolution(x[edge_src], data[self.key_filter], weight)
x = message_gather(x, edge_dst, message)
x = x.div(self.denominator)
if self.is_parallel:
x = torch.tensor_split(x, data[KEY.NLOCAL])[0]
data[self.key_x] = x
return data
import itertools
import warnings
from typing import Iterator, Literal, Union
import e3nn.o3 as o3
import numpy as np
from .convolution import IrrepsConvolution
from .linear import IrrepsLinear
from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
_CUE_AVAILABLE = True
# Obatained from MACE
class O3_e3nn(cue.O3):
def __mul__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> Iterator['O3_e3nn']:
return [ # type: ignore
O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)
]
@classmethod
def clebsch_gordan( # type: ignore
cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn'
) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim
)
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> bool:
rep2 = rep1._from(rep2) # type: ignore
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod
def iterator(cls) -> Iterator['O3_e3nn']:
for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
except ImportError:
_CUE_AVAILABLE = False
def is_cue_available():
return _CUE_AVAILABLE
def cue_needed(func):
def wrapper(*args, **kwargs):
if is_cue_available():
return func(*args, **kwargs)
else:
raise ImportError('cue is not available')
return wrapper
def _check_may_not_compatible(orig_kwargs, defaults):
for k, v in defaults.items():
v_given = orig_kwargs.pop(k, v)
if v_given != v:
warnings.warn(f'{k}: {v} is ignored to use cuEquivariance')
def is_cue_cuda_available_model(config):
if config.get('use_bias_in_linear', False):
warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn')
return False
else:
return True
@cue_needed
def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']):
"""Convert e3nn irreps to given group's cue irreps"""
if group == 'SO3':
assert all(irrep.ir.p == 1 for irrep in irreps)
return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore
elif group == 'O3':
return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore
else:
raise ValueError(f'Unknown group: {group}')
@cue_needed
def patch_linear(
module: Union[IrrepsLinear, SelfConnectionLinearIntro],
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
orig_kwargs = module.linear_kwargs
may_not_compatible_default = dict(
f_in=None,
f_out=None,
instructions=None,
biases=False,
path_normalization='element',
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(orig_kwargs, may_not_compatible_default)
module.linear_cls = cuet.Linear # type: ignore
orig_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_convolution(
module: IrrepsConvolution,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
# conv_kwargs will be patched in place
conv_kwargs = module.convolution_kwargs
conv_kwargs.update(
dict(
irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group),
irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group),
filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group),
)
)
inst_orig = conv_kwargs.pop('instructions')
inst_sorted = sorted(inst_orig, key=lambda x: x[2])
assert all([a == b for a, b in zip(inst_orig, inst_sorted)])
may_not_compatible_default = dict(
in1_var=None,
in2_var=None,
out_var=None,
irrep_normalization=False,
path_normalization='element',
compile_left_right=True,
compile_right=False,
_specialized_code=None,
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(conv_kwargs, may_not_compatible_default)
module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore
conv_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_fully_connected(
module: SelfConnectionIntro,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore
module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
may_not_compatible_default = dict(
irrep_normalization=None,
path_normalization=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(
module.fc_tensor_product_kwargs, may_not_compatible_default
)
module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore
module.fc_tensor_product_kwargs.update(**cue_kwargs)
return module
import math
import torch
import torch.nn as nn
from e3nn.o3 import Irreps, SphericalHarmonics
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EdgePreprocess(nn.Module):
"""
preprocessing pos to edge vectors and edge lengths
currently used in sevenn/scripts/deploy for lammps serial model
"""
def __init__(self, is_stress: bool):
super().__init__()
# controlled by 'AtomGraphSequential'
self.is_stress = is_stress
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
cell = data[KEY.CELL].view(-1, 3, 3)
else:
cell = data[KEY.CELL].view(3, 3)
cell_shift = data[KEY.CELL_SHIFT]
pos = data[KEY.POS]
batch = data[KEY.BATCH] # for deploy, must be defined first
if self.is_stress:
if self._is_batch_data:
num_batch = int(batch.max().cpu().item()) + 1
strain = torch.zeros(
(num_batch, 3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.bmm(
pos.unsqueeze(-2), sym_strain[batch]
).squeeze(-2)
cell = cell + torch.bmm(cell, sym_strain)
else:
strain = torch.zeros(
(3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.mm(pos, sym_strain)
cell = cell + torch.mm(cell, sym_strain)
idx_src = data[KEY.EDGE_IDX][0]
idx_dst = data[KEY.EDGE_IDX][1]
edge_vec = pos[idx_dst] - pos[idx_src]
if self._is_batch_data:
edge_vec = edge_vec + torch.einsum(
'ni,nij->nj', cell_shift, cell[batch[idx_src]]
)
else:
edge_vec = edge_vec + torch.einsum(
'ni,ij->nj', cell_shift, cell.squeeze(0)
)
data[KEY.EDGE_VEC] = edge_vec
data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1)
return data
class BesselBasis(nn.Module):
"""
f : (*, 1) -> (*, bessel_basis_num)
"""
def __init__(
self,
cutoff_length: float,
bessel_basis_num: int = 8,
trainable_coeff: bool = True,
):
super().__init__()
self.num_basis = bessel_basis_num
self.prefactor = 2.0 / cutoff_length
self.coeffs = torch.FloatTensor([
n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)
])
if trainable_coeff:
self.coeffs = nn.Parameter(self.coeffs)
def forward(self, r: torch.Tensor) -> torch.Tensor:
ur = r.unsqueeze(-1) # to fit dimension
return self.prefactor * torch.sin(self.coeffs * ur) / ur
class PolynomialCutoff(nn.Module):
"""
f : (*, 1) -> (*, 1)
https://arxiv.org/pdf/2003.03123.pdf
"""
def __init__(
self,
cutoff_length: float,
poly_cut_p_value: int = 6,
):
super().__init__()
p = poly_cut_p_value
self.cutoff_length = cutoff_length
self.p = p
self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0
self.coeff_p1 = p * (p + 2.0)
self.coeff_p2 = p * (p + 1.0) / 2.0
def forward(self, r: torch.Tensor) -> torch.Tensor:
r = r / self.cutoff_length
return (
1
- self.coeff_p0 * torch.pow(r, self.p)
+ self.coeff_p1 * torch.pow(r, self.p + 1.0)
- self.coeff_p2 * torch.pow(r, self.p + 2.0)
)
class XPLORCutoff(nn.Module):
"""
https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html
"""
def __init__(
self,
cutoff_length: float,
cutoff_on: float,
):
super().__init__()
self.r_on = cutoff_on
self.r_cut = cutoff_length
assert self.r_on < self.r_cut
def forward(self, r: torch.Tensor) -> torch.Tensor:
r_sq = r * r
r_on_sq = self.r_on * self.r_on
r_cut_sq = self.r_cut * self.r_cut
return torch.where(
r < self.r_on,
1.0,
(r_cut_sq - r_sq) ** 2
* (r_cut_sq + 2 * r_sq - 3 * r_on_sq)
/ (r_cut_sq - r_on_sq) ** 3,
)
@compile_mode('script')
class SphericalEncoding(nn.Module):
def __init__(
self,
lmax: int,
parity: int = -1,
normalization: str = 'component',
normalize: bool = True,
):
super().__init__()
self.lmax = lmax
self.normalization = normalization
self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e')
self.irreps_out = Irreps.spherical_harmonics(lmax, parity)
self.sph = SphericalHarmonics(
self.irreps_out,
normalize=normalize,
normalization=normalization,
irreps_in=self.irreps_in,
)
def forward(self, r: torch.Tensor) -> torch.Tensor:
return self.sph(r)
@compile_mode('script')
class EdgeEmbedding(nn.Module):
"""
embedding layer of |r| by
RadialBasis(|r|)*CutOff(|r|)
f : (N_edge) -> (N_edge, basis_num)
"""
def __init__(
self,
basis_module: nn.Module,
cutoff_module: nn.Module,
spherical_module: nn.Module,
):
super().__init__()
self.basis_function = basis_module
self.cutoff_function = cutoff_module
self.spherical = spherical_module
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
rvec = data[KEY.EDGE_VEC]
r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1)
data[KEY.EDGE_LENGTH] = r
data[KEY.EDGE_EMBEDDING] = self.basis_function(
r
) * self.cutoff_function(r).unsqueeze(-1)
data[KEY.EDGE_ATTR] = self.spherical(rvec)
return data
from typing import Callable, Dict
import torch.nn as nn
from e3nn.nn import Gate
from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EquivariantGate(nn.Module):
def __init__(
self,
irreps_x: Irreps,
act_scalar_dict: Dict[int, Callable],
act_gate_dict: Dict[int, Callable],
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
parity_mapper = {'e': 1, 'o': -1}
act_scalar_dict = {
parity_mapper[k]: v for k, v in act_scalar_dict.items()
}
act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()}
irreps_gated_elem = []
irreps_scalars_elem = []
# non scalar irreps > gated / scalar irreps > scalars
for mul, irreps in irreps_x:
if irreps.l > 0:
irreps_gated_elem.append((mul, irreps))
else:
irreps_scalars_elem.append((mul, irreps))
irreps_scalars = Irreps(irreps_scalars_elem)
irreps_gated = Irreps(irreps_gated_elem)
irreps_gates_parity = 1 if '0e' in irreps_scalars else -1
irreps_gates = Irreps(
[(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated]
)
act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars]
act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates]
self.gate = Gate(
irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated
)
def get_gate_irreps_in(self):
"""
user must call this function to get proper irreps in for forward
"""
return self.gate.irreps_in
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = self.gate(data[self.key_x])
return data
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .util import broadcast
@compile_mode('script')
class ForceOutput(nn.Module):
"""
works when pos.requires_grad_ is True
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = [data[self.key_pos]]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
pos_tensor,
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)[0]
# For torchscript
if grad is not None:
data[self.key_force] = torch.neg(grad)
return data
@compile_mode('script')
class ForceStressOutput(nn.Module):
"""
Compute stress and force from positions.
Used in serial torchscipt models
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = data[self.key_pos]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
[pos_tensor, data['_strain']],
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)
# make grad is not Optional[Tensor]
fgrad = grad[0]
if fgrad is not None:
data[self.key_force] = torch.neg(fgrad)
sgrad = grad[1]
volume = data[self.key_cell_volume]
vlim = 1e-3 # for cell volume = 0 for non PBC structures
if self._is_batch_data:
volume[volume < vlim] = vlim
elif volume < vlim:
volume = torch.tensor(vlim)
if sgrad is not None:
if self._is_batch_data:
stress = sgrad / volume.view(-1, 1, 1)
stress = torch.neg(stress)
virial_stress = torch.vstack((
stress[:, 0, 0],
stress[:, 1, 1],
stress[:, 2, 2],
stress[:, 0, 1],
stress[:, 1, 2],
stress[:, 0, 2],
))
data[self.key_stress] = virial_stress.transpose(0, 1)
else:
stress = sgrad / volume
stress = torch.neg(stress)
virial_stress = torch.stack((
stress[0, 0],
stress[1, 1],
stress[2, 2],
stress[0, 1],
stress[1, 2],
stress[0, 2],
))
data[self.key_stress] = virial_stress
return data
@compile_mode('script')
class ForceStressOutputFromEdge(nn.Module):
"""
Compute stress and force from edge.
Used in parallel torchscipt models, and training
"""
def __init__(
self,
data_key_edge: str = KEY.EDGE_VEC,
data_key_edge_idx: str = KEY.EDGE_IDX,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_edge = data_key_edge
self.key_edge_idx = data_key_edge_idx
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_edge
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item?
rij = data[self.key_edge]
energy = [(data[self.key_energy]).sum()]
edge_idx = data[self.key_edge_idx]
grad = torch.autograd.grad(
energy,
[rij],
create_graph=self.training,
allow_unused=True
)
# make grad is not Optional[Tensor]
fij = grad[0]
if fij is not None:
# compute force
pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
_edge_src = broadcast(edge_idx[0], fij, 0)
_edge_dst = broadcast(edge_idx[1], fij, 0)
pf.scatter_reduce_(0, _edge_src, fij, reduce='sum')
nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum')
data[self.key_force] = pf - nf
# compute virial
diag = rij * fij
s12 = rij[..., 0] * fij[..., 1]
s23 = rij[..., 1] * fij[..., 2]
s31 = rij[..., 2] * fij[..., 0]
# cat last dimension
_virial = torch.cat([
diag,
s12.unsqueeze(-1),
s23.unsqueeze(-1),
s31.unsqueeze(-1)
], dim=-1)
_s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device)
_edge_dst6 = broadcast(edge_idx[1], _virial, 0)
_s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum')
if self._is_batch_data:
batch = data[KEY.BATCH] # for deploy, must be defined first
nbatch = int(batch.max().cpu().item()) + 1
sout = torch.zeros(
(nbatch, 6), dtype=_virial.dtype, device=_virial.device
)
_batch = broadcast(batch, _s, 0)
sout.scatter_reduce_(0, _batch, _s, reduce='sum')
else:
sout = torch.sum(_s, dim=0)
data[self.key_stress] =\
torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1)
return data
from typing import Callable, List, Tuple
from e3nn.o3 import Irreps
import sevenn._keys as KEY
from .convolution import IrrepsConvolution
from .equivariant_gate import EquivariantGate
from .linear import IrrepsLinear
def NequIP_interaction_block(
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out_tp: Irreps,
irreps_out: Irreps,
weight_nn_layers: List[int],
conv_denominator: float,
train_conv_denominator: bool,
self_connection_pair: Tuple[Callable, Callable],
act_scalar: Callable,
act_gate: Callable,
act_radial: Callable,
bias_in_linear: bool,
num_species: int,
t: int, # interaction layer index
data_key_x: str = KEY.NODE_FEATURE,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
parallel: bool = False,
**conv_kwargs,
):
block = {}
irreps_node_attr = Irreps(f'{num_species}x0e')
sc_intro, sc_outro = self_connection_pair
gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate)
irreps_for_gate_in = gate_layer.get_gate_irreps_in()
block[f'{t}_self_connection_intro'] = sc_intro(
irreps_x,
irreps_operand=irreps_node_attr,
irreps_out=irreps_for_gate_in,
)
block[f'{t}_self_interaction_1'] = IrrepsLinear(
irreps_x, irreps_x,
data_key_in=data_key_x,
biases=bias_in_linear,
)
# convolution part, l>lmax is dropped as defined in irreps_out
block[f'{t}_convolution'] = IrrepsConvolution(
irreps_x=irreps_x,
irreps_filter=irreps_filter,
irreps_out=irreps_out_tp,
data_key_weight_input=data_key_weight_input,
weight_layer_input_to_hidden=weight_nn_layers,
weight_layer_act=act_radial,
denominator=conv_denominator,
train_denominator=train_conv_denominator,
is_parallel=parallel,
**conv_kwargs,
)
# irreps of x increase to gate_irreps_in
block[f'{t}_self_interaction_2'] = IrrepsLinear(
irreps_out_tp,
irreps_for_gate_in,
data_key_in=data_key_x,
biases=bias_in_linear,
)
block[f'{t}_self_connection_outro'] = sc_outro()
block[f'{t}_equivariant_gate'] = gate_layer
return block
from typing import Callable, List, Optional
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class IrrepsLinear(nn.Module):
"""
wrapper class of e3nn Linear to operate on AtomGraphData
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_in: str,
data_key_out: Optional[str] = None,
data_key_modal_attr: str = KEY.MODAL_ATTR,
num_modalities: int = 0,
lazy_layer_instantiate: bool = True,
**linear_kwargs,
):
super().__init__()
self.key_input = data_key_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
self.key_modal_attr = data_key_modal_attr
self._irreps_in_wo_modal = irreps_in
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.linear_kwargs = linear_kwargs
self.linear = None
self.layer_instantiated = False
self.num_modalities = num_modalities
self._is_batch_data = True
# use getter setter
self.linear_cls = Linear
if num_modalities > 1: # in case of multi-modal
self.set_num_modalities(num_modalities)
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def set_num_modalities(self, num_modalities):
if self.layer_instantiated:
raise ValueError('Layer already instantiated, can not change modalities')
irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e')
self.num_modalities = num_modalities
self.irreps_in = irreps_in
def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
batch_modality_onehot = data[self.key_modal_attr].reshape(
-1, self.num_modalities
)
batch_modality_onehot = batch_modality_onehot.type(
data[self.key_input].dtype
)
data[self.key_input] = torch.cat(
[data[self.key_input], batch_modality_onehot[batch]], dim=1
)
else:
modality_onehot = data[self.key_modal_attr].expand(
len(data[self.key_input]), -1
)
modality_onehot = modality_onehot.type(data[self.key_input].dtype)
data[self.key_input] = torch.cat(
[data[self.key_input], modality_onehot], dim=1
)
return data
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
if self.num_modalities > 1:
data = self._patch_modal_to_data(data)
data[self.key_output] = self.linear(data[self.key_input])
return data
@compile_mode('script')
class AtomReduce(nn.Module):
"""
atomic energy -> total energy
constant is multiplied to data
"""
def __init__(
self,
data_key_in: str,
data_key_out: str,
reduce: str = 'sum',
constant: float = 1.0,
):
super().__init__()
self.key_input = data_key_in
self.key_output = data_key_out
self.constant = constant
self.reduce = reduce
# controlled by the upper most wrapper 'AtomGraphSequential'
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
src = data[self.key_input].squeeze(1)
size = int(data[KEY.BATCH].max()) + 1
output = torch.zeros(
(size),
dtype=src.dtype,
device=src.device,
)
output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum')
data[self.key_output] = output * self.constant
else:
data[self.key_output] = torch.sum(data[self.key_input]) * self.constant
return data
@compile_mode('script')
class FCN_e3nn(nn.Module):
"""
wrapper class of e3nn FullyConnectedNet
"""
def __init__(
self,
irreps_in: Irreps, # confirm it is scalar & input size
dim_out: int,
hidden_neurons: List[int],
activation: Callable,
data_key_in: str,
data_key_out: Optional[str] = None,
**e3nn_kwargs,
):
super().__init__()
self.key_input = data_key_in
self.irreps_in = irreps_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
for _, irrep in irreps_in:
assert irrep.is_scalar()
inp_dim = irreps_in.dim
self.fcn = FullyConnectedNet(
[inp_dim] + hidden_neurons + [dim_out],
activation,
**e3nn_kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = self.fcn(data[self.key_input])
return data
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional
from ase.symbols import symbols2numbers
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
# TODO: put this to model_build and do not preprocess data by onehot
@compile_mode('script')
class OnehotEmbedding(nn.Module):
"""
x : tensor of shape (N, 1)
x_after : tensor of shape (N, num_classes)
It overwrite data_key_x
and saves input to data_key_save and output to data_key_additional
I know this is strange but it is for compatibility with previous version
and to specie wise shift scale work
ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2)
"""
def __init__(
self,
num_classes: int,
data_key_x: str = KEY.NODE_FEATURE,
data_key_out: Optional[str] = None,
data_key_save: Optional[str] = None,
data_key_additional: Optional[str] = None, # additional output
):
super().__init__()
self.num_classes = num_classes
self.key_x = data_key_x
if data_key_out is None:
self.key_output = data_key_x
else:
self.key_output = data_key_out
self.key_save = data_key_save
self.key_additional_output = data_key_additional
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
inp = data[self.key_x]
embd = torch.nn.functional.one_hot(inp, self.num_classes)
embd = embd.float()
data[self.key_output] = embd
if self.key_additional_output is not None:
data[self.key_additional_output] = embd # for self-connection
if self.key_save is not None:
data[self.key_save] = inp # for elemwise shift scale
return data
def get_type_mapper_from_specie(specie_list: List[str]):
"""
from ['Hf', 'O']
return {72: 0, 8: 1}
"""
specie_list = sorted(specie_list)
type_map = {}
unique_counter = 0
for specie in specie_list:
atomic_num = symbols2numbers(specie)[0]
if atomic_num in type_map:
continue
type_map[atomic_num] = unique_counter
unique_counter += 1
return type_map
# deprecated
def one_hot_atom_embedding(
atomic_numbers: List[int], type_map: Dict[int, int]
):
"""
atomic numbers from ase.get_atomic_numbers
type_map from get_type_mapper_from_specie()
"""
num_classes = len(type_map)
try:
type_numbers = torch.LongTensor(
[type_map[num] for num in atomic_numbers]
)
except KeyError as e:
raise ValueError(f'Atomic number {e.args[0]} is not expected')
embd = torch.nn.functional.one_hot(type_numbers, num_classes)
embd = embd.to(torch.get_default_dtype())
return embd
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType
def _as_univ(
ss: List[float], type_map: Dict[int, int], default: float
) -> List[float]:
assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long'
return [
ss[type_map[z]] if z in type_map else default
for z in range(NUM_UNIV_ELEMENT)
]
@compile_mode('script')
class Rescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
"""
def __init__(
self,
shift: float,
scale: float,
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
train_shift_scale: bool = False,
**kwargs,
):
assert isinstance(shift, float) and isinstance(scale, float)
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor([shift]), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor([scale]), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
def get_shift(self) -> float:
return self.shift.detach().cpu().tolist()[0]
def get_scale(self) -> float:
return self.scale.detach().cpu().tolist()[0]
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = data[self.key_input] * self.scale + self.shift
return data
@compile_mode('script')
class SpeciesWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Use as it is if given list, expand to list if one of them is float
If two lists are given and length is not the same, raise error
"""
def __init__(
self,
shift: Union[List[float], float],
scale: Union[List[float], float],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_indices: str = KEY.ATOM_TYPE,
train_shift_scale: bool = False,
):
super().__init__()
assert isinstance(shift, float) or isinstance(shift, list)
assert isinstance(scale, float) or isinstance(scale, list)
if (
isinstance(shift, list)
and isinstance(scale, list)
and len(shift) != len(scale)
):
raise ValueError('List length should be same')
if isinstance(shift, list):
num_species = len(shift)
elif isinstance(scale, list):
num_species = len(scale)
else:
raise ValueError('Both shift and scale is not a list')
shift = [shift] * num_species if isinstance(shift, float) else shift
scale = [scale] * num_species if isinstance(scale, float) else scale
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_indices = data_key_indices
def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return shift in list of float. If type_map is given, return type_map reversed
shift, which index equals atomic_number. 0.0 is assigned for atomis not found
"""
shift = self.shift.detach().cpu().tolist()
if type_map:
shift = _as_univ(shift, type_map, 0.0)
return shift
def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return scale in list of float. If type_map is given, return type_map reversed
scale, which index equals atomic_number. 1.0 is assigned for atomis not found
"""
scale = self.scale.detach().cpu().tolist()
if type_map:
scale = _as_univ(scale, type_map, 1.0)
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float]],
scale: Union[float, List[float]],
type_map: Dict[int, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
"""
shift_scale = []
n_atom_types = len(type_map)
for s in (shift, scale):
if isinstance(s, list) and len(s) > n_atom_types:
if len(s) != NUM_UNIV_ELEMENT:
raise ValueError('given shift or scale is strange')
s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])]
# s = [s[z] for z in sorted(type_map, key=type_map.get)]
elif isinstance(s, float):
s = [s] * n_atom_types
elif isinstance(s, list) and len(s) == 1:
s = s * n_atom_types
shift_scale.append(s)
assert all([len(s) == n_atom_types for s in shift_scale])
shift, scale = shift_scale
return SpeciesWiseRescale(shift, scale, **kwargs)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
indices = data[self.key_indices]
data[self.key_output] = data[self.key_input] * self.scale[indices].view(
-1, 1
) + self.shift[indices].view(-1, 1)
return data
@compile_mode('script')
class ModalWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Given shift or scale is either modal-wise and atom-wise or
not modal-wise but atom-wise. It is always interpreted as atom-wise.
"""
def __init__(
self,
shift: List[List[float]],
scale: List[List[float]],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_modal_indices: str = KEY.MODAL_TYPE,
data_key_atom_indices: str = KEY.ATOM_TYPE,
use_modal_wise_shift: bool = False,
use_modal_wise_scale: bool = False,
train_shift_scale: bool = False,
):
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_atom_indices = data_key_atom_indices
self.key_modal_indices = data_key_modal_indices
self.use_modal_wise_shift = use_modal_wise_shift
self.use_modal_wise_scale = use_modal_wise_scale
self._is_batch_data = True
def get_shift(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise shift: return univ shift
both type_map and modal_map is given and modal wise shift: return fully
resolved modalwise univ shift
"""
shift = self.shift.detach().cpu().tolist()
if type_map and not self.use_modal_wise_shift:
shift = _as_univ(shift, type_map, 0.0)
elif self.use_modal_wise_shift and modal_map and type_map:
shift = [_as_univ(s, type_map, 0.0) for s in shift]
shift = {modal: shift[idx] for modal, idx in modal_map.items()}
return shift
def get_scale(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise scale: return univ scale
both type_map and modal_map is given and modal wise scale: return fully
resolved modalwise univ scale
"""
scale = self.scale.detach().cpu().tolist()
if type_map and not self.use_modal_wise_scale:
scale = _as_univ(scale, type_map, 0.0)
elif self.use_modal_wise_scale and modal_map and type_map:
scale = [_as_univ(s, type_map, 0.0) for s in scale]
scale = {modal: scale[idx] for modal, idx in modal_map.items()}
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float], Dict[str, Any]],
scale: Union[float, List[float], Dict[str, Any]],
use_modal_wise_shift: bool,
use_modal_wise_scale: bool,
type_map: Dict[int, int],
modal_map: Dict[str, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT
, assume it is element-wise list
otherwise, it is modal-wise list
"""
def solve_mapper(arr, map):
# value is attr index and never overlap, key is either 'z' or modal str
return [arr[z] for z in sorted(map, key=lambda x: map[x])]
shift_scale = []
n_atom_types = len(type_map)
n_modals = len(modal_map)
for s, use_mw in (
(shift, use_modal_wise_shift),
(scale, use_modal_wise_scale),
):
# solve elemewise, or broadcast
if isinstance(s, float):
# given, modal-wise: no, elem-wise: no => broadcast
shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,)
res = torch.full(shape, s).tolist() # TODO: w/o torch
elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT:
# given, modal-wise: no, elem-wise: yes(univ) => solve elem map
s = solve_mapper(s, type_map)
res = [s] * n_modals if use_mw else s
elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_modals
and use_mw
):
res = [[v] * n_atom_types for v in s]
elif ( # given, modal-wise: no, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_atom_types
and not use_mw
):
res = s
elif ( # given, modal-wise: yes, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], list)
and len(s) == n_modals
and len(s[0]) == n_atom_types
and use_mw
):
res = s
elif isinstance(s, dict) and use_mw:
# solve modal dict, modal-wise: yes
s = solve_mapper(s, modal_map)
res = []
for v in s:
if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT:
# elem-wise: yes(univ) => solve elem map
v = solve_mapper(v, type_map)
elif isinstance(v, float):
# elem-wise: no => broadcast to elemwise
v = [v] * n_atom_types
else:
raise ValueError(f'Invalid shift or scale {s}')
res.append(v)
else:
raise ValueError(f'Invalid shift or scale {s}')
if use_mw:
assert (
isinstance(res, list)
and isinstance(res[0], list)
and len(res) == n_modals
)
assert all([len(r) == n_atom_types for r in res]) # type: ignore
else:
assert (
isinstance(res, list)
and isinstance(res[0], float)
and len(res) == n_atom_types
)
shift_scale.append(res)
shift, scale = shift_scale
return ModalWiseRescale(
shift,
scale,
use_modal_wise_shift=use_modal_wise_shift,
use_modal_wise_scale=use_modal_wise_scale,
**kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
modal_indices = data[self.key_modal_indices][batch]
else:
modal_indices = data[self.key_modal_indices]
atom_indices = data[self.key_atom_indices]
shift = (
self.shift[modal_indices, atom_indices]
if self.use_modal_wise_shift
else self.shift[atom_indices]
)
scale = (
self.scale[modal_indices, atom_indices]
if self.use_modal_wise_scale
else self.scale[atom_indices]
)
data[self.key_output] = data[self.key_input] * scale.view(
-1, 1
) + shift.view(-1, 1)
return data
def get_resolved_shift_scale(
module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale],
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
):
"""
Return resolved shift and scale from scale modules. For element wise case,
convert to list of floats where idx is atomic number. For modal wise case, return
dictionary of shift scale where key is modal name given in modal_map
Return:
Tuple of solved shift and scale
"""
if isinstance(module, Rescale):
return (module.get_shift(), module.get_scale())
elif isinstance(module, SpeciesWiseRescale):
return (module.get_shift(type_map), module.get_scale(type_map))
elif isinstance(module, ModalWiseRescale):
return (
module.get_shift(type_map, modal_map),
module.get_scale(type_map, modal_map),
)
raise ValueError('Not scale module')
import torch.nn as nn
from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class SelfConnectionIntro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
irreps_in: Irreps,
irreps_operand: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
data_key_operand: str = KEY.NODE_ATTR,
lazy_layer_instantiate: bool = True,
**kwargs, # for compatibility
):
super().__init__()
self.fc_tensor_product = FullyConnectedTensorProduct(
irreps_in, irreps_operand, irreps_out
)
self.irreps_in1 = irreps_in
self.irreps_in2 = irreps_operand
self.irreps_out = irreps_out
self.key_x = data_key_x
self.key_operand = data_key_operand
self.fc_tensor_product = None
self.layer_instantiated = False
self.fc_tensor_product_cls = FullyConnectedTensorProduct
self.fc_tensor_product_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.fc_tensor_product is not None:
raise ValueError('fc_tensor_product layer already exists')
self.fc_tensor_product = self.fc_tensor_product_cls(
self.irreps_in1,
self.irreps_in2,
self.irreps_out,
shared_weights=True,
internal_weights=None, # same as True
**self.fc_tensor_product_kwargs,
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.fc_tensor_product is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product(
data[self.key_x], data[self.key_operand]
)
return data
@compile_mode('script')
class SelfConnectionLinearIntro(nn.Module):
"""
Linear style self connection update
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
lazy_layer_instantiate: bool = True,
**kwargs,
):
super().__init__()
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.key_x = data_key_x
self.linear = None
self.layer_instantiated = False
self.linear_cls = Linear
# TODO: better to have SelfConnectionIntro super class
kwargs.pop('irreps_operand')
self.linear_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x])
return data
@compile_mode('script')
class SelfConnectionOutro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP]
del data[KEY.SELF_CONNECTION_TEMP]
return data
import warnings
from collections import OrderedDict
from typing import Dict, Optional
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
def _instantiate_modules(modules):
# see IrrepsLinear of linear.py
for module in modules.values():
if not getattr(module, 'layer_instantiated', True):
module.instantiate()
@compile_mode('script')
class _ModalInputPrepare(nn.Module):
def __init__(
self,
modal_idx: int
):
super().__init__()
self.modal_idx = modal_idx
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[KEY.MODAL_TYPE] = torch.tensor(
self.modal_idx,
dtype=torch.int64,
device=data['x'].device,
)
return data
@compile_mode('script')
class AtomGraphSequential(nn.Sequential):
"""
Wrapper of SevenNet model
Args:
modules: OrderedDict of nn.Modules
cutoff: not used internally, but makes sense to have
type_map: atomic_numbers => onehot index (see nn/node_embedding.py)
eval_type_map: perform index mapping using type_map defaults to True
data_key_atomic_numbers: used when eval_type_map is True
data_key_node_feature: used when eval_type_map is True
data_key_grad: if given, sets its requires grad True before pred
"""
def __init__(
self,
modules: Dict[str, nn.Module],
cutoff: float = 0.0,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
eval_type_map: bool = True,
eval_modal_map: bool = False,
data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS,
data_key_node_feature: str = KEY.NODE_FEATURE,
data_key_grad: Optional[str] = None,
):
if not isinstance(modules, OrderedDict): # backward compat
modules = OrderedDict(modules)
self.cutoff = cutoff
self.type_map = type_map
self.eval_type_map = eval_type_map
self.is_batch_data = True
if cutoff == 0.0:
warnings.warn('cutoff is 0.0 or not given', UserWarning)
if self.type_map is None:
warnings.warn('type_map is not given', UserWarning)
self.eval_type_map = False
else:
z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long))
for z, onehot in self.type_map.items():
z_to_onehot_tensor[z] = onehot
self.z_to_onehot_tensor = z_to_onehot_tensor
if eval_modal_map and modal_map is None:
raise ValueError('eval_modal_map is True but modal_map is None')
self.eval_modal_map = eval_modal_map
self.modal_map = modal_map
self.key_atomic_numbers = data_key_atomic_numbers
self.key_node_feature = data_key_node_feature
self.key_grad = data_key_grad
_instantiate_modules(modules)
super().__init__(modules)
if not isinstance(self._modules, OrderedDict): # backward compat
self._modules = OrderedDict(self._modules)
def set_is_batch_data(self, flag: bool):
# whether given data is batched or not some module have to change
# its behavior. checking whether data is batched or not inside
# forward function make problem harder when make it into torchscript
for module in self:
try: # Easier to ask for forgiveness than permission.
module._is_batch_data = flag # type: ignore
except AttributeError:
pass
self.is_batch_data = flag
def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'):
tg_module = self._modules[modlue_name]
for m in tg_module.modules():
try:
return repr(m.__getattribute__(attr_key))
except AttributeError:
pass
return None
def prepand_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
self._modules.move_to_end(key, last=False) # type: ignore
def replace_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
def delete_module_by_key(self, key: str):
if key in self._modules.keys():
del self._modules[key]
@torch.jit.unused
def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor):
assert atomic_numbers.dtype == torch.int64
device = atomic_numbers.device
z_to_onehot_tensor = self.z_to_onehot_tensor.to(device)
return torch.index_select(
input=z_to_onehot_tensor, dim=0, index=atomic_numbers
)
@torch.jit.unused
def _eval_modal_map(self, data: AtomGraphDataType):
assert self.modal_map is not None
# modal_map: dict[str, int]
if not self.is_batch_data:
modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore
else:
modal_idx = [
self.modal_map[ii] # type: ignore
for ii in data[KEY.DATA_MODALITY]
]
modal_idx = torch.tensor(
modal_idx,
dtype=torch.int64,
device=data.x.device, # type: ignore
)
data[KEY.MODAL_TYPE] = modal_idx
def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self.eval_type_map:
atomic_numbers = data[self.key_atomic_numbers]
onehot = self._atomic_numbers_to_onehot(atomic_numbers)
data[self.key_node_feature] = onehot
if self.eval_modal_map:
self._eval_modal_map(data)
if self.key_grad is not None:
data[self.key_grad].requires_grad_(True)
return data
def prepare_modal_deploy(self, modal: str):
if self.modal_map is None:
return
self.eval_modal_map = False
self.set_is_batch_data(False)
modal_idx = self.modal_map[modal] # type: ignore
self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx))
def forward(self, input: AtomGraphDataType) -> AtomGraphDataType:
data = self._preprocess(input)
for module in self:
data = module(data)
return data
import torch
def broadcast(
src: torch.Tensor,
other: torch.Tensor,
dim: int
):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
We support the LAMMPS pair style `d3` of the Grimme's D3 dispersion (van der Waals) correction scheme accelerated with CUDA, which can be used within LAMMPS in conjunction with SevenNet.
**PLEASE NOTE:** Currently, this D3 code does not support mulit-GPU parallelism yet. So it can only be run on a single GPU.
# About Grimme's D3 code accelerated with CUDA
This is LAMMPS implementation of [Grimme's D3 method](https://doi.org/10.1063/1.3382344). We have ported the code from the [original fortran code](https://www.chemie.uni-bonn.de/grimme/de/software/dft-d3) to a LAMMPS pair style written in CUDA/C++.
While D3 method is significantly faster than DFT, existing CPU implementations were slower than SevenNet. To address this, we have adopted CUDA and single precision (FP32) operations to accelerate the code.
## Installation for LAMMPS
Simply run,
```bash
sevenn_patch_lammps ./lammps_sevenn --d3
```
You can follow the remaining installation steps in the [SevenNet documentation](../../README.md#installation-for-lammps).
Also, this code requires a GPU with a compute capability of **at least 6.0**. If you try to compile it with version 5.0, you may encounter an `atomicAdd` error.
The target compute capability of this code follows the setting of LibTorch in SevenNet, except for version 5.0.
You can manually select the target capability using the `TORCH_CUDA_ARCH_LIST` environment variable. For example, you can use: `export TORCH_CUDA_ARCH_LIST="6.1;7.0;8.0;8.6;8.9;9.0"`.
## Usage for LAMMPS
You can use the D3 dispersion correction in LAMMPS with SevenNet through the `pair/hybrid` command:
```txt
pair_style hybrid/overlay e3gnn d3 {cutoff_d3_r} {cutoff_d3_cn} {type_of_damping} {name_of_functional}
pair_coeff * * e3gnn {path_to_serial_model} {space_separated_chemical_species}
pair_coeff * * d3 {space_separated_chemical_species}
```
for example,
```txt
pair_style hybrid/overlay e3gnn d3 9000 1600 damp_bj pbe
pair_coeff * * e3gnn ./deployed_serial.pt C H O
pair_coeff * * d3 C H O
```
`cutoff_d3_r` and `cutoff_d3_cn` are square of cutoff radii for energy/force and coordination number, respectively. Units are Bohr radius: 1 (Bohr radius) = 0.52917721 (Å). Default values are `9000` and `1600`, respectively. this is also the default values used in VASP.[^1]
Available `type_of_damping` are as follows:
- `damp_zero`: Zero damping
- `damp_bj`: Becke-Johnson damping
Available `name_of_functional` options are the same as in the original Fortran code. SevenNet-0 is trained on the 'PBE' functional, so you should specify 'pbe' in the script when using it. For other supporting functionals, check 'List of parametrized functionals' in [here](https://www.chemie.uni-bonn.de/grimme/de/software/dft-d3).
## Features
- Selective(or no) periodic boundary condition: implemented, But only PBC/noPBC can be checked through original FORTRAN code; selective PBC cannot
- 3-body term, n > 8 term: not implemented (as to VASP)
- Modified versions of zero and bj damping
## Cautions
- It can be slower than the CPU with a small number of atoms.
- The maximum number of atoms that can be calculated is 46,340 (overflow issue).
- There can be occurred small amounts of numerical error
- The introduction of some FP32 operations can lead to minor numerical errors, particularly in pressure calculations, but these are generally smaller than those seen with SevenNet.
- If the error is too large, ensure that the `fmad=false` option in `patch_lammps.sh` is correctly applied during build.
## To do
- Remove atom_modify / compute virial dependency.
- Add support for ASE as calculator interface.
- Add support for multi GPUs (with `e3gnn/parallel`).
- Implement without Unified Memory.
- Unfix the `threadsPerBlock=128`.
- Unroll the repetition loop `k` (for small number of atoms).
## Contributors
- Hyungmin An: Ported the original Fortran D3 code to C++ with OpenMP and MPI.
- Gijin Kim: Accelerated the C++ D3 code with OpenACC[^2] and CUDA, and currently maintains it.
[^1]: On the [VASP DFT-D3](https://www.vasp.at/wiki/index.php/DFT-D3) page, the `VDW_RADIUS` and `VDW_CNRADIUS` are `50.2` and `20.0`, respectively (units are Å). However, when running VASP 6.3.2 with D3 using zero damping (BJ does not provide such a log), the default values in the OUTCAR file are `50.2022` and `21.1671`. These values are the same as our defaults.
[^2]: Since OpenACC is not compatible with libtorch, we chose to use the CUDA.
This diff is collapsed.
/* -*- c++ -*- ----------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories
LAMMPS development team: developers@lammps.org
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
#ifndef LMP_COMM_BRICK_H
#define LMP_COMM_BRICK_H
#include "comm.h"
namespace LAMMPS_NS {
class CommBrick : public Comm {
public:
CommBrick(class LAMMPS *);
CommBrick(class LAMMPS *, class Comm *);
~CommBrick() override;
void init() override;
void setup() override; // setup 3d comm pattern
void forward_comm(int dummy = 0) override; // forward comm of atom coords
void reverse_comm() override; // reverse comm of forces
void exchange() override; // move atoms to new procs
void borders() override; // setup list of atoms to comm
void forward_comm(class Pair *) override; // forward comm from a Pair
void reverse_comm(class Pair *) override; // reverse comm from a Pair
void forward_comm(class Bond *) override; // forward comm from a Bond
void reverse_comm(class Bond *) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *) override; // forward from a Compute
void reverse_comm(class Compute *) override; // reverse from a Compute
void forward_comm(class Dump *) override; // forward comm from a Dump
void reverse_comm(class Dump *) override; // reverse comm from a Dump
void forward_comm_array(int, double **) override; // forward comm of array
void *extract(const char *, int &) override;
double memory_usage() override;
// patched from SevenNet //
void forward_comm(class PairE3GNNParallel *);
void reverse_comm(class PairE3GNNParallel *);
// patched from SevenNet //
protected:
int nswap; // # of swaps to perform = sum of maxneed
int recvneed[3][2]; // # of procs away I recv atoms from
int sendneed[3][2]; // # of procs away I send atoms to
int maxneed[3]; // max procs away any proc needs, per dim
int maxswap; // max # of swaps memory is allocated for
int *sendnum, *recvnum; // # of atoms to send/recv in each swap
int *sendproc, *recvproc; // proc to send/recv to/from at each swap
int *size_forward_recv; // # of values to recv in each forward comm
int *size_reverse_send; // # to send in each reverse comm
int *size_reverse_recv; // # to recv in each reverse comm
double *slablo, *slabhi; // bounds of slab to send at each swap
double **multilo, **multihi; // bounds of slabs for multi-collection swap
double **multioldlo, **multioldhi; // bounds of slabs for multi-type swap
double **cutghostmulti; // cutghost on a per-collection basis
double **cutghostmultiold; // cutghost on a per-type basis
int *pbc_flag; // general flag for sending atoms thru PBC
int **pbc; // dimension flags for PBC adjustments
int *firstrecv; // where to put 1st recv atom in each swap
int **sendlist; // list of atoms to send in each swap
int *localsendlist; // indexed list of local sendlist atoms
int *maxsendlist; // max size of send list for each swap
double *buf_send; // send buffer for all comm
double *buf_recv; // recv buffer for all comm
int maxsend, maxrecv; // current size of send/recv buffer
int smax, rmax; // max size in atoms of single borders send/recv
// NOTE: init_buffers is called from a constructor and must not be made virtual
void init_buffers();
int updown(int, int, int, double, int, double *);
// compare cutoff to procs
virtual void grow_send(int, int); // reallocate send buffer
virtual void grow_recv(int); // free/allocate recv buffer
virtual void grow_list(int, int); // reallocate one sendlist
virtual void grow_swap(int); // grow swap, multi, and multi/old arrays
virtual void allocate_swap(int); // allocate swap arrays
virtual void allocate_multi(int); // allocate multi arrays
virtual void allocate_multiold(int); // allocate multi/old arrays
virtual void free_swap(); // free swap arrays
virtual void free_multi(); // free multi arrays
virtual void free_multiold(); // free multi/old arrays
};
} // namespace LAMMPS_NS
#endif
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment