Unverified Commit ca86f720 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent b75ed73c
import math
import torch
@torch.jit.script
def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x) - math.log(2.0)
from typing import List
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, TensorProduct
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .activation import ShiftedSoftPlus
from .util import broadcast
def message_gather(
node_features: torch.Tensor,
edge_dst: torch.Tensor,
message: torch.Tensor
):
index = broadcast(edge_dst, message, 0)
out_shape = [len(node_features)] + list(message.shape[1:])
out = torch.zeros(
out_shape,
dtype=node_features.dtype,
device=node_features.device
)
out.scatter_reduce_(0, index, message, reduce='sum')
return out
@compile_mode('script')
class IrrepsConvolution(nn.Module):
"""
convolution of (fig 2.b), comm. in LAMMPS
"""
def __init__(
self,
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out: Irreps,
weight_layer_input_to_hidden: List[int],
weight_layer_act=ShiftedSoftPlus,
denominator: float = 1.0,
train_denominator: bool = False,
data_key_x: str = KEY.NODE_FEATURE,
data_key_filter: str = KEY.EDGE_ATTR,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
data_key_edge_idx: str = KEY.EDGE_IDX,
lazy_layer_instantiate: bool = True,
is_parallel: bool = False,
):
super().__init__()
self.denominator = nn.Parameter(
torch.FloatTensor([denominator]), requires_grad=train_denominator
)
self.key_x = data_key_x
self.key_filter = data_key_filter
self.key_weight_input = data_key_weight_input
self.key_edge_idx = data_key_edge_idx
self.is_parallel = is_parallel
instructions = []
irreps_mid = []
weight_numel = 0
for i, (mul_x, ir_x) in enumerate(irreps_x):
for j, (_, ir_filter) in enumerate(irreps_filter):
for ir_out in ir_x * ir_filter:
if ir_out in irreps_out: # here we drop l > lmax
k = len(irreps_mid)
weight_numel += mul_x * 1 # path shape
irreps_mid.append((mul_x, ir_out))
instructions.append((i, j, k, 'uvu', True))
irreps_mid = Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort() # type: ignore
instructions = [
(i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
# From v0.11.x, to compatible with cuEquivariance
self._instructions_before_sort = instructions
instructions = sorted(instructions, key=lambda x: x[2])
self.convolution_kwargs = dict(
irreps_in1=irreps_x,
irreps_in2=irreps_filter,
irreps_out=irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.weight_nn_kwargs = dict(
hs=weight_layer_input_to_hidden + [weight_numel],
act=weight_layer_act
)
self.convolution = None
self.weight_nn = None
self.layer_instantiated = False
self.convolution_cls = TensorProduct
self.weight_nn_cls = FullyConnectedNet
if not lazy_layer_instantiate:
self.instantiate()
self._comm_size = irreps_x.dim # used in parallel
def instantiate(self):
if self.convolution is not None:
raise ValueError('Convolution layer already exists')
if self.weight_nn is not None:
raise ValueError('Weight_nn layer already exists')
self.convolution = self.convolution_cls(**self.convolution_kwargs)
self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.convolution is not None, 'Convolution is not instantiated'
assert self.weight_nn is not None, 'Weight_nn is not instantiated'
weight = self.weight_nn(data[self.key_weight_input])
x = data[self.key_x]
if self.is_parallel:
x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]])
# note that 1 -> src 0 -> dst
edge_src = data[self.key_edge_idx][1]
edge_dst = data[self.key_edge_idx][0]
message = self.convolution(x[edge_src], data[self.key_filter], weight)
x = message_gather(x, edge_dst, message)
x = x.div(self.denominator)
if self.is_parallel:
x = torch.tensor_split(x, data[KEY.NLOCAL])[0]
data[self.key_x] = x
return data
import itertools
import warnings
from typing import Iterator, Literal, Union
import e3nn.o3 as o3
import numpy as np
from .convolution import IrrepsConvolution
from .linear import IrrepsLinear
from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
_CUE_AVAILABLE = True
# Obatained from MACE
class O3_e3nn(cue.O3):
def __mul__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> Iterator['O3_e3nn']:
return [ # type: ignore
O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)
]
@classmethod
def clebsch_gordan( # type: ignore
cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn'
) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim
)
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # type: ignore
rep1: 'O3_e3nn', rep2: 'O3_e3nn'
) -> bool:
rep2 = rep1._from(rep2) # type: ignore
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod
def iterator(cls) -> Iterator['O3_e3nn']:
for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
except ImportError:
_CUE_AVAILABLE = False
def is_cue_available():
return _CUE_AVAILABLE
def cue_needed(func):
def wrapper(*args, **kwargs):
if is_cue_available():
return func(*args, **kwargs)
else:
raise ImportError('cue is not available')
return wrapper
def _check_may_not_compatible(orig_kwargs, defaults):
for k, v in defaults.items():
v_given = orig_kwargs.pop(k, v)
if v_given != v:
warnings.warn(f'{k}: {v} is ignored to use cuEquivariance')
def is_cue_cuda_available_model(config):
if config.get('use_bias_in_linear', False):
warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn')
return False
else:
return True
@cue_needed
def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']):
"""Convert e3nn irreps to given group's cue irreps"""
if group == 'SO3':
assert all(irrep.ir.p == 1 for irrep in irreps)
return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore
elif group == 'O3':
return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore
else:
raise ValueError(f'Unknown group: {group}')
@cue_needed
def patch_linear(
module: Union[IrrepsLinear, SelfConnectionLinearIntro],
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
orig_kwargs = module.linear_kwargs
may_not_compatible_default = dict(
f_in=None,
f_out=None,
instructions=None,
biases=False,
path_normalization='element',
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(orig_kwargs, may_not_compatible_default)
module.linear_cls = cuet.Linear # type: ignore
orig_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_convolution(
module: IrrepsConvolution,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
# conv_kwargs will be patched in place
conv_kwargs = module.convolution_kwargs
conv_kwargs.update(
dict(
irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group),
irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group),
filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group),
)
)
inst_orig = conv_kwargs.pop('instructions')
inst_sorted = sorted(inst_orig, key=lambda x: x[2])
assert all([a == b for a, b in zip(inst_orig, inst_sorted)])
may_not_compatible_default = dict(
in1_var=None,
in2_var=None,
out_var=None,
irrep_normalization=False,
path_normalization='element',
compile_left_right=True,
compile_right=False,
_specialized_code=None,
_optimize_einsums=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(conv_kwargs, may_not_compatible_default)
module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore
conv_kwargs.update(**cue_kwargs)
return module
@cue_needed
def patch_fully_connected(
module: SelfConnectionIntro,
group: Literal['SO3', 'O3'],
**cue_kwargs,
):
assert not module.layer_instantiated
module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore
module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore
module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore
may_not_compatible_default = dict(
irrep_normalization=None,
path_normalization=None,
)
# pop may_not_compatible_defaults
_check_may_not_compatible(
module.fc_tensor_product_kwargs, may_not_compatible_default
)
module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore
module.fc_tensor_product_kwargs.update(**cue_kwargs)
return module
import math
import torch
import torch.nn as nn
from e3nn.o3 import Irreps, SphericalHarmonics
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EdgePreprocess(nn.Module):
"""
preprocessing pos to edge vectors and edge lengths
currently used in sevenn/scripts/deploy for lammps serial model
"""
def __init__(self, is_stress: bool):
super().__init__()
# controlled by 'AtomGraphSequential'
self.is_stress = is_stress
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
cell = data[KEY.CELL].view(-1, 3, 3)
else:
cell = data[KEY.CELL].view(3, 3)
cell_shift = data[KEY.CELL_SHIFT]
pos = data[KEY.POS]
batch = data[KEY.BATCH] # for deploy, must be defined first
if self.is_stress:
if self._is_batch_data:
num_batch = int(batch.max().cpu().item()) + 1
strain = torch.zeros(
(num_batch, 3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.bmm(
pos.unsqueeze(-2), sym_strain[batch]
).squeeze(-2)
cell = cell + torch.bmm(cell, sym_strain)
else:
strain = torch.zeros(
(3, 3),
dtype=pos.dtype,
device=pos.device,
)
strain.requires_grad_(True)
data['_strain'] = strain
sym_strain = 0.5 * (strain + strain.transpose(-1, -2))
pos = pos + torch.mm(pos, sym_strain)
cell = cell + torch.mm(cell, sym_strain)
idx_src = data[KEY.EDGE_IDX][0]
idx_dst = data[KEY.EDGE_IDX][1]
edge_vec = pos[idx_dst] - pos[idx_src]
if self._is_batch_data:
edge_vec = edge_vec + torch.einsum(
'ni,nij->nj', cell_shift, cell[batch[idx_src]]
)
else:
edge_vec = edge_vec + torch.einsum(
'ni,ij->nj', cell_shift, cell.squeeze(0)
)
data[KEY.EDGE_VEC] = edge_vec
data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1)
return data
class BesselBasis(nn.Module):
"""
f : (*, 1) -> (*, bessel_basis_num)
"""
def __init__(
self,
cutoff_length: float,
bessel_basis_num: int = 8,
trainable_coeff: bool = True,
):
super().__init__()
self.num_basis = bessel_basis_num
self.prefactor = 2.0 / cutoff_length
self.coeffs = torch.FloatTensor([
n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)
])
if trainable_coeff:
self.coeffs = nn.Parameter(self.coeffs)
def forward(self, r: torch.Tensor) -> torch.Tensor:
ur = r.unsqueeze(-1) # to fit dimension
return self.prefactor * torch.sin(self.coeffs * ur) / ur
class PolynomialCutoff(nn.Module):
"""
f : (*, 1) -> (*, 1)
https://arxiv.org/pdf/2003.03123.pdf
"""
def __init__(
self,
cutoff_length: float,
poly_cut_p_value: int = 6,
):
super().__init__()
p = poly_cut_p_value
self.cutoff_length = cutoff_length
self.p = p
self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0
self.coeff_p1 = p * (p + 2.0)
self.coeff_p2 = p * (p + 1.0) / 2.0
def forward(self, r: torch.Tensor) -> torch.Tensor:
r = r / self.cutoff_length
return (
1
- self.coeff_p0 * torch.pow(r, self.p)
+ self.coeff_p1 * torch.pow(r, self.p + 1.0)
- self.coeff_p2 * torch.pow(r, self.p + 2.0)
)
class XPLORCutoff(nn.Module):
"""
https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html
"""
def __init__(
self,
cutoff_length: float,
cutoff_on: float,
):
super().__init__()
self.r_on = cutoff_on
self.r_cut = cutoff_length
assert self.r_on < self.r_cut
def forward(self, r: torch.Tensor) -> torch.Tensor:
r_sq = r * r
r_on_sq = self.r_on * self.r_on
r_cut_sq = self.r_cut * self.r_cut
return torch.where(
r < self.r_on,
1.0,
(r_cut_sq - r_sq) ** 2
* (r_cut_sq + 2 * r_sq - 3 * r_on_sq)
/ (r_cut_sq - r_on_sq) ** 3,
)
@compile_mode('script')
class SphericalEncoding(nn.Module):
def __init__(
self,
lmax: int,
parity: int = -1,
normalization: str = 'component',
normalize: bool = True,
):
super().__init__()
self.lmax = lmax
self.normalization = normalization
self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e')
self.irreps_out = Irreps.spherical_harmonics(lmax, parity)
self.sph = SphericalHarmonics(
self.irreps_out,
normalize=normalize,
normalization=normalization,
irreps_in=self.irreps_in,
)
def forward(self, r: torch.Tensor) -> torch.Tensor:
return self.sph(r)
@compile_mode('script')
class EdgeEmbedding(nn.Module):
"""
embedding layer of |r| by
RadialBasis(|r|)*CutOff(|r|)
f : (N_edge) -> (N_edge, basis_num)
"""
def __init__(
self,
basis_module: nn.Module,
cutoff_module: nn.Module,
spherical_module: nn.Module,
):
super().__init__()
self.basis_function = basis_module
self.cutoff_function = cutoff_module
self.spherical = spherical_module
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
rvec = data[KEY.EDGE_VEC]
r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1)
data[KEY.EDGE_LENGTH] = r
data[KEY.EDGE_EMBEDDING] = self.basis_function(
r
) * self.cutoff_function(r).unsqueeze(-1)
data[KEY.EDGE_ATTR] = self.spherical(rvec)
return data
from typing import Callable, Dict
import torch.nn as nn
from e3nn.nn import Gate
from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class EquivariantGate(nn.Module):
def __init__(
self,
irreps_x: Irreps,
act_scalar_dict: Dict[int, Callable],
act_gate_dict: Dict[int, Callable],
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
parity_mapper = {'e': 1, 'o': -1}
act_scalar_dict = {
parity_mapper[k]: v for k, v in act_scalar_dict.items()
}
act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()}
irreps_gated_elem = []
irreps_scalars_elem = []
# non scalar irreps > gated / scalar irreps > scalars
for mul, irreps in irreps_x:
if irreps.l > 0:
irreps_gated_elem.append((mul, irreps))
else:
irreps_scalars_elem.append((mul, irreps))
irreps_scalars = Irreps(irreps_scalars_elem)
irreps_gated = Irreps(irreps_gated_elem)
irreps_gates_parity = 1 if '0e' in irreps_scalars else -1
irreps_gates = Irreps(
[(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated]
)
act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars]
act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates]
self.gate = Gate(
irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated
)
def get_gate_irreps_in(self):
"""
user must call this function to get proper irreps in for forward
"""
return self.gate.irreps_in
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = self.gate(data[self.key_x])
return data
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
from .util import broadcast
@compile_mode('script')
class ForceOutput(nn.Module):
"""
works when pos.requires_grad_ is True
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = [data[self.key_pos]]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
pos_tensor,
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)[0]
# For torchscript
if grad is not None:
data[self.key_force] = torch.neg(grad)
return data
@compile_mode('script')
class ForceStressOutput(nn.Module):
"""
Compute stress and force from positions.
Used in serial torchscipt models
"""
def __init__(
self,
data_key_pos: str = KEY.POS,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_pos = data_key_pos
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_pos
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
pos_tensor = data[self.key_pos]
energy = [(data[self.key_energy]).sum()]
# `materialize_grads` not supported in low version of pytorch
# Also can not be deployed when using it.
# But not using it makes problem in
# force/stress inference in sparse systems
# TODO: use it only in sevennet_calculator?
grad = torch.autograd.grad(
energy,
[pos_tensor, data['_strain']],
create_graph=self.training,
allow_unused=True,
# materialize_grads=True,
)
# make grad is not Optional[Tensor]
fgrad = grad[0]
if fgrad is not None:
data[self.key_force] = torch.neg(fgrad)
sgrad = grad[1]
volume = data[self.key_cell_volume]
vlim = 1e-3 # for cell volume = 0 for non PBC structures
if self._is_batch_data:
volume[volume < vlim] = vlim
elif volume < vlim:
volume = torch.tensor(vlim)
if sgrad is not None:
if self._is_batch_data:
stress = sgrad / volume.view(-1, 1, 1)
stress = torch.neg(stress)
virial_stress = torch.vstack((
stress[:, 0, 0],
stress[:, 1, 1],
stress[:, 2, 2],
stress[:, 0, 1],
stress[:, 1, 2],
stress[:, 0, 2],
))
data[self.key_stress] = virial_stress.transpose(0, 1)
else:
stress = sgrad / volume
stress = torch.neg(stress)
virial_stress = torch.stack((
stress[0, 0],
stress[1, 1],
stress[2, 2],
stress[0, 1],
stress[1, 2],
stress[0, 2],
))
data[self.key_stress] = virial_stress
return data
@compile_mode('script')
class ForceStressOutputFromEdge(nn.Module):
"""
Compute stress and force from edge.
Used in parallel torchscipt models, and training
"""
def __init__(
self,
data_key_edge: str = KEY.EDGE_VEC,
data_key_edge_idx: str = KEY.EDGE_IDX,
data_key_energy: str = KEY.PRED_TOTAL_ENERGY,
data_key_force: str = KEY.PRED_FORCE,
data_key_stress: str = KEY.PRED_STRESS,
data_key_cell_volume: str = KEY.CELL_VOLUME,
):
super().__init__()
self.key_edge = data_key_edge
self.key_edge_idx = data_key_edge_idx
self.key_energy = data_key_energy
self.key_force = data_key_force
self.key_stress = data_key_stress
self.key_cell_volume = data_key_cell_volume
self._is_batch_data = True
def get_grad_key(self):
return self.key_edge
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item?
rij = data[self.key_edge]
energy = [(data[self.key_energy]).sum()]
edge_idx = data[self.key_edge_idx]
grad = torch.autograd.grad(
energy,
[rij],
create_graph=self.training,
allow_unused=True
)
# make grad is not Optional[Tensor]
fij = grad[0]
if fij is not None:
# compute force
pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device)
_edge_src = broadcast(edge_idx[0], fij, 0)
_edge_dst = broadcast(edge_idx[1], fij, 0)
pf.scatter_reduce_(0, _edge_src, fij, reduce='sum')
nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum')
data[self.key_force] = pf - nf
# compute virial
diag = rij * fij
s12 = rij[..., 0] * fij[..., 1]
s23 = rij[..., 1] * fij[..., 2]
s31 = rij[..., 2] * fij[..., 0]
# cat last dimension
_virial = torch.cat([
diag,
s12.unsqueeze(-1),
s23.unsqueeze(-1),
s31.unsqueeze(-1)
], dim=-1)
_s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device)
_edge_dst6 = broadcast(edge_idx[1], _virial, 0)
_s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum')
if self._is_batch_data:
batch = data[KEY.BATCH] # for deploy, must be defined first
nbatch = int(batch.max().cpu().item()) + 1
sout = torch.zeros(
(nbatch, 6), dtype=_virial.dtype, device=_virial.device
)
_batch = broadcast(batch, _s, 0)
sout.scatter_reduce_(0, _batch, _s, reduce='sum')
else:
sout = torch.sum(_s, dim=0)
data[self.key_stress] =\
torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1)
return data
from typing import Callable, List, Tuple
from e3nn.o3 import Irreps
import sevenn._keys as KEY
from .convolution import IrrepsConvolution
from .equivariant_gate import EquivariantGate
from .linear import IrrepsLinear
def NequIP_interaction_block(
irreps_x: Irreps,
irreps_filter: Irreps,
irreps_out_tp: Irreps,
irreps_out: Irreps,
weight_nn_layers: List[int],
conv_denominator: float,
train_conv_denominator: bool,
self_connection_pair: Tuple[Callable, Callable],
act_scalar: Callable,
act_gate: Callable,
act_radial: Callable,
bias_in_linear: bool,
num_species: int,
t: int, # interaction layer index
data_key_x: str = KEY.NODE_FEATURE,
data_key_weight_input: str = KEY.EDGE_EMBEDDING,
parallel: bool = False,
**conv_kwargs,
):
block = {}
irreps_node_attr = Irreps(f'{num_species}x0e')
sc_intro, sc_outro = self_connection_pair
gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate)
irreps_for_gate_in = gate_layer.get_gate_irreps_in()
block[f'{t}_self_connection_intro'] = sc_intro(
irreps_x,
irreps_operand=irreps_node_attr,
irreps_out=irreps_for_gate_in,
)
block[f'{t}_self_interaction_1'] = IrrepsLinear(
irreps_x, irreps_x,
data_key_in=data_key_x,
biases=bias_in_linear,
)
# convolution part, l>lmax is dropped as defined in irreps_out
block[f'{t}_convolution'] = IrrepsConvolution(
irreps_x=irreps_x,
irreps_filter=irreps_filter,
irreps_out=irreps_out_tp,
data_key_weight_input=data_key_weight_input,
weight_layer_input_to_hidden=weight_nn_layers,
weight_layer_act=act_radial,
denominator=conv_denominator,
train_denominator=train_conv_denominator,
is_parallel=parallel,
**conv_kwargs,
)
# irreps of x increase to gate_irreps_in
block[f'{t}_self_interaction_2'] = IrrepsLinear(
irreps_out_tp,
irreps_for_gate_in,
data_key_in=data_key_x,
biases=bias_in_linear,
)
block[f'{t}_self_connection_outro'] = sc_outro()
block[f'{t}_equivariant_gate'] = gate_layer
return block
from typing import Callable, List, Optional
import torch
import torch.nn as nn
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class IrrepsLinear(nn.Module):
"""
wrapper class of e3nn Linear to operate on AtomGraphData
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_in: str,
data_key_out: Optional[str] = None,
data_key_modal_attr: str = KEY.MODAL_ATTR,
num_modalities: int = 0,
lazy_layer_instantiate: bool = True,
**linear_kwargs,
):
super().__init__()
self.key_input = data_key_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
self.key_modal_attr = data_key_modal_attr
self._irreps_in_wo_modal = irreps_in
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.linear_kwargs = linear_kwargs
self.linear = None
self.layer_instantiated = False
self.num_modalities = num_modalities
self._is_batch_data = True
# use getter setter
self.linear_cls = Linear
if num_modalities > 1: # in case of multi-modal
self.set_num_modalities(num_modalities)
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def set_num_modalities(self, num_modalities):
if self.layer_instantiated:
raise ValueError('Layer already instantiated, can not change modalities')
irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e')
self.num_modalities = num_modalities
self.irreps_in = irreps_in
def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
batch_modality_onehot = data[self.key_modal_attr].reshape(
-1, self.num_modalities
)
batch_modality_onehot = batch_modality_onehot.type(
data[self.key_input].dtype
)
data[self.key_input] = torch.cat(
[data[self.key_input], batch_modality_onehot[batch]], dim=1
)
else:
modality_onehot = data[self.key_modal_attr].expand(
len(data[self.key_input]), -1
)
modality_onehot = modality_onehot.type(data[self.key_input].dtype)
data[self.key_input] = torch.cat(
[data[self.key_input], modality_onehot], dim=1
)
return data
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
if self.num_modalities > 1:
data = self._patch_modal_to_data(data)
data[self.key_output] = self.linear(data[self.key_input])
return data
@compile_mode('script')
class AtomReduce(nn.Module):
"""
atomic energy -> total energy
constant is multiplied to data
"""
def __init__(
self,
data_key_in: str,
data_key_out: str,
reduce: str = 'sum',
constant: float = 1.0,
):
super().__init__()
self.key_input = data_key_in
self.key_output = data_key_out
self.constant = constant
self.reduce = reduce
# controlled by the upper most wrapper 'AtomGraphSequential'
self._is_batch_data = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
src = data[self.key_input].squeeze(1)
size = int(data[KEY.BATCH].max()) + 1
output = torch.zeros(
(size),
dtype=src.dtype,
device=src.device,
)
output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum')
data[self.key_output] = output * self.constant
else:
data[self.key_output] = torch.sum(data[self.key_input]) * self.constant
return data
@compile_mode('script')
class FCN_e3nn(nn.Module):
"""
wrapper class of e3nn FullyConnectedNet
"""
def __init__(
self,
irreps_in: Irreps, # confirm it is scalar & input size
dim_out: int,
hidden_neurons: List[int],
activation: Callable,
data_key_in: str,
data_key_out: Optional[str] = None,
**e3nn_kwargs,
):
super().__init__()
self.key_input = data_key_in
self.irreps_in = irreps_in
if data_key_out is None:
self.key_output = data_key_in
else:
self.key_output = data_key_out
for _, irrep in irreps_in:
assert irrep.is_scalar()
inp_dim = irreps_in.dim
self.fcn = FullyConnectedNet(
[inp_dim] + hidden_neurons + [dim_out],
activation,
**e3nn_kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = self.fcn(data[self.key_input])
return data
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional
from ase.symbols import symbols2numbers
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
# TODO: put this to model_build and do not preprocess data by onehot
@compile_mode('script')
class OnehotEmbedding(nn.Module):
"""
x : tensor of shape (N, 1)
x_after : tensor of shape (N, num_classes)
It overwrite data_key_x
and saves input to data_key_save and output to data_key_additional
I know this is strange but it is for compatibility with previous version
and to specie wise shift scale work
ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2)
"""
def __init__(
self,
num_classes: int,
data_key_x: str = KEY.NODE_FEATURE,
data_key_out: Optional[str] = None,
data_key_save: Optional[str] = None,
data_key_additional: Optional[str] = None, # additional output
):
super().__init__()
self.num_classes = num_classes
self.key_x = data_key_x
if data_key_out is None:
self.key_output = data_key_x
else:
self.key_output = data_key_out
self.key_save = data_key_save
self.key_additional_output = data_key_additional
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
inp = data[self.key_x]
embd = torch.nn.functional.one_hot(inp, self.num_classes)
embd = embd.float()
data[self.key_output] = embd
if self.key_additional_output is not None:
data[self.key_additional_output] = embd # for self-connection
if self.key_save is not None:
data[self.key_save] = inp # for elemwise shift scale
return data
def get_type_mapper_from_specie(specie_list: List[str]):
"""
from ['Hf', 'O']
return {72: 0, 8: 1}
"""
specie_list = sorted(specie_list)
type_map = {}
unique_counter = 0
for specie in specie_list:
atomic_num = symbols2numbers(specie)[0]
if atomic_num in type_map:
continue
type_map[atomic_num] = unique_counter
unique_counter += 1
return type_map
# deprecated
def one_hot_atom_embedding(
atomic_numbers: List[int], type_map: Dict[int, int]
):
"""
atomic numbers from ase.get_atomic_numbers
type_map from get_type_mapper_from_specie()
"""
num_classes = len(type_map)
try:
type_numbers = torch.LongTensor(
[type_map[num] for num in atomic_numbers]
)
except KeyError as e:
raise ValueError(f'Atomic number {e.args[0]} is not expected')
embd = torch.nn.functional.one_hot(type_numbers, num_classes)
embd = embd.to(torch.get_default_dtype())
return embd
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType
def _as_univ(
ss: List[float], type_map: Dict[int, int], default: float
) -> List[float]:
assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long'
return [
ss[type_map[z]] if z in type_map else default
for z in range(NUM_UNIV_ELEMENT)
]
@compile_mode('script')
class Rescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
"""
def __init__(
self,
shift: float,
scale: float,
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
train_shift_scale: bool = False,
**kwargs,
):
assert isinstance(shift, float) and isinstance(scale, float)
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor([shift]), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor([scale]), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
def get_shift(self) -> float:
return self.shift.detach().cpu().tolist()[0]
def get_scale(self) -> float:
return self.scale.detach().cpu().tolist()[0]
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_output] = data[self.key_input] * self.scale + self.shift
return data
@compile_mode('script')
class SpeciesWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Use as it is if given list, expand to list if one of them is float
If two lists are given and length is not the same, raise error
"""
def __init__(
self,
shift: Union[List[float], float],
scale: Union[List[float], float],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_indices: str = KEY.ATOM_TYPE,
train_shift_scale: bool = False,
):
super().__init__()
assert isinstance(shift, float) or isinstance(shift, list)
assert isinstance(scale, float) or isinstance(scale, list)
if (
isinstance(shift, list)
and isinstance(scale, list)
and len(shift) != len(scale)
):
raise ValueError('List length should be same')
if isinstance(shift, list):
num_species = len(shift)
elif isinstance(scale, list):
num_species = len(scale)
else:
raise ValueError('Both shift and scale is not a list')
shift = [shift] * num_species if isinstance(shift, float) else shift
scale = [scale] * num_species if isinstance(scale, float) else scale
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_indices = data_key_indices
def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return shift in list of float. If type_map is given, return type_map reversed
shift, which index equals atomic_number. 0.0 is assigned for atomis not found
"""
shift = self.shift.detach().cpu().tolist()
if type_map:
shift = _as_univ(shift, type_map, 0.0)
return shift
def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
"""
Return scale in list of float. If type_map is given, return type_map reversed
scale, which index equals atomic_number. 1.0 is assigned for atomis not found
"""
scale = self.scale.detach().cpu().tolist()
if type_map:
scale = _as_univ(scale, type_map, 1.0)
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float]],
scale: Union[float, List[float]],
type_map: Dict[int, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
"""
shift_scale = []
n_atom_types = len(type_map)
for s in (shift, scale):
if isinstance(s, list) and len(s) > n_atom_types:
if len(s) != NUM_UNIV_ELEMENT:
raise ValueError('given shift or scale is strange')
s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])]
# s = [s[z] for z in sorted(type_map, key=type_map.get)]
elif isinstance(s, float):
s = [s] * n_atom_types
elif isinstance(s, list) and len(s) == 1:
s = s * n_atom_types
shift_scale.append(s)
assert all([len(s) == n_atom_types for s in shift_scale])
shift, scale = shift_scale
return SpeciesWiseRescale(shift, scale, **kwargs)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
indices = data[self.key_indices]
data[self.key_output] = data[self.key_input] * self.scale[indices].view(
-1, 1
) + self.shift[indices].view(-1, 1)
return data
@compile_mode('script')
class ModalWiseRescale(nn.Module):
"""
Scaling and shifting energy (and automatically force and stress)
Given shift or scale is either modal-wise and atom-wise or
not modal-wise but atom-wise. It is always interpreted as atom-wise.
"""
def __init__(
self,
shift: List[List[float]],
scale: List[List[float]],
data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
data_key_out: str = KEY.ATOMIC_ENERGY,
data_key_modal_indices: str = KEY.MODAL_TYPE,
data_key_atom_indices: str = KEY.ATOM_TYPE,
use_modal_wise_shift: bool = False,
use_modal_wise_scale: bool = False,
train_shift_scale: bool = False,
):
super().__init__()
self.shift = nn.Parameter(
torch.FloatTensor(shift), requires_grad=train_shift_scale
)
self.scale = nn.Parameter(
torch.FloatTensor(scale), requires_grad=train_shift_scale
)
self.key_input = data_key_in
self.key_output = data_key_out
self.key_atom_indices = data_key_atom_indices
self.key_modal_indices = data_key_modal_indices
self.use_modal_wise_shift = use_modal_wise_shift
self.use_modal_wise_scale = use_modal_wise_scale
self._is_batch_data = True
def get_shift(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise shift: return univ shift
both type_map and modal_map is given and modal wise shift: return fully
resolved modalwise univ shift
"""
shift = self.shift.detach().cpu().tolist()
if type_map and not self.use_modal_wise_shift:
shift = _as_univ(shift, type_map, 0.0)
elif self.use_modal_wise_shift and modal_map and type_map:
shift = [_as_univ(s, type_map, 0.0) for s in shift]
shift = {modal: shift[idx] for modal, idx in modal_map.items()}
return shift
def get_scale(
self,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
) -> Union[List[float], Dict[str, List[float]]]:
"""
Nothing is given: return as it is
type_map is given but not modal wise scale: return univ scale
both type_map and modal_map is given and modal wise scale: return fully
resolved modalwise univ scale
"""
scale = self.scale.detach().cpu().tolist()
if type_map and not self.use_modal_wise_scale:
scale = _as_univ(scale, type_map, 0.0)
elif self.use_modal_wise_scale and modal_map and type_map:
scale = [_as_univ(s, type_map, 0.0) for s in scale]
scale = {modal: scale[idx] for modal, idx in modal_map.items()}
return scale
@staticmethod
def from_mappers(
shift: Union[float, List[float], Dict[str, Any]],
scale: Union[float, List[float], Dict[str, Any]],
use_modal_wise_shift: bool,
use_modal_wise_scale: bool,
type_map: Dict[int, int],
modal_map: Dict[str, int],
**kwargs,
):
"""
Fit dimensions or mapping raw shift scale values to that is valid under
the given type_map: (atomic_numbers -> type_indices)
If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT
, assume it is element-wise list
otherwise, it is modal-wise list
"""
def solve_mapper(arr, map):
# value is attr index and never overlap, key is either 'z' or modal str
return [arr[z] for z in sorted(map, key=lambda x: map[x])]
shift_scale = []
n_atom_types = len(type_map)
n_modals = len(modal_map)
for s, use_mw in (
(shift, use_modal_wise_shift),
(scale, use_modal_wise_scale),
):
# solve elemewise, or broadcast
if isinstance(s, float):
# given, modal-wise: no, elem-wise: no => broadcast
shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,)
res = torch.full(shape, s).tolist() # TODO: w/o torch
elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT:
# given, modal-wise: no, elem-wise: yes(univ) => solve elem map
s = solve_mapper(s, type_map)
res = [s] * n_modals if use_mw else s
elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_modals
and use_mw
):
res = [[v] * n_atom_types for v in s]
elif ( # given, modal-wise: no, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], float)
and len(s) == n_atom_types
and not use_mw
):
res = s
elif ( # given, modal-wise: yes, elem-wise: yes => as it is
isinstance(s, list)
and isinstance(s[0], list)
and len(s) == n_modals
and len(s[0]) == n_atom_types
and use_mw
):
res = s
elif isinstance(s, dict) and use_mw:
# solve modal dict, modal-wise: yes
s = solve_mapper(s, modal_map)
res = []
for v in s:
if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT:
# elem-wise: yes(univ) => solve elem map
v = solve_mapper(v, type_map)
elif isinstance(v, float):
# elem-wise: no => broadcast to elemwise
v = [v] * n_atom_types
else:
raise ValueError(f'Invalid shift or scale {s}')
res.append(v)
else:
raise ValueError(f'Invalid shift or scale {s}')
if use_mw:
assert (
isinstance(res, list)
and isinstance(res[0], list)
and len(res) == n_modals
)
assert all([len(r) == n_atom_types for r in res]) # type: ignore
else:
assert (
isinstance(res, list)
and isinstance(res[0], float)
and len(res) == n_atom_types
)
shift_scale.append(res)
shift, scale = shift_scale
return ModalWiseRescale(
shift,
scale,
use_modal_wise_shift=use_modal_wise_shift,
use_modal_wise_scale=use_modal_wise_scale,
**kwargs,
)
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self._is_batch_data:
batch = data[KEY.BATCH]
modal_indices = data[self.key_modal_indices][batch]
else:
modal_indices = data[self.key_modal_indices]
atom_indices = data[self.key_atom_indices]
shift = (
self.shift[modal_indices, atom_indices]
if self.use_modal_wise_shift
else self.shift[atom_indices]
)
scale = (
self.scale[modal_indices, atom_indices]
if self.use_modal_wise_scale
else self.scale[atom_indices]
)
data[self.key_output] = data[self.key_input] * scale.view(
-1, 1
) + shift.view(-1, 1)
return data
def get_resolved_shift_scale(
module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale],
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
):
"""
Return resolved shift and scale from scale modules. For element wise case,
convert to list of floats where idx is atomic number. For modal wise case, return
dictionary of shift scale where key is modal name given in modal_map
Return:
Tuple of solved shift and scale
"""
if isinstance(module, Rescale):
return (module.get_shift(), module.get_scale())
elif isinstance(module, SpeciesWiseRescale):
return (module.get_shift(type_map), module.get_scale(type_map))
elif isinstance(module, ModalWiseRescale):
return (
module.get_shift(type_map, modal_map),
module.get_scale(type_map, modal_map),
)
raise ValueError('Not scale module')
import torch.nn as nn
from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
@compile_mode('script')
class SelfConnectionIntro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
irreps_in: Irreps,
irreps_operand: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
data_key_operand: str = KEY.NODE_ATTR,
lazy_layer_instantiate: bool = True,
**kwargs, # for compatibility
):
super().__init__()
self.fc_tensor_product = FullyConnectedTensorProduct(
irreps_in, irreps_operand, irreps_out
)
self.irreps_in1 = irreps_in
self.irreps_in2 = irreps_operand
self.irreps_out = irreps_out
self.key_x = data_key_x
self.key_operand = data_key_operand
self.fc_tensor_product = None
self.layer_instantiated = False
self.fc_tensor_product_cls = FullyConnectedTensorProduct
self.fc_tensor_product_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.fc_tensor_product is not None:
raise ValueError('fc_tensor_product layer already exists')
self.fc_tensor_product = self.fc_tensor_product_cls(
self.irreps_in1,
self.irreps_in2,
self.irreps_out,
shared_weights=True,
internal_weights=None, # same as True
**self.fc_tensor_product_kwargs,
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.fc_tensor_product is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product(
data[self.key_x], data[self.key_operand]
)
return data
@compile_mode('script')
class SelfConnectionLinearIntro(nn.Module):
"""
Linear style self connection update
"""
def __init__(
self,
irreps_in: Irreps,
irreps_out: Irreps,
data_key_x: str = KEY.NODE_FEATURE,
lazy_layer_instantiate: bool = True,
**kwargs,
):
super().__init__()
self.irreps_in = irreps_in
self.irreps_out = irreps_out
self.key_x = data_key_x
self.linear = None
self.layer_instantiated = False
self.linear_cls = Linear
# TODO: better to have SelfConnectionIntro super class
kwargs.pop('irreps_operand')
self.linear_kwargs = kwargs
if not lazy_layer_instantiate:
self.instantiate()
def instantiate(self):
if self.linear is not None:
raise ValueError('Linear layer already exists')
self.linear = self.linear_cls(
self.irreps_in, self.irreps_out, **self.linear_kwargs
)
self.layer_instantiated = True
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
assert self.linear is not None, 'Layer is not instantiated'
data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x])
return data
@compile_mode('script')
class SelfConnectionOutro(nn.Module):
"""
do TensorProduct of x and some data(here attribute of x)
and save it (to concatenate updated x at SelfConnectionOutro)
"""
def __init__(
self,
data_key_x: str = KEY.NODE_FEATURE,
):
super().__init__()
self.key_x = data_key_x
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP]
del data[KEY.SELF_CONNECTION_TEMP]
return data
import warnings
from collections import OrderedDict
from typing import Dict, Optional
import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode
import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType
def _instantiate_modules(modules):
# see IrrepsLinear of linear.py
for module in modules.values():
if not getattr(module, 'layer_instantiated', True):
module.instantiate()
@compile_mode('script')
class _ModalInputPrepare(nn.Module):
def __init__(
self,
modal_idx: int
):
super().__init__()
self.modal_idx = modal_idx
def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
data[KEY.MODAL_TYPE] = torch.tensor(
self.modal_idx,
dtype=torch.int64,
device=data['x'].device,
)
return data
@compile_mode('script')
class AtomGraphSequential(nn.Sequential):
"""
Wrapper of SevenNet model
Args:
modules: OrderedDict of nn.Modules
cutoff: not used internally, but makes sense to have
type_map: atomic_numbers => onehot index (see nn/node_embedding.py)
eval_type_map: perform index mapping using type_map defaults to True
data_key_atomic_numbers: used when eval_type_map is True
data_key_node_feature: used when eval_type_map is True
data_key_grad: if given, sets its requires grad True before pred
"""
def __init__(
self,
modules: Dict[str, nn.Module],
cutoff: float = 0.0,
type_map: Optional[Dict[int, int]] = None,
modal_map: Optional[Dict[str, int]] = None,
eval_type_map: bool = True,
eval_modal_map: bool = False,
data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS,
data_key_node_feature: str = KEY.NODE_FEATURE,
data_key_grad: Optional[str] = None,
):
if not isinstance(modules, OrderedDict): # backward compat
modules = OrderedDict(modules)
self.cutoff = cutoff
self.type_map = type_map
self.eval_type_map = eval_type_map
self.is_batch_data = True
if cutoff == 0.0:
warnings.warn('cutoff is 0.0 or not given', UserWarning)
if self.type_map is None:
warnings.warn('type_map is not given', UserWarning)
self.eval_type_map = False
else:
z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long))
for z, onehot in self.type_map.items():
z_to_onehot_tensor[z] = onehot
self.z_to_onehot_tensor = z_to_onehot_tensor
if eval_modal_map and modal_map is None:
raise ValueError('eval_modal_map is True but modal_map is None')
self.eval_modal_map = eval_modal_map
self.modal_map = modal_map
self.key_atomic_numbers = data_key_atomic_numbers
self.key_node_feature = data_key_node_feature
self.key_grad = data_key_grad
_instantiate_modules(modules)
super().__init__(modules)
if not isinstance(self._modules, OrderedDict): # backward compat
self._modules = OrderedDict(self._modules)
def set_is_batch_data(self, flag: bool):
# whether given data is batched or not some module have to change
# its behavior. checking whether data is batched or not inside
# forward function make problem harder when make it into torchscript
for module in self:
try: # Easier to ask for forgiveness than permission.
module._is_batch_data = flag # type: ignore
except AttributeError:
pass
self.is_batch_data = flag
def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'):
tg_module = self._modules[modlue_name]
for m in tg_module.modules():
try:
return repr(m.__getattribute__(attr_key))
except AttributeError:
pass
return None
def prepand_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
self._modules.move_to_end(key, last=False) # type: ignore
def replace_module(self, key: str, module: nn.Module):
self._modules.update({key: module})
def delete_module_by_key(self, key: str):
if key in self._modules.keys():
del self._modules[key]
@torch.jit.unused
def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor):
assert atomic_numbers.dtype == torch.int64
device = atomic_numbers.device
z_to_onehot_tensor = self.z_to_onehot_tensor.to(device)
return torch.index_select(
input=z_to_onehot_tensor, dim=0, index=atomic_numbers
)
@torch.jit.unused
def _eval_modal_map(self, data: AtomGraphDataType):
assert self.modal_map is not None
# modal_map: dict[str, int]
if not self.is_batch_data:
modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore
else:
modal_idx = [
self.modal_map[ii] # type: ignore
for ii in data[KEY.DATA_MODALITY]
]
modal_idx = torch.tensor(
modal_idx,
dtype=torch.int64,
device=data.x.device, # type: ignore
)
data[KEY.MODAL_TYPE] = modal_idx
def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType:
if self.eval_type_map:
atomic_numbers = data[self.key_atomic_numbers]
onehot = self._atomic_numbers_to_onehot(atomic_numbers)
data[self.key_node_feature] = onehot
if self.eval_modal_map:
self._eval_modal_map(data)
if self.key_grad is not None:
data[self.key_grad].requires_grad_(True)
return data
def prepare_modal_deploy(self, modal: str):
if self.modal_map is None:
return
self.eval_modal_map = False
self.set_is_batch_data(False)
modal_idx = self.modal_map[modal] # type: ignore
self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx))
def forward(self, input: AtomGraphDataType) -> AtomGraphDataType:
data = self._preprocess(input)
for module in self:
data = module(data)
return data
import torch
def broadcast(
src: torch.Tensor,
other: torch.Tensor,
dim: int
):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
We support the LAMMPS pair style `d3` of the Grimme's D3 dispersion (van der Waals) correction scheme accelerated with CUDA, which can be used within LAMMPS in conjunction with SevenNet.
**PLEASE NOTE:** Currently, this D3 code does not support mulit-GPU parallelism yet. So it can only be run on a single GPU.
# About Grimme's D3 code accelerated with CUDA
This is LAMMPS implementation of [Grimme's D3 method](https://doi.org/10.1063/1.3382344). We have ported the code from the [original fortran code](https://www.chemie.uni-bonn.de/grimme/de/software/dft-d3) to a LAMMPS pair style written in CUDA/C++.
While D3 method is significantly faster than DFT, existing CPU implementations were slower than SevenNet. To address this, we have adopted CUDA and single precision (FP32) operations to accelerate the code.
## Installation for LAMMPS
Simply run,
```bash
sevenn_patch_lammps ./lammps_sevenn --d3
```
You can follow the remaining installation steps in the [SevenNet documentation](../../README.md#installation-for-lammps).
Also, this code requires a GPU with a compute capability of **at least 6.0**. If you try to compile it with version 5.0, you may encounter an `atomicAdd` error.
The target compute capability of this code follows the setting of LibTorch in SevenNet, except for version 5.0.
You can manually select the target capability using the `TORCH_CUDA_ARCH_LIST` environment variable. For example, you can use: `export TORCH_CUDA_ARCH_LIST="6.1;7.0;8.0;8.6;8.9;9.0"`.
## Usage for LAMMPS
You can use the D3 dispersion correction in LAMMPS with SevenNet through the `pair/hybrid` command:
```txt
pair_style hybrid/overlay e3gnn d3 {cutoff_d3_r} {cutoff_d3_cn} {type_of_damping} {name_of_functional}
pair_coeff * * e3gnn {path_to_serial_model} {space_separated_chemical_species}
pair_coeff * * d3 {space_separated_chemical_species}
```
for example,
```txt
pair_style hybrid/overlay e3gnn d3 9000 1600 damp_bj pbe
pair_coeff * * e3gnn ./deployed_serial.pt C H O
pair_coeff * * d3 C H O
```
`cutoff_d3_r` and `cutoff_d3_cn` are square of cutoff radii for energy/force and coordination number, respectively. Units are Bohr radius: 1 (Bohr radius) = 0.52917721 (Å). Default values are `9000` and `1600`, respectively. this is also the default values used in VASP.[^1]
Available `type_of_damping` are as follows:
- `damp_zero`: Zero damping
- `damp_bj`: Becke-Johnson damping
Available `name_of_functional` options are the same as in the original Fortran code. SevenNet-0 is trained on the 'PBE' functional, so you should specify 'pbe' in the script when using it. For other supporting functionals, check 'List of parametrized functionals' in [here](https://www.chemie.uni-bonn.de/grimme/de/software/dft-d3).
## Features
- Selective(or no) periodic boundary condition: implemented, But only PBC/noPBC can be checked through original FORTRAN code; selective PBC cannot
- 3-body term, n > 8 term: not implemented (as to VASP)
- Modified versions of zero and bj damping
## Cautions
- It can be slower than the CPU with a small number of atoms.
- The maximum number of atoms that can be calculated is 46,340 (overflow issue).
- There can be occurred small amounts of numerical error
- The introduction of some FP32 operations can lead to minor numerical errors, particularly in pressure calculations, but these are generally smaller than those seen with SevenNet.
- If the error is too large, ensure that the `fmad=false` option in `patch_lammps.sh` is correctly applied during build.
## To do
- Remove atom_modify / compute virial dependency.
- Add support for ASE as calculator interface.
- Add support for multi GPUs (with `e3gnn/parallel`).
- Implement without Unified Memory.
- Unfix the `threadsPerBlock=128`.
- Unroll the repetition loop `k` (for small number of atoms).
## Contributors
- Hyungmin An: Ported the original Fortran D3 code to C++ with OpenMP and MPI.
- Gijin Kim: Accelerated the C++ D3 code with OpenACC[^2] and CUDA, and currently maintains it.
[^1]: On the [VASP DFT-D3](https://www.vasp.at/wiki/index.php/DFT-D3) page, the `VDW_RADIUS` and `VDW_CNRADIUS` are `50.2` and `20.0`, respectively (units are Å). However, when running VASP 6.3.2 with D3 using zero damping (BJ does not provide such a log), the default values in the OUTCAR file are `50.2022` and `21.1671`. These values are the same as our defaults.
[^2]: Since OpenACC is not compatible with libtorch, we chose to use the CUDA.
// clang-format off
/* ----------------------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories
LAMMPS development team: developers@lammps.org
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
/* ----------------------------------------------------------------------
Contributing author (triclinic) : Pieter in 't Veld (SNL)
------------------------------------------------------------------------- */
#include "comm_brick.h"
#include "atom.h"
#include "atom_vec.h"
#include "bond.h"
#include "compute.h"
#include "domain.h"
#include "dump.h"
#include "error.h"
#include "fix.h"
#include "memory.h"
#include "neighbor.h"
#include "pair.h"
#include <cmath>
#include <cstring>
#include "pair_e3gnn_parallel.h"
using namespace LAMMPS_NS;
#define BUFFACTOR 1.5
#define BUFMIN 1024
#define BIG 1.0e20
/* ---------------------------------------------------------------------- */
CommBrick::CommBrick(LAMMPS *lmp) :
Comm(lmp),
sendnum(nullptr), recvnum(nullptr), sendproc(nullptr), recvproc(nullptr),
size_forward_recv(nullptr), size_reverse_send(nullptr), size_reverse_recv(nullptr),
slablo(nullptr), slabhi(nullptr), multilo(nullptr), multihi(nullptr),
multioldlo(nullptr), multioldhi(nullptr), cutghostmulti(nullptr), cutghostmultiold(nullptr),
pbc_flag(nullptr), pbc(nullptr), firstrecv(nullptr), sendlist(nullptr),
localsendlist(nullptr), maxsendlist(nullptr), buf_send(nullptr), buf_recv(nullptr)
{
style = Comm::BRICK;
layout = Comm::LAYOUT_UNIFORM;
pbc_flag = nullptr;
init_buffers();
}
/* ---------------------------------------------------------------------- */
CommBrick::~CommBrick()
{
CommBrick::free_swap();
if (mode == Comm::MULTI) {
CommBrick::free_multi();
memory->destroy(cutghostmulti);
}
if (mode == Comm::MULTIOLD) {
CommBrick::free_multiold();
memory->destroy(cutghostmultiold);
}
if (sendlist) for (int i = 0; i < maxswap; i++) memory->destroy(sendlist[i]);
if (localsendlist) memory->destroy(localsendlist);
memory->sfree(sendlist);
memory->destroy(maxsendlist);
memory->destroy(buf_send);
memory->destroy(buf_recv);
}
/* ---------------------------------------------------------------------- */
//IMPORTANT: we *MUST* pass "*oldcomm" to the Comm initializer here, as
// the code below *requires* that the (implicit) copy constructor
// for Comm is run and thus creating a shallow copy of "oldcomm".
// The call to Comm::copy_arrays() then converts the shallow copy
// into a deep copy of the class with the new layout.
CommBrick::CommBrick(LAMMPS * /*lmp*/, Comm *oldcomm) : Comm(*oldcomm)
{
if (oldcomm->layout == Comm::LAYOUT_TILED)
error->all(FLERR,"Cannot change to comm_style brick from tiled layout");
style = Comm::BRICK;
layout = oldcomm->layout;
Comm::copy_arrays(oldcomm);
init_buffers();
}
/* ----------------------------------------------------------------------
initialize comm buffers and other data structs local to CommBrick
------------------------------------------------------------------------- */
void CommBrick::init_buffers()
{
multilo = multihi = nullptr;
cutghostmulti = nullptr;
multioldlo = multioldhi = nullptr;
cutghostmultiold = nullptr;
buf_send = buf_recv = nullptr;
maxsend = maxrecv = BUFMIN;
CommBrick::grow_send(maxsend,2);
memory->create(buf_recv,maxrecv,"comm:buf_recv");
nswap = 0;
maxswap = 6;
CommBrick::allocate_swap(maxswap);
sendlist = (int **) memory->smalloc(maxswap*sizeof(int *),"comm:sendlist");
memory->create(maxsendlist,maxswap,"comm:maxsendlist");
for (int i = 0; i < maxswap; i++) {
maxsendlist[i] = BUFMIN;
memory->create(sendlist[i],BUFMIN,"comm:sendlist[i]");
}
}
/* ---------------------------------------------------------------------- */
void CommBrick::init()
{
Comm::init();
int bufextra_old = bufextra;
init_exchange();
if (bufextra > bufextra_old) grow_send(maxsend+bufextra,2);
// memory for multi style communication
// allocate in setup
if (mode == Comm::MULTI) {
// If inconsitent # of collections, destroy any preexisting arrays (may be missized)
if (ncollections != neighbor->ncollections) {
ncollections = neighbor->ncollections;
if (multilo != nullptr) {
free_multi();
memory->destroy(cutghostmulti);
}
}
// delete any old user cutoffs if # of collections chanaged
if (cutusermulti && ncollections != ncollections_cutoff) {
if(me == 0) error->warning(FLERR, "cutoff/multi settings discarded, must be defined"
" after customizing collections in neigh_modify");
memory->destroy(cutusermulti);
cutusermulti = nullptr;
}
if (multilo == nullptr) {
allocate_multi(maxswap);
memory->create(cutghostmulti,ncollections,3,"comm:cutghostmulti");
}
}
if ((mode == Comm::SINGLE || mode == Comm::MULTIOLD) && multilo) {
free_multi();
memory->destroy(cutghostmulti);
}
// memory for multi/old-style communication
if (mode == Comm::MULTIOLD && multioldlo == nullptr) {
allocate_multiold(maxswap);
memory->create(cutghostmultiold,atom->ntypes+1,3,"comm:cutghostmultiold");
}
if ((mode == Comm::SINGLE || mode == Comm::MULTI) && multioldlo) {
free_multiold();
memory->destroy(cutghostmultiold);
}
}
/* ----------------------------------------------------------------------
setup spatial-decomposition communication patterns
function of neighbor cutoff(s) & cutghostuser & current box size
single mode sets slab boundaries (slablo,slabhi) based on max cutoff
multi mode sets collection-dependent slab boundaries (multilo,multihi)
multi/old mode sets type-dependent slab boundaries (multioldlo,multioldhi)
------------------------------------------------------------------------- */
void CommBrick::setup()
{
// cutghost[] = max distance at which ghost atoms need to be acquired
// for orthogonal:
// cutghost is in box coords = neigh->cutghost in all 3 dims
// for triclinic:
// neigh->cutghost = distance between tilted planes in box coords
// cutghost is in lamda coords = distance between those planes
// for multi:
// cutghostmulti = same as cutghost, only for each atom collection
// for multi/old:
// cutghostmultiold = same as cutghost, only for each atom type
int i,j;
int ntypes = atom->ntypes;
double *prd,*sublo,*subhi;
double cut = get_comm_cutoff();
if ((cut == 0.0) && (me == 0))
error->warning(FLERR,"Communication cutoff is 0.0. No ghost atoms "
"will be generated. Atoms may get lost.");
if (mode == Comm::MULTI) {
double **cutcollectionsq = neighbor->cutcollectionsq;
// build collection array for atom exchange
neighbor->build_collection(0);
// If using multi/reduce, communicate particles a distance equal
// to the max cutoff with equally sized or smaller collections
// If not, communicate the maximum cutoff of the entire collection
for (i = 0; i < ncollections; i++) {
if (cutusermulti) {
cutghostmulti[i][0] = cutusermulti[i];
cutghostmulti[i][1] = cutusermulti[i];
cutghostmulti[i][2] = cutusermulti[i];
} else {
cutghostmulti[i][0] = 0.0;
cutghostmulti[i][1] = 0.0;
cutghostmulti[i][2] = 0.0;
}
for (j = 0; j < ncollections; j++){
if (multi_reduce && (cutcollectionsq[j][j] > cutcollectionsq[i][i])) continue;
cutghostmulti[i][0] = MAX(cutghostmulti[i][0],sqrt(cutcollectionsq[i][j]));
cutghostmulti[i][1] = MAX(cutghostmulti[i][1],sqrt(cutcollectionsq[i][j]));
cutghostmulti[i][2] = MAX(cutghostmulti[i][2],sqrt(cutcollectionsq[i][j]));
}
}
}
if (mode == Comm::MULTIOLD) {
double *cuttype = neighbor->cuttype;
for (i = 1; i <= ntypes; i++) {
double tmp = 0.0;
if (cutusermultiold) tmp = cutusermultiold[i];
cutghostmultiold[i][0] = MAX(tmp,cuttype[i]);
cutghostmultiold[i][1] = MAX(tmp,cuttype[i]);
cutghostmultiold[i][2] = MAX(tmp,cuttype[i]);
}
}
if (triclinic == 0) {
prd = domain->prd;
sublo = domain->sublo;
subhi = domain->subhi;
cutghost[0] = cutghost[1] = cutghost[2] = cut;
} else {
prd = domain->prd_lamda;
sublo = domain->sublo_lamda;
subhi = domain->subhi_lamda;
double *h_inv = domain->h_inv;
double length0,length1,length2;
length0 = sqrt(h_inv[0]*h_inv[0] + h_inv[5]*h_inv[5] + h_inv[4]*h_inv[4]);
cutghost[0] = cut * length0;
length1 = sqrt(h_inv[1]*h_inv[1] + h_inv[3]*h_inv[3]);
cutghost[1] = cut * length1;
length2 = h_inv[2];
cutghost[2] = cut * length2;
if (mode == Comm::MULTI) {
for (i = 0; i < ncollections; i++) {
cutghostmulti[i][0] *= length0;
cutghostmulti[i][1] *= length1;
cutghostmulti[i][2] *= length2;
}
}
if (mode == Comm::MULTIOLD) {
for (i = 1; i <= ntypes; i++) {
cutghostmultiold[i][0] *= length0;
cutghostmultiold[i][1] *= length1;
cutghostmultiold[i][2] *= length2;
}
}
}
// recvneed[idim][0/1] = # of procs away I recv atoms from, within cutghost
// 0 = from left, 1 = from right
// do not cross non-periodic boundaries, need[2] = 0 for 2d
// sendneed[idim][0/1] = # of procs away I send atoms to
// 0 = to left, 1 = to right
// set equal to recvneed[idim][1/0] of neighbor proc
// maxneed[idim] = max procs away any proc recvs atoms in either direction
// layout = UNIFORM = uniform sized sub-domains:
// maxneed is directly computable from sub-domain size
// limit to procgrid-1 for non-PBC
// recvneed = maxneed except for procs near non-PBC
// sendneed = recvneed of neighbor on each side
// layout = NONUNIFORM = non-uniform sized sub-domains:
// compute recvneed via updown() which accounts for non-PBC
// sendneed = recvneed of neighbor on each side
// maxneed via Allreduce() of recvneed
int *periodicity = domain->periodicity;
int left,right;
if (layout == Comm::LAYOUT_UNIFORM) {
maxneed[0] = static_cast<int> (cutghost[0] * procgrid[0] / prd[0]) + 1;
maxneed[1] = static_cast<int> (cutghost[1] * procgrid[1] / prd[1]) + 1;
maxneed[2] = static_cast<int> (cutghost[2] * procgrid[2] / prd[2]) + 1;
if (domain->dimension == 2) maxneed[2] = 0;
if (!periodicity[0]) maxneed[0] = MIN(maxneed[0],procgrid[0]-1);
if (!periodicity[1]) maxneed[1] = MIN(maxneed[1],procgrid[1]-1);
if (!periodicity[2]) maxneed[2] = MIN(maxneed[2],procgrid[2]-1);
if (!periodicity[0]) {
recvneed[0][0] = MIN(maxneed[0],myloc[0]);
recvneed[0][1] = MIN(maxneed[0],procgrid[0]-myloc[0]-1);
left = myloc[0] - 1;
if (left < 0) left = procgrid[0] - 1;
sendneed[0][0] = MIN(maxneed[0],procgrid[0]-left-1);
right = myloc[0] + 1;
if (right == procgrid[0]) right = 0;
sendneed[0][1] = MIN(maxneed[0],right);
} else recvneed[0][0] = recvneed[0][1] =
sendneed[0][0] = sendneed[0][1] = maxneed[0];
if (!periodicity[1]) {
recvneed[1][0] = MIN(maxneed[1],myloc[1]);
recvneed[1][1] = MIN(maxneed[1],procgrid[1]-myloc[1]-1);
left = myloc[1] - 1;
if (left < 0) left = procgrid[1] - 1;
sendneed[1][0] = MIN(maxneed[1],procgrid[1]-left-1);
right = myloc[1] + 1;
if (right == procgrid[1]) right = 0;
sendneed[1][1] = MIN(maxneed[1],right);
} else recvneed[1][0] = recvneed[1][1] =
sendneed[1][0] = sendneed[1][1] = maxneed[1];
if (!periodicity[2]) {
recvneed[2][0] = MIN(maxneed[2],myloc[2]);
recvneed[2][1] = MIN(maxneed[2],procgrid[2]-myloc[2]-1);
left = myloc[2] - 1;
if (left < 0) left = procgrid[2] - 1;
sendneed[2][0] = MIN(maxneed[2],procgrid[2]-left-1);
right = myloc[2] + 1;
if (right == procgrid[2]) right = 0;
sendneed[2][1] = MIN(maxneed[2],right);
} else recvneed[2][0] = recvneed[2][1] =
sendneed[2][0] = sendneed[2][1] = maxneed[2];
} else {
recvneed[0][0] = updown(0,0,myloc[0],prd[0],periodicity[0],xsplit);
recvneed[0][1] = updown(0,1,myloc[0],prd[0],periodicity[0],xsplit);
left = myloc[0] - 1;
if (left < 0) left = procgrid[0] - 1;
sendneed[0][0] = updown(0,1,left,prd[0],periodicity[0],xsplit);
right = myloc[0] + 1;
if (right == procgrid[0]) right = 0;
sendneed[0][1] = updown(0,0,right,prd[0],periodicity[0],xsplit);
recvneed[1][0] = updown(1,0,myloc[1],prd[1],periodicity[1],ysplit);
recvneed[1][1] = updown(1,1,myloc[1],prd[1],periodicity[1],ysplit);
left = myloc[1] - 1;
if (left < 0) left = procgrid[1] - 1;
sendneed[1][0] = updown(1,1,left,prd[1],periodicity[1],ysplit);
right = myloc[1] + 1;
if (right == procgrid[1]) right = 0;
sendneed[1][1] = updown(1,0,right,prd[1],periodicity[1],ysplit);
if (domain->dimension == 3) {
recvneed[2][0] = updown(2,0,myloc[2],prd[2],periodicity[2],zsplit);
recvneed[2][1] = updown(2,1,myloc[2],prd[2],periodicity[2],zsplit);
left = myloc[2] - 1;
if (left < 0) left = procgrid[2] - 1;
sendneed[2][0] = updown(2,1,left,prd[2],periodicity[2],zsplit);
right = myloc[2] + 1;
if (right == procgrid[2]) right = 0;
sendneed[2][1] = updown(2,0,right,prd[2],periodicity[2],zsplit);
} else recvneed[2][0] = recvneed[2][1] =
sendneed[2][0] = sendneed[2][1] = 0;
int all[6];
MPI_Allreduce(&recvneed[0][0],all,6,MPI_INT,MPI_MAX,world);
maxneed[0] = MAX(all[0],all[1]);
maxneed[1] = MAX(all[2],all[3]);
maxneed[2] = MAX(all[4],all[5]);
}
// allocate comm memory
nswap = 2 * (maxneed[0]+maxneed[1]+maxneed[2]);
if (nswap > maxswap) grow_swap(nswap);
// setup parameters for each exchange:
// sendproc = proc to send to at each swap
// recvproc = proc to recv from at each swap
// for mode SINGLE:
// slablo/slabhi = boundaries for slab of atoms to send at each swap
// use -BIG/midpt/BIG to ensure all atoms included even if round-off occurs
// if round-off, atoms recvd across PBC can be < or > than subbox boundary
// note that borders() only loops over subset of atoms during each swap
// treat all as PBC here, non-PBC is handled in borders() via r/s need[][]
// for mode MULTI:
// multilo/multihi is same, with slablo/slabhi for each atom type
// pbc_flag: 0 = nothing across a boundary, 1 = something across a boundary
// pbc = -1/0/1 for PBC factor in each of 3/6 orthogonal/triclinic dirs
// for triclinic, slablo/hi and pbc_border will be used in lamda (0-1) coords
// 1st part of if statement is sending to the west/south/down
// 2nd part of if statement is sending to the east/north/up
int dim,ineed;
int iswap = 0;
for (dim = 0; dim < 3; dim++) {
for (ineed = 0; ineed < 2*maxneed[dim]; ineed++) {
pbc_flag[iswap] = 0;
pbc[iswap][0] = pbc[iswap][1] = pbc[iswap][2] =
pbc[iswap][3] = pbc[iswap][4] = pbc[iswap][5] = 0;
if (ineed % 2 == 0) {
sendproc[iswap] = procneigh[dim][0];
recvproc[iswap] = procneigh[dim][1];
if (mode == Comm::SINGLE) {
if (ineed < 2) slablo[iswap] = -BIG;
else slablo[iswap] = 0.5 * (sublo[dim] + subhi[dim]);
slabhi[iswap] = sublo[dim] + cutghost[dim];
} else if (mode == Comm::MULTI) {
for (i = 0; i < ncollections; i++) {
if (ineed < 2) multilo[iswap][i] = -BIG;
else multilo[iswap][i] = 0.5 * (sublo[dim] + subhi[dim]);
multihi[iswap][i] = sublo[dim] + cutghostmulti[i][dim];
}
} else {
for (i = 1; i <= ntypes; i++) {
if (ineed < 2) multioldlo[iswap][i] = -BIG;
else multioldlo[iswap][i] = 0.5 * (sublo[dim] + subhi[dim]);
multioldhi[iswap][i] = sublo[dim] + cutghostmultiold[i][dim];
}
}
if (myloc[dim] == 0) {
pbc_flag[iswap] = 1;
pbc[iswap][dim] = 1;
if (triclinic) {
if (dim == 1) pbc[iswap][5] = 1;
else if (dim == 2) pbc[iswap][4] = pbc[iswap][3] = 1;
}
}
} else {
sendproc[iswap] = procneigh[dim][1];
recvproc[iswap] = procneigh[dim][0];
if (mode == Comm::SINGLE) {
slablo[iswap] = subhi[dim] - cutghost[dim];
if (ineed < 2) slabhi[iswap] = BIG;
else slabhi[iswap] = 0.5 * (sublo[dim] + subhi[dim]);
} else if (mode == Comm::MULTI) {
for (i = 0; i < ncollections; i++) {
multilo[iswap][i] = subhi[dim] - cutghostmulti[i][dim];
if (ineed < 2) multihi[iswap][i] = BIG;
else multihi[iswap][i] = 0.5 * (sublo[dim] + subhi[dim]);
}
} else {
for (i = 1; i <= ntypes; i++) {
multioldlo[iswap][i] = subhi[dim] - cutghostmultiold[i][dim];
if (ineed < 2) multioldhi[iswap][i] = BIG;
else multioldhi[iswap][i] = 0.5 * (sublo[dim] + subhi[dim]);
}
}
if (myloc[dim] == procgrid[dim]-1) {
pbc_flag[iswap] = 1;
pbc[iswap][dim] = -1;
if (triclinic) {
if (dim == 1) pbc[iswap][5] = -1;
else if (dim == 2) pbc[iswap][4] = pbc[iswap][3] = -1;
}
}
}
iswap++;
}
}
}
/* ----------------------------------------------------------------------
walk up/down the extent of nearby processors in dim and dir
loc = myloc of proc to start at
dir = 0/1 = walk to left/right
do not cross non-periodic boundaries
is not called for z dim in 2d
return how many procs away are needed to encompass cutghost away from loc
------------------------------------------------------------------------- */
int CommBrick::updown(int dim, int dir, int loc, double prd, int periodicity, double *split)
{
int index,count;
double frac,delta;
if (dir == 0) {
frac = cutghost[dim]/prd;
index = loc - 1;
delta = 0.0;
count = 0;
while (delta < frac) {
if (index < 0) {
if (!periodicity) break;
index = procgrid[dim] - 1;
}
count++;
delta += split[index+1] - split[index];
index--;
}
} else {
frac = cutghost[dim]/prd;
index = loc + 1;
delta = 0.0;
count = 0;
while (delta < frac) {
if (index >= procgrid[dim]) {
if (!periodicity) break;
index = 0;
}
count++;
delta += split[index+1] - split[index];
index++;
}
}
return count;
}
/* ----------------------------------------------------------------------
forward communication of atom coords every timestep
other per-atom attributes may also be sent via pack/unpack routines
------------------------------------------------------------------------- */
void CommBrick::forward_comm(int /*dummy*/)
{
int n;
MPI_Request request;
AtomVec *avec = atom->avec;
double **x = atom->x;
double *buf;
// exchange data with another proc
// if other proc is self, just copy
// if comm_x_only set, exchange or copy directly to x, don't unpack
for (int iswap = 0; iswap < nswap; iswap++) {
if (sendproc[iswap] != me) {
if (comm_x_only) {
if (size_forward_recv[iswap]) {
buf = x[firstrecv[iswap]];
MPI_Irecv(buf,size_forward_recv[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
}
n = avec->pack_comm(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
if (n) MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (size_forward_recv[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
} else if (ghost_velocity) {
if (size_forward_recv[iswap])
MPI_Irecv(buf_recv,size_forward_recv[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
n = avec->pack_comm_vel(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
if (n) MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (size_forward_recv[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
avec->unpack_comm_vel(recvnum[iswap],firstrecv[iswap],buf_recv);
} else {
if (size_forward_recv[iswap])
MPI_Irecv(buf_recv,size_forward_recv[iswap],MPI_DOUBLE,
recvproc[iswap],0,world,&request);
n = avec->pack_comm(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
if (n) MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (size_forward_recv[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
avec->unpack_comm(recvnum[iswap],firstrecv[iswap],buf_recv);
}
} else {
if (comm_x_only) {
if (sendnum[iswap])
avec->pack_comm(sendnum[iswap],sendlist[iswap],
x[firstrecv[iswap]],pbc_flag[iswap],pbc[iswap]);
} else if (ghost_velocity) {
avec->pack_comm_vel(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
avec->unpack_comm_vel(recvnum[iswap],firstrecv[iswap],buf_send);
} else {
avec->pack_comm(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
avec->unpack_comm(recvnum[iswap],firstrecv[iswap],buf_send);
}
}
}
}
/* ----------------------------------------------------------------------
reverse communication of forces on atoms every timestep
other per-atom attributes may also be sent via pack/unpack routines
------------------------------------------------------------------------- */
void CommBrick::reverse_comm()
{
int n;
MPI_Request request;
AtomVec *avec = atom->avec;
double **f = atom->f;
double *buf;
// exchange data with another proc
// if other proc is self, just copy
// if comm_f_only set, exchange or copy directly from f, don't pack
for (int iswap = nswap-1; iswap >= 0; iswap--) {
if (sendproc[iswap] != me) {
if (comm_f_only) {
if (size_reverse_recv[iswap])
MPI_Irecv(buf_recv,size_reverse_recv[iswap],MPI_DOUBLE,sendproc[iswap],0,world,&request);
if (size_reverse_send[iswap]) {
buf = f[firstrecv[iswap]];
MPI_Send(buf,size_reverse_send[iswap],MPI_DOUBLE,recvproc[iswap],0,world);
}
if (size_reverse_recv[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
} else {
if (size_reverse_recv[iswap])
MPI_Irecv(buf_recv,size_reverse_recv[iswap],MPI_DOUBLE,sendproc[iswap],0,world,&request);
n = avec->pack_reverse(recvnum[iswap],firstrecv[iswap],buf_send);
if (n) MPI_Send(buf_send,n,MPI_DOUBLE,recvproc[iswap],0,world);
if (size_reverse_recv[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
}
avec->unpack_reverse(sendnum[iswap],sendlist[iswap],buf_recv);
} else {
if (comm_f_only) {
if (sendnum[iswap])
avec->unpack_reverse(sendnum[iswap],sendlist[iswap],f[firstrecv[iswap]]);
} else {
avec->pack_reverse(recvnum[iswap],firstrecv[iswap],buf_send);
avec->unpack_reverse(sendnum[iswap],sendlist[iswap],buf_send);
}
}
}
}
/* ----------------------------------------------------------------------
exchange: move atoms to correct processors
atoms exchanged with all 6 stencil neighbors
send out atoms that have left my box, receive ones entering my box
atoms will be lost if not inside a stencil proc's box
can happen if atom moves outside of non-periodic boundary
or if atom moves more than one proc away
this routine called before every reneighboring
for triclinic, atoms must be in lamda coords (0-1) before exchange is called
------------------------------------------------------------------------- */
void CommBrick::exchange()
{
int i,m,nsend,nrecv,nrecv1,nrecv2,nlocal;
double lo,hi,value;
double **x;
double *sublo,*subhi;
MPI_Request request;
AtomVec *avec = atom->avec;
// clear global->local map for owned and ghost atoms
// b/c atoms migrate to new procs in exchange() and
// new ghosts are created in borders()
// map_set() is done at end of borders()
// clear ghost count and any ghost bonus data internal to AtomVec
if (map_style != Atom::MAP_NONE) atom->map_clear();
atom->nghost = 0;
atom->avec->clear_bonus();
// ensure send buf has extra space for a single atom
// only need to reset if a fix can dynamically add to size of single atom
if (maxexchange_fix_dynamic) {
int bufextra_old = bufextra;
init_exchange();
if (bufextra > bufextra_old) grow_send(maxsend+bufextra,2);
}
// subbox bounds for orthogonal or triclinic
if (triclinic == 0) {
sublo = domain->sublo;
subhi = domain->subhi;
} else {
sublo = domain->sublo_lamda;
subhi = domain->subhi_lamda;
}
// loop over dimensions
int dimension = domain->dimension;
for (int dim = 0; dim < dimension; dim++) {
// fill buffer with atoms leaving my box, using < and >=
// when atom is deleted, fill it in with last atom
x = atom->x;
lo = sublo[dim];
hi = subhi[dim];
nlocal = atom->nlocal;
i = nsend = 0;
while (i < nlocal) {
if (x[i][dim] < lo || x[i][dim] >= hi) {
if (nsend > maxsend) grow_send(nsend,1);
nsend += avec->pack_exchange(i,&buf_send[nsend]);
avec->copy(nlocal-1,i,1);
nlocal--;
} else i++;
}
atom->nlocal = nlocal;
// send/recv atoms in both directions
// send size of message first so receiver can realloc buf_recv if needed
// if 1 proc in dimension, no send/recv
// set nrecv = 0 so buf_send atoms will be lost
// if 2 procs in dimension, single send/recv
// if more than 2 procs in dimension, send/recv to both neighbors
if (procgrid[dim] == 1) nrecv = 0;
else {
MPI_Sendrecv(&nsend,1,MPI_INT,procneigh[dim][0],0,
&nrecv1,1,MPI_INT,procneigh[dim][1],0,world,MPI_STATUS_IGNORE);
nrecv = nrecv1;
if (procgrid[dim] > 2) {
MPI_Sendrecv(&nsend,1,MPI_INT,procneigh[dim][1],0,
&nrecv2,1,MPI_INT,procneigh[dim][0],0,world,MPI_STATUS_IGNORE);
nrecv += nrecv2;
}
if (nrecv > maxrecv) grow_recv(nrecv);
MPI_Irecv(buf_recv,nrecv1,MPI_DOUBLE,procneigh[dim][1],0,world,&request);
MPI_Send(buf_send,nsend,MPI_DOUBLE,procneigh[dim][0],0,world);
MPI_Wait(&request,MPI_STATUS_IGNORE);
if (procgrid[dim] > 2) {
MPI_Irecv(&buf_recv[nrecv1],nrecv2,MPI_DOUBLE,procneigh[dim][0],0,world,&request);
MPI_Send(buf_send,nsend,MPI_DOUBLE,procneigh[dim][1],0,world);
MPI_Wait(&request,MPI_STATUS_IGNORE);
}
}
// check incoming atoms to see if they are in my box
// if so, add to my list
// box check is only for this dimension,
// atom may be passed to another proc in later dims
m = 0;
while (m < nrecv) {
value = buf_recv[m+dim+1];
if (value >= lo && value < hi) m += avec->unpack_exchange(&buf_recv[m]);
else m += static_cast<int> (buf_recv[m]);
}
}
if (atom->firstgroupname) atom->first_reorder();
}
/* ----------------------------------------------------------------------
borders: list nearby atoms to send to neighboring procs at every timestep
one list is created for every swap that will be made
as list is made, actually do swaps
this does equivalent of a forward_comm(), so don't need to explicitly
call forward_comm() on reneighboring timestep
this routine is called before every reneighboring
for triclinic, atoms must be in lamda coords (0-1) before borders is called
------------------------------------------------------------------------- */
void CommBrick::borders()
{
int i,n,itype,icollection,iswap,dim,ineed,twoneed;
int nsend,nrecv,sendflag,nfirst,nlast,ngroup,nprior;
double lo,hi;
int *type;
int *collection;
double **x;
double *buf,*mlo,*mhi;
MPI_Request request;
AtomVec *avec = atom->avec;
// After exchanging/sorting, need to reconstruct collection array for border communication
if (mode == Comm::MULTI) neighbor->build_collection(0);
// do swaps over all 3 dimensions
iswap = 0;
smax = rmax = 0;
for (dim = 0; dim < 3; dim++) {
nlast = 0;
twoneed = 2*maxneed[dim];
for (ineed = 0; ineed < twoneed; ineed++) {
// find atoms within slab boundaries lo/hi using <= and >=
// check atoms between nfirst and nlast
// for first swaps in a dim, check owned and ghost
// for later swaps in a dim, only check newly arrived ghosts
// store sent atom indices in sendlist for use in future timesteps
x = atom->x;
if (mode == Comm::SINGLE) {
lo = slablo[iswap];
hi = slabhi[iswap];
} else if (mode == Comm::MULTI) {
collection = neighbor->collection;
mlo = multilo[iswap];
mhi = multihi[iswap];
} else {
type = atom->type;
mlo = multioldlo[iswap];
mhi = multioldhi[iswap];
}
if (ineed % 2 == 0) {
nfirst = nlast;
nlast = atom->nlocal + atom->nghost;
}
nsend = 0;
// sendflag = 0 if I do not send on this swap
// sendneed test indicates receiver no longer requires data
// e.g. due to non-PBC or non-uniform sub-domains
if (ineed/2 >= sendneed[dim][ineed % 2]) sendflag = 0;
else sendflag = 1;
// find send atoms according to SINGLE vs MULTI
// all atoms eligible versus only atoms in bordergroup
// can only limit loop to bordergroup for first sends (ineed < 2)
// on these sends, break loop in two: owned (in group) and ghost
if (sendflag) {
if (!bordergroup || ineed >= 2) {
if (mode == Comm::SINGLE) {
for (i = nfirst; i < nlast; i++)
if (x[i][dim] >= lo && x[i][dim] <= hi) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
} else if (mode == Comm::MULTI) {
for (i = nfirst; i < nlast; i++) {
icollection = collection[i];
if (x[i][dim] >= mlo[icollection] && x[i][dim] <= mhi[icollection]) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
}
} else {
for (i = nfirst; i < nlast; i++) {
itype = type[i];
if (x[i][dim] >= mlo[itype] && x[i][dim] <= mhi[itype]) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
}
}
} else {
if (mode == Comm::SINGLE) {
ngroup = atom->nfirst;
for (i = 0; i < ngroup; i++)
if (x[i][dim] >= lo && x[i][dim] <= hi) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
for (i = atom->nlocal; i < nlast; i++)
if (x[i][dim] >= lo && x[i][dim] <= hi) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
} else if (mode == Comm::MULTI) {
ngroup = atom->nfirst;
for (i = 0; i < ngroup; i++) {
icollection = collection[i];
if (x[i][dim] >= mlo[icollection] && x[i][dim] <= mhi[icollection]) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
}
for (i = atom->nlocal; i < nlast; i++) {
icollection = collection[i];
if (x[i][dim] >= mlo[icollection] && x[i][dim] <= mhi[icollection]) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
}
} else {
ngroup = atom->nfirst;
for (i = 0; i < ngroup; i++) {
itype = type[i];
if (x[i][dim] >= mlo[itype] && x[i][dim] <= mhi[itype]) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
}
for (i = atom->nlocal; i < nlast; i++) {
itype = type[i];
if (x[i][dim] >= mlo[itype] && x[i][dim] <= mhi[itype]) {
if (nsend == maxsendlist[iswap]) grow_list(iswap,nsend);
sendlist[iswap][nsend++] = i;
}
}
}
}
}
// pack up list of border atoms
if (nsend*size_border > maxsend) grow_send(nsend*size_border,0);
if (ghost_velocity)
n = avec->pack_border_vel(nsend,sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
else
n = avec->pack_border(nsend,sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
// swap atoms with other proc
// no MPI calls except SendRecv if nsend/nrecv = 0
// put incoming ghosts at end of my atom arrays
// if swapping with self, simply copy, no messages
if (sendproc[iswap] != me) {
MPI_Sendrecv(&nsend,1,MPI_INT,sendproc[iswap],0,
&nrecv,1,MPI_INT,recvproc[iswap],0,world,MPI_STATUS_IGNORE);
if (nrecv*size_border > maxrecv) grow_recv(nrecv*size_border);
if (nrecv) MPI_Irecv(buf_recv,nrecv*size_border,MPI_DOUBLE,
recvproc[iswap],0,world,&request);
if (n) MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (nrecv) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else {
nrecv = nsend;
buf = buf_send;
}
// unpack buffer
if (ghost_velocity)
avec->unpack_border_vel(nrecv,atom->nlocal+atom->nghost,buf);
else
avec->unpack_border(nrecv,atom->nlocal+atom->nghost,buf);
// set all pointers & counters
smax = MAX(smax,nsend);
rmax = MAX(rmax,nrecv);
sendnum[iswap] = nsend;
recvnum[iswap] = nrecv;
size_forward_recv[iswap] = nrecv*size_forward;
size_reverse_send[iswap] = nrecv*size_reverse;
size_reverse_recv[iswap] = nsend*size_reverse;
firstrecv[iswap] = atom->nlocal + atom->nghost;
nprior = atom->nlocal + atom->nghost;
atom->nghost += nrecv;
if (neighbor->style == Neighbor::MULTI) neighbor->build_collection(nprior);
iswap++;
}
}
// For molecular systems we lose some bits for local atom indices due
// to encoding of special pairs in neighbor lists. Check for overflows.
if ((atom->molecular != Atom::ATOMIC)
&& ((atom->nlocal + atom->nghost) > NEIGHMASK))
error->one(FLERR,"Per-processor number of atoms is too large for "
"molecular neighbor lists");
// ensure send/recv buffers are long enough for all forward & reverse comm
int max = MAX(maxforward*smax,maxreverse*rmax);
if (max > maxsend) grow_send(max,0);
max = MAX(maxforward*rmax,maxreverse*smax);
if (max > maxrecv) grow_recv(max);
// reset global->local map
if (map_style != Atom::MAP_NONE) atom->map_set();
}
/* ----------------------------------------------------------------------
forward communication invoked by a Pair
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::forward_comm(Pair *pair)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = pair->comm_forward;
for (iswap = 0; iswap < nswap; iswap++) {
// pack buffer
n = pair->pack_forward_comm(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (recvnum[iswap])
MPI_Irecv(buf_recv,nsize*recvnum[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
if (sendnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (recvnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
pair->unpack_forward_comm(recvnum[iswap],firstrecv[iswap],buf);
}
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Pair
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::reverse_comm(Pair *pair)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = MAX(pair->comm_reverse,pair->comm_reverse_off);
for (iswap = nswap-1; iswap >= 0; iswap--) {
// pack buffer
n = pair->pack_reverse_comm(recvnum[iswap],firstrecv[iswap],buf_send);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (sendnum[iswap])
MPI_Irecv(buf_recv,nsize*sendnum[iswap],MPI_DOUBLE,sendproc[iswap],0,world,&request);
if (recvnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,recvproc[iswap],0,world);
if (sendnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
pair->unpack_reverse_comm(sendnum[iswap],sendlist[iswap],buf);
}
}
void CommBrick::forward_comm(PairE3GNNParallel *pair)
{
int iswap,n;
MPI_Request request;
const bool comm_preprocess_done = pair->is_comm_preprocess_done();
const int nsize = pair->get_x_dim();
float *buf_send_, *buf_recv_;
if(pair->use_cuda_mpi_()) {
DeviceBuffManager::getInstance().get_buffer(maxsend+bufextra, maxrecv,buf_send_, buf_recv_);
} else {
buf_send_ = reinterpret_cast<float*>(buf_send);
buf_recv_ = reinterpret_cast<float*>(buf_recv);
}
if(!comm_preprocess_done) {
pair->notify_proc_ids(sendproc, recvproc);
}
if (nswap > 6) error->all(FLERR,"PairE3GNNParallel: Cell size is too small. Please use a single GPU or make a supercell");
for (iswap = 0; iswap < nswap; iswap++) {
if(sendproc[iswap] == me) continue;
if(!comm_preprocess_done) {
pair->pack_forward_init(sendnum[iswap], sendlist[iswap], iswap);
pair->unpack_forward_init(recvnum[iswap], firstrecv[iswap], iswap);
} else {
n = pair->pack_forward_comm_gnn(buf_send_, iswap);
if (recvnum[iswap])
MPI_Irecv(buf_recv_,nsize*recvnum[iswap],MPI_FLOAT,recvproc[iswap],0,world,&request);
if (sendnum[iswap])
MPI_Send(buf_send_,n,MPI_FLOAT,sendproc[iswap],0,world);
if (recvnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
pair->unpack_forward_comm_gnn(buf_recv_, iswap);
}
}
}
void CommBrick::reverse_comm(PairE3GNNParallel *pair)
{
int iswap,n;
MPI_Request request;
const bool comm_preprocess_done = pair->is_comm_preprocess_done();
int nsize = pair->get_x_dim();
float *buf_send_, *buf_recv_;
if(pair->use_cuda_mpi_()) {
DeviceBuffManager::getInstance().get_buffer(maxsend+bufextra, maxrecv,buf_send_, buf_recv_);
} else {
buf_send_ = reinterpret_cast<float*>(buf_send);
buf_recv_ = reinterpret_cast<float*>(buf_recv);
}
for (iswap = nswap-1; iswap >= 0; iswap--) {
if(sendproc[iswap] == me) continue;
n = pair->pack_reverse_comm_gnn(buf_send_, iswap);
if (sendnum[iswap])
MPI_Irecv(buf_recv_,nsize*sendnum[iswap],MPI_FLOAT,sendproc[iswap],0,world,&request);
if (recvnum[iswap])
MPI_Send(buf_send_,n,MPI_FLOAT,recvproc[iswap],0,world);
if (sendnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
pair->unpack_reverse_comm_gnn(buf_recv_, iswap);
}
}
/* ----------------------------------------------------------------------
forward communication invoked by a Bond
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::forward_comm(Bond *bond)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = bond->comm_forward;
for (iswap = 0; iswap < nswap; iswap++) {
// pack buffer
n = bond->pack_forward_comm(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (recvnum[iswap])
MPI_Irecv(buf_recv,nsize*recvnum[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
if (sendnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (recvnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
bond->unpack_forward_comm(recvnum[iswap],firstrecv[iswap],buf);
}
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Bond
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::reverse_comm(Bond *bond)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = MAX(bond->comm_reverse,bond->comm_reverse_off);
for (iswap = nswap-1; iswap >= 0; iswap--) {
// pack buffer
n = bond->pack_reverse_comm(recvnum[iswap],firstrecv[iswap],buf_send);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (sendnum[iswap])
MPI_Irecv(buf_recv,nsize*sendnum[iswap],MPI_DOUBLE,sendproc[iswap],0,world,&request);
if (recvnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,recvproc[iswap],0,world);
if (sendnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
bond->unpack_reverse_comm(sendnum[iswap],sendlist[iswap],buf);
}
}
/* ----------------------------------------------------------------------
forward communication invoked by a Fix
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix
size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommBrick::forward_comm(Fix *fix, int size)
{
int iswap,n,nsize;
double *buf;
MPI_Request request;
if (size) nsize = size;
else nsize = fix->comm_forward;
for (iswap = 0; iswap < nswap; iswap++) {
// pack buffer
n = fix->pack_forward_comm(sendnum[iswap],sendlist[iswap],buf_send,pbc_flag[iswap],pbc[iswap]);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (recvnum[iswap])
MPI_Irecv(buf_recv,nsize*recvnum[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
if (sendnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (recvnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
fix->unpack_forward_comm(recvnum[iswap],firstrecv[iswap],buf);
}
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Fix
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix
size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommBrick::reverse_comm(Fix *fix, int size)
{
int iswap,n,nsize;
double *buf;
MPI_Request request;
if (size) nsize = size;
else nsize = fix->comm_reverse;
for (iswap = nswap-1; iswap >= 0; iswap--) {
// pack buffer
n = fix->pack_reverse_comm(recvnum[iswap],firstrecv[iswap],buf_send);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (sendnum[iswap])
MPI_Irecv(buf_recv,nsize*sendnum[iswap],MPI_DOUBLE,sendproc[iswap],0,world,&request);
if (recvnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,recvproc[iswap],0,world);
if (sendnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
fix->unpack_reverse_comm(sendnum[iswap],sendlist[iswap],buf);
}
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Fix with variable size data
query fix for pack size to ensure buf_send is big enough
handshake sizes before each Irecv/Send to ensure buf_recv is big enough
------------------------------------------------------------------------- */
void CommBrick::reverse_comm_variable(Fix *fix)
{
int iswap,nsend,nrecv;
double *buf;
MPI_Request request;
for (iswap = nswap-1; iswap >= 0; iswap--) {
// pack buffer
nsend = fix->pack_reverse_comm_size(recvnum[iswap],firstrecv[iswap]);
if (nsend > maxsend) grow_send(nsend,0);
nsend = fix->pack_reverse_comm(recvnum[iswap],firstrecv[iswap],buf_send);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
MPI_Sendrecv(&nsend,1,MPI_INT,recvproc[iswap],0,
&nrecv,1,MPI_INT,sendproc[iswap],0,world,MPI_STATUS_IGNORE);
if (sendnum[iswap]) {
if (nrecv > maxrecv) grow_recv(nrecv);
MPI_Irecv(buf_recv,maxrecv,MPI_DOUBLE,sendproc[iswap],0,world,&request);
}
if (recvnum[iswap])
MPI_Send(buf_send,nsend,MPI_DOUBLE,recvproc[iswap],0,world);
if (sendnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
fix->unpack_reverse_comm(sendnum[iswap],sendlist[iswap],buf);
}
}
/* ----------------------------------------------------------------------
forward communication invoked by a Compute
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::forward_comm(Compute *compute)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = compute->comm_forward;
for (iswap = 0; iswap < nswap; iswap++) {
// pack buffer
n = compute->pack_forward_comm(sendnum[iswap],sendlist[iswap],
buf_send,pbc_flag[iswap],pbc[iswap]);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (recvnum[iswap])
MPI_Irecv(buf_recv,nsize*recvnum[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
if (sendnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (recvnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
compute->unpack_forward_comm(recvnum[iswap],firstrecv[iswap],buf);
}
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Compute
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::reverse_comm(Compute *compute)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = compute->comm_reverse;
for (iswap = nswap-1; iswap >= 0; iswap--) {
// pack buffer
n = compute->pack_reverse_comm(recvnum[iswap],firstrecv[iswap],buf_send);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (sendnum[iswap])
MPI_Irecv(buf_recv,nsize*sendnum[iswap],MPI_DOUBLE,sendproc[iswap],0,world,&request);
if (recvnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,recvproc[iswap],0,world);
if (sendnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
compute->unpack_reverse_comm(sendnum[iswap],sendlist[iswap],buf);
}
}
/* ----------------------------------------------------------------------
forward communication invoked by a Dump
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::forward_comm(Dump *dump)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = dump->comm_forward;
for (iswap = 0; iswap < nswap; iswap++) {
// pack buffer
n = dump->pack_forward_comm(sendnum[iswap],sendlist[iswap],
buf_send,pbc_flag[iswap],pbc[iswap]);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (recvnum[iswap])
MPI_Irecv(buf_recv,nsize*recvnum[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
if (sendnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,sendproc[iswap],0,world);
if (recvnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
dump->unpack_forward_comm(recvnum[iswap],firstrecv[iswap],buf);
}
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Dump
nsize used only to set recv buffer limit
------------------------------------------------------------------------- */
void CommBrick::reverse_comm(Dump *dump)
{
int iswap,n;
double *buf;
MPI_Request request;
int nsize = dump->comm_reverse;
for (iswap = nswap-1; iswap >= 0; iswap--) {
// pack buffer
n = dump->pack_reverse_comm(recvnum[iswap],firstrecv[iswap],buf_send);
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (sendnum[iswap])
MPI_Irecv(buf_recv,nsize*sendnum[iswap],MPI_DOUBLE,sendproc[iswap],0,world,&request);
if (recvnum[iswap])
MPI_Send(buf_send,n,MPI_DOUBLE,recvproc[iswap],0,world);
if (sendnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
dump->unpack_reverse_comm(sendnum[iswap],sendlist[iswap],buf);
}
}
/* ----------------------------------------------------------------------
forward communication of N values in per-atom array
------------------------------------------------------------------------- */
void CommBrick::forward_comm_array(int nsize, double **array)
{
int i,j,k,m,iswap,last;
double *buf;
MPI_Request request;
// ensure send/recv bufs are big enough for nsize
// based on smax/rmax from most recent borders() invocation
if (nsize > maxforward) {
maxforward = nsize;
if (maxforward*smax > maxsend) grow_send(maxforward*smax,0);
if (maxforward*rmax > maxrecv) grow_recv(maxforward*rmax);
}
for (iswap = 0; iswap < nswap; iswap++) {
// pack buffer
m = 0;
for (i = 0; i < sendnum[iswap]; i++) {
j = sendlist[iswap][i];
for (k = 0; k < nsize; k++)
buf_send[m++] = array[j][k];
}
// exchange with another proc
// if self, set recv buffer to send buffer
if (sendproc[iswap] != me) {
if (recvnum[iswap])
MPI_Irecv(buf_recv,nsize*recvnum[iswap],MPI_DOUBLE,recvproc[iswap],0,world,&request);
if (sendnum[iswap])
MPI_Send(buf_send,nsize*sendnum[iswap],MPI_DOUBLE,sendproc[iswap],0,world);
if (recvnum[iswap]) MPI_Wait(&request,MPI_STATUS_IGNORE);
buf = buf_recv;
} else buf = buf_send;
// unpack buffer
m = 0;
last = firstrecv[iswap] + recvnum[iswap];
for (i = firstrecv[iswap]; i < last; i++)
for (k = 0; k < nsize; k++)
array[i][k] = buf[m++];
}
}
/* ----------------------------------------------------------------------
realloc the size of the send buffer as needed with BUFFACTOR and bufextra
flag = 0, don't need to realloc with copy, just free/malloc w/ BUFFACTOR
flag = 1, realloc with BUFFACTOR
flag = 2, free/malloc w/out BUFFACTOR
------------------------------------------------------------------------- */
void CommBrick::grow_send(int n, int flag)
{
if (flag == 0) {
maxsend = static_cast<int> (BUFFACTOR * n);
memory->destroy(buf_send);
memory->create(buf_send,maxsend+bufextra,"comm:buf_send");
} else if (flag == 1) {
maxsend = static_cast<int> (BUFFACTOR * n);
memory->grow(buf_send,maxsend+bufextra,"comm:buf_send");
} else {
memory->destroy(buf_send);
memory->grow(buf_send,maxsend+bufextra,"comm:buf_send");
}
}
/* ----------------------------------------------------------------------
free/malloc the size of the recv buffer as needed with BUFFACTOR
------------------------------------------------------------------------- */
void CommBrick::grow_recv(int n)
{
maxrecv = static_cast<int> (BUFFACTOR * n);
memory->destroy(buf_recv);
memory->create(buf_recv,maxrecv,"comm:buf_recv");
}
/* ----------------------------------------------------------------------
realloc the size of the iswap sendlist as needed with BUFFACTOR
------------------------------------------------------------------------- */
void CommBrick::grow_list(int iswap, int n)
{
maxsendlist[iswap] = static_cast<int> (BUFFACTOR * n);
memory->grow(sendlist[iswap],maxsendlist[iswap],"comm:sendlist[iswap]");
}
/* ----------------------------------------------------------------------
realloc the buffers needed for swaps
------------------------------------------------------------------------- */
void CommBrick::grow_swap(int n)
{
free_swap();
allocate_swap(n);
if (mode == Comm::MULTI) {
free_multi();
allocate_multi(n);
}
if (mode == Comm::MULTIOLD) {
free_multiold();
allocate_multiold(n);
}
sendlist = (int **)
memory->srealloc(sendlist,n*sizeof(int *),"comm:sendlist");
memory->grow(maxsendlist,n,"comm:maxsendlist");
for (int i = maxswap; i < n; i++) {
maxsendlist[i] = BUFMIN;
memory->create(sendlist[i],BUFMIN,"comm:sendlist[i]");
}
maxswap = n;
}
/* ----------------------------------------------------------------------
allocation of swap info
------------------------------------------------------------------------- */
void CommBrick::allocate_swap(int n)
{
memory->create(sendnum,n,"comm:sendnum");
memory->create(recvnum,n,"comm:recvnum");
memory->create(sendproc,n,"comm:sendproc");
memory->create(recvproc,n,"comm:recvproc");
memory->create(size_forward_recv,n,"comm:size");
memory->create(size_reverse_send,n,"comm:size");
memory->create(size_reverse_recv,n,"comm:size");
memory->create(slablo,n,"comm:slablo");
memory->create(slabhi,n,"comm:slabhi");
memory->create(firstrecv,n,"comm:firstrecv");
memory->create(pbc_flag,n,"comm:pbc_flag");
memory->create(pbc,n,6,"comm:pbc");
}
/* ----------------------------------------------------------------------
allocation of multi-collection swap info
------------------------------------------------------------------------- */
void CommBrick::allocate_multi(int n)
{
multilo = memory->create(multilo,n,ncollections,"comm:multilo");
multihi = memory->create(multihi,n,ncollections,"comm:multihi");
}
/* ----------------------------------------------------------------------
allocation of multi/old-type swap info
------------------------------------------------------------------------- */
void CommBrick::allocate_multiold(int n)
{
multioldlo = memory->create(multioldlo,n,atom->ntypes+1,"comm:multioldlo");
multioldhi = memory->create(multioldhi,n,atom->ntypes+1,"comm:multioldhi");
}
/* ----------------------------------------------------------------------
free memory for swaps
------------------------------------------------------------------------- */
void CommBrick::free_swap()
{
memory->destroy(sendnum);
memory->destroy(recvnum);
memory->destroy(sendproc);
memory->destroy(recvproc);
memory->destroy(size_forward_recv);
memory->destroy(size_reverse_send);
memory->destroy(size_reverse_recv);
memory->destroy(slablo);
memory->destroy(slabhi);
memory->destroy(firstrecv);
memory->destroy(pbc_flag);
memory->destroy(pbc);
}
/* ----------------------------------------------------------------------
free memory for multi-collection swaps
------------------------------------------------------------------------- */
void CommBrick::free_multi()
{
memory->destroy(multilo);
memory->destroy(multihi);
multilo = multihi = nullptr;
}
/* ----------------------------------------------------------------------
free memory for multi/old-type swaps
------------------------------------------------------------------------- */
void CommBrick::free_multiold()
{
memory->destroy(multioldlo);
memory->destroy(multioldhi);
multioldlo = multioldhi = nullptr;
}
/* ----------------------------------------------------------------------
extract data potentially useful to other classes
------------------------------------------------------------------------- */
void *CommBrick::extract(const char *str, int &dim)
{
dim = 0;
if (strcmp(str,"localsendlist") == 0) {
int i, iswap, isend;
dim = 1;
if (!localsendlist)
memory->create(localsendlist,atom->nlocal,"comm:localsendlist");
else
memory->grow(localsendlist,atom->nlocal,"comm:localsendlist");
for (i = 0; i < atom->nlocal; i++)
localsendlist[i] = 0;
for (iswap = 0; iswap < nswap; iswap++)
for (isend = 0; isend < sendnum[iswap]; isend++)
if (sendlist[iswap][isend] < atom->nlocal)
localsendlist[sendlist[iswap][isend]] = 1;
return (void *) localsendlist;
}
return nullptr;
}
/* ----------------------------------------------------------------------
return # of bytes of allocated memory
------------------------------------------------------------------------- */
double CommBrick::memory_usage()
{
double bytes = 0;
bytes += (double)nprocs * sizeof(int); // grid2proc
for (int i = 0; i < nswap; i++)
bytes += memory->usage(sendlist[i],maxsendlist[i]);
bytes += memory->usage(buf_send,maxsend+bufextra);
bytes += memory->usage(buf_recv,maxrecv);
return bytes;
}
/* -*- c++ -*- ----------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories
LAMMPS development team: developers@lammps.org
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
#ifndef LMP_COMM_BRICK_H
#define LMP_COMM_BRICK_H
#include "comm.h"
namespace LAMMPS_NS {
class CommBrick : public Comm {
public:
CommBrick(class LAMMPS *);
CommBrick(class LAMMPS *, class Comm *);
~CommBrick() override;
void init() override;
void setup() override; // setup 3d comm pattern
void forward_comm(int dummy = 0) override; // forward comm of atom coords
void reverse_comm() override; // reverse comm of forces
void exchange() override; // move atoms to new procs
void borders() override; // setup list of atoms to comm
void forward_comm(class Pair *) override; // forward comm from a Pair
void reverse_comm(class Pair *) override; // reverse comm from a Pair
void forward_comm(class Bond *) override; // forward comm from a Bond
void reverse_comm(class Bond *) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *) override; // forward from a Compute
void reverse_comm(class Compute *) override; // reverse from a Compute
void forward_comm(class Dump *) override; // forward comm from a Dump
void reverse_comm(class Dump *) override; // reverse comm from a Dump
void forward_comm_array(int, double **) override; // forward comm of array
void *extract(const char *, int &) override;
double memory_usage() override;
// patched from SevenNet //
void forward_comm(class PairE3GNNParallel *);
void reverse_comm(class PairE3GNNParallel *);
// patched from SevenNet //
protected:
int nswap; // # of swaps to perform = sum of maxneed
int recvneed[3][2]; // # of procs away I recv atoms from
int sendneed[3][2]; // # of procs away I send atoms to
int maxneed[3]; // max procs away any proc needs, per dim
int maxswap; // max # of swaps memory is allocated for
int *sendnum, *recvnum; // # of atoms to send/recv in each swap
int *sendproc, *recvproc; // proc to send/recv to/from at each swap
int *size_forward_recv; // # of values to recv in each forward comm
int *size_reverse_send; // # to send in each reverse comm
int *size_reverse_recv; // # to recv in each reverse comm
double *slablo, *slabhi; // bounds of slab to send at each swap
double **multilo, **multihi; // bounds of slabs for multi-collection swap
double **multioldlo, **multioldhi; // bounds of slabs for multi-type swap
double **cutghostmulti; // cutghost on a per-collection basis
double **cutghostmultiold; // cutghost on a per-type basis
int *pbc_flag; // general flag for sending atoms thru PBC
int **pbc; // dimension flags for PBC adjustments
int *firstrecv; // where to put 1st recv atom in each swap
int **sendlist; // list of atoms to send in each swap
int *localsendlist; // indexed list of local sendlist atoms
int *maxsendlist; // max size of send list for each swap
double *buf_send; // send buffer for all comm
double *buf_recv; // recv buffer for all comm
int maxsend, maxrecv; // current size of send/recv buffer
int smax, rmax; // max size in atoms of single borders send/recv
// NOTE: init_buffers is called from a constructor and must not be made virtual
void init_buffers();
int updown(int, int, int, double, int, double *);
// compare cutoff to procs
virtual void grow_send(int, int); // reallocate send buffer
virtual void grow_recv(int); // free/allocate recv buffer
virtual void grow_list(int, int); // reallocate one sendlist
virtual void grow_swap(int); // grow swap, multi, and multi/old arrays
virtual void allocate_swap(int); // allocate swap arrays
virtual void allocate_multi(int); // allocate multi arrays
virtual void allocate_multiold(int); // allocate multi/old arrays
virtual void free_swap(); // free swap arrays
virtual void free_multi(); // free multi arrays
virtual void free_multiold(); // free multi/old arrays
};
} // namespace LAMMPS_NS
#endif
/* ----------------------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories
LAMMPS development team: developers@lammps.org
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
/* ----------------------------------------------------------------------
Contributing author: Gijin Kim, Hyungmin An (SNU)
------------------------------------------------------------------------- */
#include "pair_d3.h"
using namespace LAMMPS_NS;
/* --------- Macros for CUDA error handling --------- */
#define START_CUDA_TIMER() \
cudaEvent_t start, stop; \
cudaEventCreate(&start); \
cudaEventCreate(&stop); \
cudaEventRecord(start);
#define STOP_CUDA_TIMER(tag) \
cudaEventRecord(stop); \
cudaEventSynchronize(stop); \
float msec = 0; \
cudaEventElapsedTime(&msec, start, stop); \
printf("Elapsed time for %s: %f ms\n", tag, msec); \
cudaEventDestroy(start); \
cudaEventDestroy(stop);
#define CHECK_CUDA(call) do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA Error (%s:%d) -> %s: %s\n", __FILE__, __LINE__, \
cudaGetErrorName(status_), cudaGetErrorString(status_)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define CHECK_CUDA_ERROR() do { \
cudaDeviceSynchronize(); \
cudaError_t status_ = cudaGetLastError(); \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA Error (%s:%d) -> %s: %s\n", __FILE__, __LINE__, \
cudaGetErrorName(status_), cudaGetErrorString(status_)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define CHECK_CUDA_DEVICES() do { \
int deviceCount = 0; \
if (cudaGetDeviceCount(&deviceCount) != cudaSuccess || deviceCount == 0) { \
fprintf(stderr, "CUDA Error (%s:%d) -> No CUDA devices found\n", \
__FILE__, __LINE__); \
exit(EXIT_FAILURE); \
} \
} while(0)
/* --------- Macros for CUDA error handling --------- */
/* --------- Math functions for CUDA compatibility --------- */
inline __host__ __device__ void ij_at_linij(int linij, int &i, int &j) {
i = static_cast<int>((sqrt(1 + 8 * linij) - 1) / 2);
j = linij - i * (i + 1) / 2;
} // unroll the triangular loop
inline __host__ __device__ float lensq3(const float *v)
{
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2];
} // from MathExtra::lensq3
/* --------- Math functions for CUDA compatibility --------- */
/* ----------------------------------------------------------------------
Constructor (Required)
------------------------------------------------------------------------- */
PairD3::PairD3(LAMMPS* lmp) : Pair(lmp) {
single_enable = 0; // potential is not pair-wise additive.
restartinfo = 0; // Many-body potentials are usually not
// written to binary restart files.
one_coeff = 1; // Many-body potnetials typically read all
// parameters from a file, so only one
// pair_coeff statement is needed.
manybody_flag = 1;
no_virial_fdotr_compute = 1;
}
/* ----------------------------------------------------------------------
Destructor (Required)
------------------------------------------------------------------------- */
PairD3::~PairD3() {
if (allocated) {
int n = atom->natoms;
int np1 = atom->ntypes + 1;
int vdw_range_x = 2 * rep_vdw[0] + 1;
int vdw_range_y = 2 * rep_vdw[1] + 1;
int vdw_range_z = 2 * rep_vdw[2] + 1;
int cn_range_x = 2 * rep_cn[0] + 1;
int cn_range_y = 2 * rep_cn[1] + 1;
int cn_range_z = 2 * rep_cn[2] + 1;
for (int i = 0; i < np1; i++) { cudaFree(setflag[i]); }; cudaFree(setflag);
for (int i = 0; i < np1; i++) { cudaFree(cutsq[i]); }; cudaFree(cutsq);
cudaFree(r2r4);
cudaFree(rcov);
cudaFree(mxc);
for (int i = 0; i < np1; i++) { cudaFree(r0ab[i]); }; cudaFree(r0ab);
for (int i = 0; i < np1; i++) {
for (int j = 0; j < np1; j++) {
for (int k = 0; k < MAXC; k++) {
for (int l = 0; l < MAXC; l++) {
cudaFree(c6ab[i][j][k][l]);
}
cudaFree(c6ab[i][j][k]);
}
cudaFree(c6ab[i][j]);
}
cudaFree(c6ab[i]);
}
cudaFree(c6ab);
cudaFree(lat_v_1);
cudaFree(lat_v_2);
cudaFree(lat_v_3);
cudaFree(rep_vdw);
cudaFree(rep_cn);
cudaFree(cn);
for (int i = 0; i < n; i++) { cudaFree(x[i]); }; cudaFree(x);
cudaFree(dc6i);
for (int i = 0; i < n; i++) { cudaFree(f[i]); }; cudaFree(f);
for (int i = 0; i < 3; i++) { cudaFree(sigma[i]); }; cudaFree(sigma);
cudaFree(dc6_iji_tot);
cudaFree(dc6_ijj_tot);
cudaFree(c6_ij_tot);
for (int i = 0; i < vdw_range_x; i++) {
for (int j = 0; j < vdw_range_y; j++) {
for (int k = 0; k < vdw_range_z; k++) {
cudaFree(tau_vdw[i][j][k]);
}
cudaFree(tau_vdw[i][j]);
}
cudaFree(tau_vdw[i]);
}
cudaFree(tau_vdw);
for (int i = 0; i < cn_range_x; i++) {
for (int j = 0; j < cn_range_y; j++) {
for (int k = 0; k < cn_range_z; k++) {
cudaFree(tau_cn[i][j][k]);
}
cudaFree(tau_cn[i][j]);
}
cudaFree(tau_cn[i]);
}
cudaFree(tau_cn);
cudaFree(tau_idx_vdw);
cudaFree(tau_idx_cn);
cudaFree(atomtype);
cudaFree(disp);
}
}
/* ----------------------------------------------------------------------
Allocate all arrays (Required)
------------------------------------------------------------------------- */
void PairD3::allocate() {
CHECK_CUDA_DEVICES();
allocated = 1;
/* atom->ntypes : # of elements; element index starts from 1 */
int n = atom->natoms;
int np1 = atom->ntypes + 1;
n_save = n;
cudaMallocManaged(&setflag, np1 * sizeof(int*)); for (int i = 0; i < np1; i++) { cudaMallocManaged(&setflag[i], np1 * sizeof(int)); }
cudaMallocManaged(&cutsq, np1 * sizeof(double*)); for (int i = 0; i < np1; i++) { cudaMallocManaged(&cutsq[i], np1 * sizeof(double)); }
cudaMallocManaged(&r2r4, np1 * sizeof(float));
cudaMallocManaged(&rcov, np1 * sizeof(float));
cudaMallocManaged(&mxc, np1 * sizeof(int));
cudaMallocManaged(&r0ab, np1 * sizeof(float*)); for (int i = 0; i < np1; i++) { cudaMallocManaged(&r0ab[i], np1 * sizeof(float)); }
cudaMallocManaged(&c6ab, np1 * sizeof(float****));
for (int i = 0; i < np1; i++) {
cudaMallocManaged(&c6ab[i], np1 * sizeof(float***));
for (int j = 0; j < np1; j++) {
cudaMallocManaged(&c6ab[i][j], MAXC * sizeof(float**));
for (int k = 0; k < MAXC; k++) {
cudaMallocManaged(&c6ab[i][j][k], MAXC * sizeof(float*));
for (int l = 0; l < MAXC; l++) {
cudaMallocManaged(&c6ab[i][j][k][l], 3 * sizeof(float));
}
}
}
}
cudaMallocManaged(&lat_v_1, 3 * sizeof(float));
cudaMallocManaged(&lat_v_2, 3 * sizeof(float));
cudaMallocManaged(&lat_v_3, 3 * sizeof(float));
cudaMallocManaged(&rep_vdw, 3 * sizeof(int));
cudaMallocManaged(&rep_cn, 3 * sizeof(int));
cudaMallocManaged(&sigma, 3 * sizeof(double*)); for (int i = 0; i < 3; i++) { cudaMallocManaged(&sigma[i], 3 * sizeof(double)); }
cudaMallocManaged(&cn, n * sizeof(double));
cudaMallocManaged(&x, n * sizeof(float*)); for (int i = 0; i < n; i++) { cudaMallocManaged(&x[i], 3 * sizeof(float)); }
cudaMallocManaged(&dc6i, n * sizeof(double));
cudaMallocManaged(&f, n * sizeof(double*)); for (int i = 0; i < n; i++) { cudaMallocManaged(&f[i], 3 * sizeof(double)); }
// Initialization
// Initialize for lattice -> set_lattice_vectors()
tau_idx_vdw_total_size = -1;
tau_idx_cn_total_size = -1;
for (int i = 0; i < 3; i++) {
rep_vdw[i] = -1;
rep_cn[i] = -1;
}
for (int i = 1; i < np1; i++) {
for (int j = 1; j < np1; j++) {
setflag[i][j] = 0;
}
}
for (int idx1 = 0; idx1 < np1; idx1++) {
for (int idx2 = 0; idx2 < np1; idx2++) {
for (int idx3 = 0; idx3 < MAXC; idx3++) {
for (int idx4 = 0; idx4 < MAXC; idx4++) {
for (int idx5 = 0; idx5 < 3; idx5++) {
c6ab[idx1][idx2][idx3][idx4][idx5] = -1;
}
}
}
}
}
int n_ij_combination = n * (n + 1) / 2;
cudaMallocManaged(&dc6_iji_tot, n_ij_combination * sizeof(float));
cudaMallocManaged(&dc6_ijj_tot, n_ij_combination * sizeof(float));
cudaMallocManaged(&c6_ij_tot, n_ij_combination * sizeof(float));
cudaMallocManaged(&atomtype, n * sizeof(int));
cudaMallocManaged(&disp, sizeof(double));
}
/* ----------------------------------------------------------------------
Settings : read from pair_style (Required) -> pair_style d3 vdw_sq cn_sq damp_name func_name
------------------------------------------------------------------------- */
void PairD3::settings(int narg, char **arg) {
if (narg != 4) {
error->all(FLERR,
"Pair_style d3 needs Four arguments:\n"
"\t rthr: cutoff radius for dispersion interaction (a.u.^2)\n"
"\t cnthr: cutoff raius for coordination number (a.u.^2)\n"
"\t damping: name of the damping function (e.g., damp_zero, damp_bj)\n"
"\t functional: name of the functional (e.g., pbe, b3-lyp)\n"
);
}
rthr = utils::numeric(FLERR, arg[0], false, lmp);
cnthr = utils::numeric(FLERR, arg[1], false, lmp);
std::map<std::string, int> commandMap = {
{"damp_zero", 0}, {"damp_bj", 1}, {"damp_zerom", 2}, {"damp_bjm", 3},
};
if (commandMap.find(arg[2]) == commandMap.end()) {
error->all(FLERR, "Unknown damping function");
}
damping = commandMap[arg[2]];
functional = arg[3];
setfuncpar();
}
/* ----------------------------------------------------------------------
finds atomic number (used in PairD3::coeff)
------------------------------------------------------------------------- */
int PairD3::find_atomic_number(std::string& key) {
std::transform(key.begin(), key.end(), key.begin(), ::tolower);
if (key.length() == 1) { key += " "; }
key.resize(2);
std::vector<std::string> element_table = {
"h ","he",
"li","be","b ","c ","n ","o ","f ","ne",
"na","mg","al","si","p ","s ","cl","ar",
"k ","ca","sc","ti","v ","cr","mn","fe","co","ni","cu",
"zn","ga","ge","as","se","br","kr",
"rb","sr","y ","zr","nb","mo","tc","ru","rh","pd","ag",
"cd","in","sn","sb","te","i ","xe",
"cs","ba","la","ce","pr","nd","pm","sm","eu","gd","tb","dy",
"ho","er","tm","yb","lu","hf","ta","w ","re","os","ir","pt",
"au","hg","tl","pb","bi","po","at","rn",
"fr","ra","ac","th","pa","u ","np","pu"
};
for (size_t i = 0; i < element_table.size(); ++i) {
if (element_table[i] == key) {
int atomic_number = i + 1;
return atomic_number;
}
}
// if not the case
return -1;
}
/* ----------------------------------------------------------------------
Check whether an integer value in an integer array (used in PairD3::coeff)
------------------------------------------------------------------------- */
int PairD3::is_int_in_array(int arr[], int size, int value) {
for (int i = 0; i < size; i++) {
if (arr[i] == value) { return i; } // returns the index
}
return -1;
}
/* ----------------------------------------------------------------------
Read r0ab values from the table (used in PairD3::coeff)
------------------------------------------------------------------------- */
void PairD3::read_r0ab(int* atomic_numbers, int ntypes) {
const double r0ab_table[94][94] = R0AB_TABLE;
for (int i = 1; i <= ntypes; i++) {
for (int j = 1; j <= ntypes; j++) {
r0ab[i][j] = r0ab_table[atomic_numbers[i-1]-1][atomic_numbers[j-1]-1] / AU_TO_ANG;
}
}
}
/* ----------------------------------------------------------------------
Get atom pair indices and grid indices (used in PairD3::read_c6ab)
------------------------------------------------------------------------- */
void PairD3::get_limit_in_pars_array(int& idx_atom_1, int& idx_atom_2, int& idx_i, int& idx_j) {
const int shift = 100;
idx_i = (idx_atom_1 - 1) / shift + 1;
idx_j = (idx_atom_2 - 1) / shift + 1;
idx_atom_1 = (idx_atom_1 - 1) % shift + 1;
idx_atom_2 = (idx_atom_2 - 1) % shift + 1;
// the code above replaces the code below
//idx_i = 1;
//idx_j = 1;
//int shift = 100;
//while (idx_atom_1 > shift) { idx_atom_1 -= shift; idx_i++; }
//while (idx_atom_2 > shift) { idx_atom_2 -= shift; idx_j++; }
}
/* ----------------------------------------------------------------------
Read c6ab values from the table (used in PairD3::coeff)
------------------------------------------------------------------------- */
void PairD3::read_c6ab(int* atomic_numbers, int ntypes) {
for (int i = 1; i <= ntypes; i++) { mxc[i] = 0; }
int grid_i = 0, grid_j = 0;
const double c6ab_table[32385][5] = C6AB_TABLE;
for (int i = 0; i < 32385; i++) {
const double ref_c6 = c6ab_table[i][0];
int atom_number_1 = static_cast<int>(c6ab_table[i][1]);
int atom_number_2 = static_cast<int>(c6ab_table[i][2]);
get_limit_in_pars_array(atom_number_1, atom_number_2, grid_i, grid_j);
const int idx_atom_1 = is_int_in_array(atomic_numbers, ntypes, atom_number_1);
if (idx_atom_1 < 0) { continue; }
const int idx_atom_2 = is_int_in_array(atomic_numbers, ntypes, atom_number_2);
if (idx_atom_2 < 0) { continue; }
const double ref_cn1 = c6ab_table[i][3];
const double ref_cn2 = c6ab_table[i][4];
mxc[idx_atom_1 + 1] = std::max(mxc[idx_atom_1 + 1], grid_i);
mxc[idx_atom_2 + 1] = std::max(mxc[idx_atom_2 + 1], grid_j);
c6ab[idx_atom_1 + 1][idx_atom_2 + 1][grid_i - 1][grid_j - 1][0] = ref_c6;
c6ab[idx_atom_1 + 1][idx_atom_2 + 1][grid_i - 1][grid_j - 1][1] = ref_cn1;
c6ab[idx_atom_1 + 1][idx_atom_2 + 1][grid_i - 1][grid_j - 1][2] = ref_cn2;
c6ab[idx_atom_2 + 1][idx_atom_1 + 1][grid_j - 1][grid_i - 1][0] = ref_c6;
c6ab[idx_atom_2 + 1][idx_atom_1 + 1][grid_j - 1][grid_i - 1][1] = ref_cn2;
c6ab[idx_atom_2 + 1][idx_atom_1 + 1][grid_j - 1][grid_i - 1][2] = ref_cn1;
}
}
/* ----------------------------------------------------------------------
Set functional parameters (used in PairD3::coeff)
------------------------------------------------------------------------- */
void PairD3::setfuncpar_zero() {
s6 = 1.0;
alp = 14.0;
rs18 = 1.0;
// default def2-QZVP (almost basis set limit)
std::unordered_map<std::string, int> commandMap = {
{ "slater-dirac-exchange", 1}, { "b-lyp", 2 }, { "b-p", 3 }, { "b97-d", 4 }, { "revpbe", 5 },
{ "pbe", 6 }, { "pbesol", 7 }, { "rpw86-pbe", 8 }, { "rpbe", 9 }, { "tpss", 10 },
{ "b3-lyp", 11 }, { "pbe0", 12 }, { "hse06", 13 }, { "revpbe38", 14 }, { "pw6b95", 15 },
{ "tpss0", 16 }, { "b2-plyp", 17 }, { "pwpb95", 18 }, { "b2gp-plyp", 19 }, { "ptpss", 20 },
{ "hf", 21 }, { "mpwlyp", 22 }, { "bpbe", 23 }, { "bh-lyp", 24 }, { "tpssh", 25 },
{ "pwb6k", 26 }, { "b1b95", 27 }, { "bop", 28 }, { "o-lyp", 29 }, { "o-pbe", 30 },
{ "ssb", 31 }, { "revssb", 32 }, { "otpss", 33 }, { "b3pw91", 34 }, { "revpbe0", 35 },
{ "pbe38", 36 }, { "mpw1b95", 37 }, { "mpwb1k", 38 }, { "bmk", 39 }, { "cam-b3lyp", 40 },
{ "lc-wpbe", 41 }, { "m05", 42 }, { "m052x", 43 }, { "m06l", 44 }, { "m06", 45 },
{ "m062x", 46 }, { "m06hf", 47 }, { "hcth120", 48 }
};
int commandCode = commandMap[functional];
switch (commandCode) {
case 1: rs6 = 0.999; s18 = -1.957; rs18 = 0.697; break;
case 2: rs6 = 1.094; s18 = 1.682; break;
case 3: rs6 = 1.139; s18 = 1.683; break;
case 4: rs6 = 0.892; s18 = 0.909; break;
case 5: rs6 = 0.923; s18 = 1.010; break;
case 6: rs6 = 1.217; s18 = 0.722; break;
case 7: rs6 = 1.345; s18 = 0.612; break;
case 8: rs6 = 1.224; s18 = 0.901; break;
case 9: rs6 = 0.872; s18 = 0.514; break;
case 10: rs6 = 1.166; s18 = 1.105; break;
case 11: rs6 = 1.261; s18 = 1.703; break;
case 12: rs6 = 1.287; s18 = 0.928; break;
case 13: rs6 = 1.129; s18 = 0.109; break;
case 14: rs6 = 1.021; s18 = 0.862; break;
case 15: rs6 = 1.532; s18 = 0.862; break;
case 16: rs6 = 1.252; s18 = 1.242; break;
case 17: rs6 = 1.427; s18 = 1.022; s6 = 0.64; break;
case 18: rs6 = 1.557; s18 = 0.705; s6 = 0.82; break;
case 19: rs6 = 1.586; s18 = 0.760; s6 = 0.56; break;
case 20: rs6 = 1.541; s18 = 0.879; s6 = 0.75; break;
case 21: rs6 = 1.158; s18 = 1.746; break;
case 22: rs6 = 1.239; s18 = 1.098; break;
case 23: rs6 = 1.087; s18 = 2.033; break;
case 24: rs6 = 1.370; s18 = 1.442; break;
case 25: rs6 = 1.223; s18 = 1.219; break;
case 26: rs6 = 1.660; s18 = 0.550; break;
case 27: rs6 = 1.613; s18 = 1.868; break;
case 28: rs6 = 0.929; s18 = 1.975; break;
case 29: rs6 = 0.806; s18 = 1.764; break;
case 30: rs6 = 0.837; s18 = 2.055; break;
case 31: rs6 = 1.215; s18 = 0.663; break;
case 32: rs6 = 1.221; s18 = 0.560; break;
case 33: rs6 = 1.128; s18 = 1.494; break;
case 34: rs6 = 1.176; s18 = 1.775; break;
case 35: rs6 = 0.949; s18 = 0.792; break;
case 36: rs6 = 1.333; s18 = 0.998; break;
case 37: rs6 = 1.605; s18 = 1.118; break;
case 38: rs6 = 1.671; s18 = 1.061; break;
case 39: rs6 = 1.931; s18 = 2.168; break;
case 40: rs6 = 1.378; s18 = 1.217; break;
case 41: rs6 = 1.355; s18 = 1.279; break;
case 42: rs6 = 1.373; s18 = 0.595; break;
case 43: rs6 = 1.417; s18 = 0.000; break;
case 44: rs6 = 1.581; s18 = 0.000; break;
case 45: rs6 = 1.325; s18 = 0.000; break;
case 46: rs6 = 1.619; s18 = 0.000; break;
case 47: rs6 = 1.446; s18 = 0.000; break;
/* DFTB3(zeta = 4.0), old deprecated parameters; case ("dftb3"); rs6 = 1.235; s18 = 0.673; */
case 48: rs6 = 1.221; s18 = 1.206; break;
default:
error->all(FLERR, "Functional name unknown");
break;
}
}
void PairD3::setfuncpar_bj() {
s6 = 1.0;
alp = 14.0;
std::unordered_map<std::string, int> commandMap = {
{"b-p", 1}, {"b-lyp", 2}, {"revpbe", 3}, {"rpbe", 4}, {"b97-d", 5}, {"pbe", 6},
{"rpw86-pbe", 7}, {"b3-lyp", 8}, {"tpss", 9}, {"hf", 10}, {"tpss0", 11}, {"pbe0", 12},
{"hse06", 13}, {"revpbe38", 14}, {"pw6b95", 15}, {"b2-plyp", 16}, {"dsd-blyp", 17},
{"dsd-blyp-fc", 18}, {"bop", 19}, {"mpwlyp", 20}, {"o-lyp", 21}, {"pbesol", 22}, {"bpbe", 23},
{"opbe", 24}, {"ssb", 25}, {"revssb", 26}, {"otpss", 27}, {"b3pw91", 28}, {"bh-lyp", 29},
{"revpbe0", 30}, {"tpssh", 31}, {"mpw1b95", 32}, {"pwb6k", 33}, {"b1b95", 34}, {"bmk", 35},
{"cam-b3lyp", 36}, {"lc-wpbe", 37}, {"b2gp-plyp", 38}, {"ptpss", 39}, {"pwpb95", 40},
{"hf/mixed", 41}, {"hf/sv", 42}, {"hf/minis", 43}, {"b3-lyp/6-31gd", 44}, {"hcth120", 45},
{"pw1pw", 46}, {"pwgga", 47}, {"hsesol", 48}, {"hf3c", 49}, {"hf3cv", 50}, {"pbeh3c", 51},
{"pbeh-3c", 52}, {"wb97m", 53}
};
int commandCode = commandMap[functional];
switch (commandCode) {
case 1: rs6 = 0.3946; s18 = 3.2822; rs18 = 4.8516; break;
case 2: rs6 = 0.4298; s18 = 2.6996; rs18 = 4.2359; break;
case 3: rs6 = 0.5238; s18 = 2.3550; rs18 = 3.5016; break;
case 4: rs6 = 0.1820; s18 = 0.8318; rs18 = 4.0094; break;
case 5: rs6 = 0.5545; s18 = 2.2609; rs18 = 3.2297; break;
case 6: rs6 = 0.4289; s18 = 0.7875; rs18 = 4.4407; break;
case 7: rs6 = 0.4613; s18 = 1.3845; rs18 = 4.5062; break;
case 8: rs6 = 0.3981; s18 = 1.9889; rs18 = 4.4211; break;
case 9: rs6 = 0.4535; s18 = 1.9435; rs18 = 4.4752; break;
case 10: rs6 = 0.3385; s18 = 0.9171; rs18 = 2.8830; break;
case 11: rs6 = 0.3768; s18 = 1.2576; rs18 = 4.5865; break;
case 12: rs6 = 0.4145; s18 = 1.2177; rs18 = 4.8593; break;
case 13: rs6 = 0.383; s18 = 2.310; rs18 = 5.685; break;
case 14: rs6 = 0.4309; s18 = 1.4760; rs18 = 3.9446; break;
case 15: rs6 = 0.2076; s18 = 0.7257; rs18 = 6.3750; break;
case 16: rs6 = 0.3065; s18 = 0.9147; rs18 = 5.0570; break; s6 = 0.64;
case 17: rs6 = 0.0000; s18 = 0.2130; rs18 = 6.0519; s6 = 0.50; break;
case 18: rs6 = 0.0009; s18 = 0.2112; rs18 = 5.9807; s6 = 0.50; break;
case 19: rs6 = 0.4870; s18 = 3.2950; rs18 = 3.5043; break;
case 20: rs6 = 0.4831; s18 = 2.0077; rs18 = 4.5323; break;
case 21: rs6 = 0.5299; s18 = 2.6205; rs18 = 2.8065; break;
case 22: rs6 = 0.4466; s18 = 2.9491; rs18 = 6.1742; break;
case 23: rs6 = 0.4567; s18 = 4.0728; rs18 = 4.3908; break;
case 24: rs6 = 0.5512; s18 = 3.3816; rs18 = 2.9444; break;
case 25: rs6 = -0.0952; s18 = -0.1744; rs18 = 5.2170; break;
case 26: rs6 = 0.4720; s18 = 0.4389; rs18 = 4.0986; break;
case 27: rs6 = 0.4634; s18 = 2.7495; rs18 = 4.3153; break;
case 28: rs6 = 0.4312; s18 = 2.8524; rs18 = 4.4693; break;
case 29: rs6 = 0.2793; s18 = 1.0354; rs18 = 4.9615; break;
case 30: rs6 = 0.4679; s18 = 1.7588; rs18 = 3.7619; break;
case 31: rs6 = 0.4529; s18 = 2.2382; rs18 = 4.6550; break;
case 32: rs6 = 0.1955; s18 = 1.0508; rs18 = 6.4177; break;
case 33: rs6 = 0.1805; s18 = 0.9383; rs18 = 7.7627; break;
case 34: rs6 = 0.2092; s18 = 1.4507; rs18 = 5.5545; break;
case 35: rs6 = 0.1940; s18 = 2.0860; rs18 = 5.9197; break;
case 36: rs6 = 0.3708; s18 = 2.0674; rs18 = 5.4743; break;
case 37: rs6 = 0.3919; s18 = 1.8541; rs18 = 5.0897; break;
case 38: rs6 = 0.0000; s18 = 0.2597; rs18 = 6.3332; s6 = 0.560; break;
case 39: rs6 = 0.0000; s18 = 0.2804; rs18 = 6.5745; s6 = 0.750; break;
case 40: rs6 = 0.0000; s18 = 0.2904; rs18 = 7.3141; s6 = 0.820; break;
// special HF / DFT with eBSSE correction;
case 41: rs6 = 0.5607; s18 = 3.9027; rs18 = 4.5622; break;
case 42: rs6 = 0.4249; s18 = 2.1849; rs18 = 4.2783; break;
case 43: rs6 = 0.1702; s18 = 0.9841; rs18 = 3.8506; break;
case 44: rs6 = 0.5014; s18 = 4.0672; rs18 = 4.8409; break;
case 45: rs6 = 0.3563; s18 = 1.0821; rs18 = 4.3359; break;
/* DFTB3 old, deprecated parameters : ;
* case ("dftb3"); rs6 = 0.7461; s18 = 3.209; rs18 = 4.1906;
* special SCC - DFTB parametrization;
* full third order DFTB, self consistent charges, hydrogen pair damping with; exponent 4.2;
*/
case 46: rs6 = 0.3807; s18 = 2.3363; rs18 = 5.8844; break;
case 47: rs6 = 0.2211; s18 = 2.6910; rs18 = 6.7278; break;
case 48: rs6 = 0.4650; s18 = 2.9215; rs18 = 6.2003; break;
// special HF - D3 - gCP - SRB / MINIX parametrization;
case 49: rs6 = 0.4171; s18 = 0.8777; rs18 = 2.9149; break;
// special HF - D3 - gCP - SRB2 / ECP - 2G parametrization;
case 50: rs6 = 0.3063; s18 = 0.5022; rs18 = 3.9856; break;
// special PBEh - D3 - gCP / def2 - mSVP parametrization;
case 51: rs6 = 0.4860; s18 = 0.0000; rs18 = 4.5000; break;
case 52: rs6 = 0.4860; s18 = 0.0000; rs18 = 4.5000; break;
case 53: rs6 = 0.5660; s18 = 0.3908; rs18 = 3.1280; break;
default:
error->all(FLERR, "Functional name unknown");
break;
}
}
void PairD3::setfuncpar_zerom() {
s6 = 1.0;
alp = 14.0;
std::unordered_map<std::string, int> commandMap = {
{"b2-plyp", 1}, {"b3-lyp", 2}, {"b97-d", 3}, {"b-lyp", 4},
{"b-p", 5}, {"pbe", 6}, {"pbe0", 7}, {"lc-wpbe", 8}
};
int commandCode = commandMap[functional];
switch (commandCode) {
case 1: rs6 = 1.313134; s18 = 0.717543; rs18 = 0.016035; s6 = 0.640000; break;
case 2: rs6 = 1.338153; s18 = 1.532981; rs18 = 0.013988; break;
case 3: rs6 = 1.151808; s18 = 1.020078; rs18 = 0.035964; break;
case 4: rs6 = 1.279637; s18 = 1.841686; rs18 = 0.014370; break;
case 5: rs6 = 1.233460; s18 = 1.945174; rs18 = 0.000000; break;
case 6: rs6 = 2.340218; s18 = 0.000000; rs18 = 0.129434; break;
case 7: rs6 = 2.077949; s18 = 0.000081; rs18 = 0.116755; break;
case 8: rs6 = 1.366361; s18 = 1.280619; rs18 = 0.003160; break;
default:
error->all(FLERR, "Functional name unknown");
break;
}
}
void PairD3::setfuncpar_bjm() {
s6 = 1.0;
alp = 14.0;
std::unordered_map<std::string, int> commandMap = {
{"b2-plyp", 1}, {"b3-lyp", 2}, {"b97-d", 3}, {"b-lyp", 4},
{"b-p", 5}, {"pbe", 6}, {"pbe0", 7}, {"lc-wpbe", 8}
};
int commandCode = commandMap[functional];
switch (commandCode) {
case 1: rs6 = 0.486434; s18 = 0.672820; rs18 = 3.656466; s6 = 0.640000; break;
case 2: rs6 = 0.278672; s18 = 1.466677; rs18 = 4.606311; break;
case 3: rs6 = 0.240184; s18 = 1.206988; rs18 = 3.864426; break;
case 4: rs6 = 0.448486; s18 = 1.875007; rs18 = 3.610679; break;
case 5: rs6 = 0.821850; s18 = 3.140281; rs18 = 2.728151; break;
case 6: rs6 = 0.012092; s18 = 0.358940; rs18 = 5.938951; break;
case 7: rs6 = 0.007912; s18 = 0.528823; rs18 = 6.162326; break;
case 8: rs6 = 0.563761; s18 = 0.906564; rs18 = 3.593680; break;
default:
error->all(FLERR, "Functional name unknown");
break;
}
}
void PairD3::setfuncpar() {
void (PairD3::*setfuncpar_damp[4])() = {
&PairD3::setfuncpar_zero,
&PairD3::setfuncpar_bj,
&PairD3::setfuncpar_zerom,
&PairD3::setfuncpar_bjm
};
(this->*setfuncpar_damp[damping])();
rs8 = rs18;
alp6 = alp;
alp8 = alp + 2.0;
// rs10 = rs18
// alp10 = alp + 4.0;
a1 = rs6;
a2 = rs8;
s8 = s18;
// s6 is already defined
}
/* ----------------------------------------------------------------------
Coeff : read from pair_coeff (Required) -> pair_coeff * * element1 element2 ...
------------------------------------------------------------------------- */
void PairD3::coeff(int narg, char **arg) {
if (!allocated) allocate();
int ntypes = atom->ntypes;
if (narg != ntypes + 2) { error->all(FLERR, "Pair_coeff needs: * * element1 element2 ..."); }
std::string element;
int* atomic_numbers = (int*)malloc(sizeof(int)*ntypes);
for (int i = 0; i < ntypes; i++) {
element = arg[i+2];
atomic_numbers[i] = find_atomic_number(element);
}
int count = 0;
for (int i = 1; i <= ntypes; i++) {
for (int j = 1; j <= ntypes; j++) {
setflag[i][j] = 1;
count++;
}
}
if (count == 0) error->all(FLERR, "Incorrect args for pair coefficients");
/*
scale r4/r2 values of the atoms by sqrt(Z)
sqrt is also globally close to optimum
together with the factor 1/2 this yield reasonable
c8 for he, ne and ar. for larger Z, C8 becomes too large
which effectively mimics higher R^n terms neglected due
to stability reasons
r2r4 =sqrt(0.5*r2r4(i)*dfloat(i)**0.5 ) with i=elementnumber
the large number of digits is just to keep the results consistent
with older versions. They should not imply any higher accuracy than
the old values
*/
double r2r4_ref[94] = {
2.00734898, 1.56637132, 5.01986934, 3.85379032, 3.64446594,
3.10492822, 2.71175247, 2.59361680, 2.38825250, 2.21522516,
6.58585536, 5.46295967, 5.65216669, 4.88284902, 4.29727576,
4.04108902, 3.72932356, 3.44677275, 7.97762753, 7.07623947,
6.60844053, 6.28791364, 6.07728703, 5.54643096, 5.80491167,
5.58415602, 5.41374528, 5.28497229, 5.22592821, 5.09817141,
6.12149689, 5.54083734, 5.06696878, 4.87005108, 4.59089647,
4.31176304, 9.55461698, 8.67396077, 7.97210197, 7.43439917,
6.58711862, 6.19536215, 6.01517290, 5.81623410, 5.65710424,
5.52640661, 5.44263305, 5.58285373, 7.02081898, 6.46815523,
5.98089120, 5.81686657, 5.53321815, 5.25477007, 11.02204549,
10.15679528, 9.35167836, 9.06926079, 8.97241155, 8.90092807,
8.85984840, 8.81736827, 8.79317710, 7.89969626, 8.80588454,
8.42439218, 8.54289262, 8.47583370, 8.45090888, 8.47339339,
7.83525634, 8.20702843, 7.70559063, 7.32755997, 7.03887381,
6.68978720, 6.05450052, 5.88752022, 5.70661499, 5.78450695,
7.79780729, 7.26443867, 6.78151984, 6.67883169, 6.39024318,
6.09527958, 11.79156076, 11.10997644, 9.51377795, 8.67197068,
8.77140725, 8.65402716, 8.53923501, 8.85024712
}; // atomic <r^2>/<r^4> values
/*
covalent radii (taken from Pyykko and Atsumi, Chem. Eur. J. 15, 2009, 188-197)
values for metals decreased by 10 %
! data rcov/
! . 0.32, 0.46, 1.20, 0.94, 0.77, 0.75, 0.71, 0.63, 0.64, 0.67
! ., 1.40, 1.25, 1.13, 1.04, 1.10, 1.02, 0.99, 0.96, 1.76, 1.54
! ., 1.33, 1.22, 1.21, 1.10, 1.07, 1.04, 1.00, 0.99, 1.01, 1.09
! ., 1.12, 1.09, 1.15, 1.10, 1.14, 1.17, 1.89, 1.67, 1.47, 1.39
! ., 1.32, 1.24, 1.15, 1.13, 1.13, 1.08, 1.15, 1.23, 1.28, 1.26
! ., 1.26, 1.23, 1.32, 1.31, 2.09, 1.76, 1.62, 1.47, 1.58, 1.57
! ., 1.56, 1.55, 1.51, 1.52, 1.51, 1.50, 1.49, 1.49, 1.48, 1.53
! ., 1.46, 1.37, 1.31, 1.23, 1.18, 1.16, 1.11, 1.12, 1.13, 1.32
! ., 1.30, 1.30, 1.36, 1.31, 1.38, 1.42, 2.01, 1.81, 1.67, 1.58
! ., 1.52, 1.53, 1.54, 1.55 /
these new data are scaled with k2=4./3. and converted a_0 via
autoang=0.52917726d0
*/
double rcov_ref[94] = {
0.80628308, 1.15903197, 3.02356173, 2.36845659, 1.94011865,
1.88972601, 1.78894056, 1.58736983, 1.61256616, 1.68815527,
3.52748848, 3.14954334, 2.84718717, 2.62041997, 2.77159820,
2.57002732, 2.49443835, 2.41884923, 4.43455700, 3.88023730,
3.35111422, 3.07395437, 3.04875805, 2.77159820, 2.69600923,
2.62041997, 2.51963467, 2.49443835, 2.54483100, 2.74640188,
2.82199085, 2.74640188, 2.89757982, 2.77159820, 2.87238349,
2.94797246, 4.76210950, 4.20778980, 3.70386304, 3.50229216,
3.32591790, 3.12434702, 2.89757982, 2.84718717, 2.84718717,
2.72120556, 2.89757982, 3.09915070, 3.22513231, 3.17473967,
3.17473967, 3.09915070, 3.32591790, 3.30072128, 5.26603625,
4.43455700, 4.08180818, 3.70386304, 3.98102289, 3.95582657,
3.93062995, 3.90543362, 3.80464833, 3.82984466, 3.80464833,
3.77945201, 3.75425569, 3.75425569, 3.72905937, 3.85504098,
3.67866672, 3.45189952, 3.30072128, 3.09915070, 2.97316878,
2.92277614, 2.79679452, 2.82199085, 2.84718717, 3.32591790,
3.27552496, 3.27552496, 3.42670319, 3.30072128, 3.47709584,
3.57788113, 5.06446567, 4.56053862, 4.20778980, 3.98102289,
3.82984466, 3.85504098, 3.88023730, 3.90543362
}; // covalent radii
for (int i = 0; i < ntypes; i++) {
r2r4[i+1] = r2r4_ref[atomic_numbers[i]-1];
rcov[i+1] = rcov_ref[atomic_numbers[i]-1];
}
// set r0ab
read_r0ab(atomic_numbers, ntypes);
// read c6ab
read_c6ab(atomic_numbers, ntypes);
free(atomic_numbers);
}
/* ----------------------------------------------------------------------
Get derivative of C6 w.r.t. CN (used in PairD3::compute)
C6 = C6(CN_A, CN_B) == W(CN_A, CN_B) / Z(CN_A, CN_B)
This gives below from chain rule:
d(C6)/dr = d(C6)/d(CN_A) * d(CN_A)/dr + d(C6)/d(CN_B) * d(CN_B)/dr
So we can pre-calculate the d(C6)/d(CN_A), d(C6)/d(CN_B) part.
d(C6)/d(CN_i) = (dW/d(CN_i) * Z - W * dZ/d(CN_i)) / (W * W)
W : "denominator"
Z : "numerator"
dW/d(CN_i) : "d_denominator_i"
dZ/d(CN_j) : "d_numerator_j"
Z = Sum( L_ij(CN_A, CN_B) * C6_ref(CN_A_i, CN_B_j) ) over i, j
W = Sum( L_ij(CN_A, CN_B) ) over i, j
And the resulting derivative term is saved into
"dc6_iji_tot", "dc6_ijj_tot" array,
where we can find the value of d(C6)/d(CN_i)
by knowing the index of "iat", and "jat". ("idx_linij")
Also, c6 values will also be saved into "c6_ij_tot" array.
Here, as we only interested in *pair* of atoms, assume "iat" >= "jat".
Then "idx_linij" = "jat + (iat + 1) * iat / 2" have the order below.
idx_linij | j = 0 j = 1 j = 2 j = 3 ...
---------------------------------------------
i = 0 | 0
i = 1 | 1 2
i = 2 | 3 4 5
i = 3 | 6 7 8 9
... | ... ... ... ... ...
------------------------------------------------------------------------- */
__global__ void kernel_get_dC6_dCNij(
int maxij, float K3,
double *cn, int *mxc, float *****c6ab, int *type,
float *c6_ij_tot, float *dc6_iji_tot, float *dc6_ijj_tot
) {
int iter = blockIdx.x * blockDim.x + threadIdx.x;
if (iter < maxij) {
int iat, jat;
ij_at_linij(iter, iat, jat);
const int atomtype_i = type[iat];
const int atomtype_j = type[jat];
const float cni = cn[iat];
const int mxci = mxc[atomtype_i];
const float cnj = cn[jat];
const int mxcj = mxc[atomtype_j];
float c6mem = -1e99f;
float r_save = 9999.0f;
double numerator = 0.0;
double denominator = 0.0;
double d_numerator_i = 0.0;
double d_denominator_i = 0.0;
double d_numerator_j = 0.0;
double d_denominator_j = 0.0;
for (int a = 0; a < mxci; a++) {
for (int b = 0; b < mxcj; b++) {
float c6ref = c6ab[atomtype_i][atomtype_j][a][b][0];
if (c6ref > 0.0f) {
float cn_refi = c6ab[atomtype_i][atomtype_j][a][b][1];
float cn_refj = c6ab[atomtype_i][atomtype_j][a][b][2];
float r = (cn_refi - cni) * (cn_refi - cni) + (cn_refj - cnj) * (cn_refj - cnj);
if (r < r_save) {
r_save = r;
c6mem = c6ref;
}
double expterm = exp(static_cast<double>(K3) * static_cast<double>(r)); // must be double
numerator += c6ref * expterm;
denominator += expterm;
expterm *= 2.0f * K3;
double term = expterm * (cni - cn_refi);
d_numerator_i += c6ref * term;
d_denominator_i += term;
term = expterm * (cnj - cn_refj);
d_numerator_j += c6ref * term;
d_denominator_j += term;
}
}
}
if (denominator > 1e-99) {
const double denominator_rc = 1.0 / denominator; // must be double
const double unit_frac = numerator * denominator_rc;
c6_ij_tot[iter] = unit_frac;
dc6_iji_tot[iter] = denominator_rc * fma(unit_frac, -d_denominator_i, d_numerator_i); // must be double
dc6_ijj_tot[iter] = denominator_rc * fma(unit_frac, -d_denominator_j, d_numerator_j); // must be double
//const double denominator_rc = 1.0 / denominator;
//const float unit_frac = numerator * denominator_rc;
//c6_ij_tot[iter] = unit_frac;
//dc6_iji_tot[iter] = \
static_cast<float>(d_numerator_i * denominator_rc) - static_cast<float>(d_denominator_i * denominator_rc) * unit_frac;
//dc6_ijj_tot[iter] = \
static_cast<float>(d_numerator_j * denominator_rc) - static_cast<float>(d_denominator_j * denominator_rc) * unit_frac;
}
else {
c6_ij_tot[iter] = c6mem;
dc6_iji_tot[iter] = 0.0f;
dc6_ijj_tot[iter] = 0.0f;
}
}
}
void PairD3::get_dC6_dCNij() {
int n = atom->natoms;
int maxij = n * (n + 1) / 2;
//START_CUDA_TIMER();
int threadsPerBlock = 128;
int blocksPerGrid = (maxij + threadsPerBlock - 1) / threadsPerBlock;
kernel_get_dC6_dCNij<<<blocksPerGrid, threadsPerBlock>>>(
maxij, K3,
cn, mxc, c6ab, atomtype,
c6_ij_tot, dc6_iji_tot, dc6_ijj_tot
);
cudaDeviceSynchronize();
//STOP_CUDA_TIMER("get_dC6dCNij");
}
/* ----------------------------------------------------------------------
Get lattice vectors (used in PairD3::compute)
1) Save lattice vectors into "lat_v_1", "lat_v_2", "lat_v_3"
2) Calculate repetition criteria for vdw, cn
3) precaluclate tau (xyz shift due to cell repetition)
------------------------------------------------------------------------- */
void PairD3::set_lattice_vectors() {
double boxxlo = domain->boxlo[0];
double boxxhi = domain->boxhi[0];
double boxylo = domain->boxlo[1];
double boxyhi = domain->boxhi[1];
double boxzlo = domain->boxlo[2];
double boxzhi = domain->boxhi[2];
double xy = domain->xy;
double xz = domain->xz;
double yz = domain->yz;
lat_v_1[0] = (boxxhi - boxxlo) / AU_TO_ANG;
lat_v_1[1] = 0.0;
lat_v_1[2] = 0.0;
lat_v_2[0] = xy / AU_TO_ANG;
lat_v_2[1] = (boxyhi - boxylo) / AU_TO_ANG;
lat_v_2[2] = 0.0;
lat_v_3[0] = xz / AU_TO_ANG;
lat_v_3[1] = yz / AU_TO_ANG;
lat_v_3[2] = (boxzhi - boxzlo) / AU_TO_ANG;
int vdwrx_save = 2 * rep_vdw[0] + 1;
int vdwry_save = 2 * rep_vdw[1] + 1;
int vdwrz_save = 2 * rep_vdw[2] + 1;
int cnrx_save = 2 * rep_cn[0] + 1;
int cnry_save = 2 * rep_cn[1] + 1;
int cnrz_save = 2 * rep_cn[2] + 1;
set_lattice_repetition_criteria(rthr, rep_vdw);
set_lattice_repetition_criteria(cnthr, rep_cn);
int vdw_range_x = 2 * rep_vdw[0] + 1;
int vdw_range_y = 2 * rep_vdw[1] + 1;
int vdw_range_z = 2 * rep_vdw[2] + 1;
int tau_loop_size_vdw = vdw_range_x * vdw_range_y * vdw_range_z * 3;
if (tau_loop_size_vdw != tau_idx_vdw_total_size) {
if (tau_idx_vdw != nullptr) {
for (int i = 0; i < vdwrx_save; i++) {
for (int j = 0; j < vdwry_save; j++) {
for (int k = 0; k < vdwrz_save; k++) {
cudaFree(tau_vdw[i][j][k]);
}
cudaFree(tau_vdw[i][j]);
}
cudaFree(tau_vdw[i]);
}
cudaFree(tau_vdw);
cudaFree(tau_idx_vdw);
}
tau_idx_vdw_total_size = tau_loop_size_vdw;
cudaMallocManaged(&tau_vdw, vdw_range_x * sizeof(float***));
for (int i = 0; i < vdw_range_x; i++) {
cudaMallocManaged(&tau_vdw[i], vdw_range_y * sizeof(float**));
for (int j = 0; j < vdw_range_y; j++) {
cudaMallocManaged(&tau_vdw[i][j], vdw_range_z * sizeof(float*));
for (int k = 0; k < vdw_range_z; k++) {
cudaMallocManaged(&tau_vdw[i][j][k], 3 * sizeof(float));
}
}
}
cudaMallocManaged(&tau_idx_vdw, tau_idx_vdw_total_size * sizeof(int));
}
int cn_range_x = 2 * rep_cn[0] + 1;
int cn_range_y = 2 * rep_cn[1] + 1;
int cn_range_z = 2 * rep_cn[2] + 1;
int tau_loop_size_cn = cn_range_x * cn_range_y * cn_range_z * 3;
if (tau_loop_size_cn != tau_idx_cn_total_size) {
if (tau_idx_cn != nullptr) {
for (int i = 0; i < cnrx_save; i++) {
for (int j = 0; j < cnry_save; j++) {
for (int k = 0; k < cnrz_save; k++) {
cudaFree(tau_cn[i][j][k]);
}
cudaFree(tau_cn[i][j]);
}
cudaFree(tau_cn[i]);
}
cudaFree(tau_cn);
cudaFree(tau_idx_cn);
}
tau_idx_cn_total_size = tau_loop_size_cn;
cudaMallocManaged(&tau_cn, cn_range_x * sizeof(float***));
for (int i = 0; i < cn_range_x; i++) {
cudaMallocManaged(&tau_cn[i], cn_range_y * sizeof(float**));
for (int j = 0; j < cn_range_y; j++) {
cudaMallocManaged(&tau_cn[i][j], cn_range_z * sizeof(float*));
for (int k = 0; k < cn_range_z; k++) {
cudaMallocManaged(&tau_cn[i][j][k], 3 * sizeof(float));
}
}
}
cudaMallocManaged(&tau_idx_cn, tau_idx_cn_total_size * sizeof(int));
}
}
/* ----------------------------------------------------------------------
Set repetition criteria (used in PairD3::compute)
Needed as Periodic Boundary Condition should be considered.
As the cell may *not* be orthorhombic,
the dot product should be used between x/y/z direction and
corresponding cross product vector.
------------------------------------------------------------------------- */
void PairD3::set_lattice_repetition_criteria(float r_threshold, int* rep_v) {
double r_cutoff = sqrt(r_threshold);
double lat_cp_12[3], lat_cp_23[3], lat_cp_31[3];
double cos_value;
MathExtra::cross3(lat_v_1, lat_v_2, lat_cp_12);
MathExtra::cross3(lat_v_2, lat_v_3, lat_cp_23);
MathExtra::cross3(lat_v_3, lat_v_1, lat_cp_31);
cos_value = MathExtra::dot3(lat_cp_23, lat_v_1) / MathExtra::len3(lat_cp_23);
rep_v[0] = static_cast<int>(std::abs(r_cutoff / cos_value)) + 1;
cos_value = MathExtra::dot3(lat_cp_31, lat_v_2) / MathExtra::len3(lat_cp_31);
rep_v[1] = static_cast<int>(std::abs(r_cutoff / cos_value)) + 1;
cos_value = MathExtra::dot3(lat_cp_12, lat_v_3) / MathExtra::len3(lat_cp_12);
rep_v[2] = static_cast<int>(std::abs(r_cutoff / cos_value)) + 1;
if (domain->xperiodic == 0) { rep_v[0] = 0; }
if (domain->yperiodic == 0) { rep_v[1] = 0; }
if (domain->zperiodic == 0) { rep_v[2] = 0; }
}
/* ----------------------------------------------------------------------
Calculate Coordination Number (used in PairD3::compute)
------------------------------------------------------------------------- */
__global__ void kernel_get_coordination_number(
int maxij, int maxtau, float cnthr, float K1,
float *rcov, int *rep_cn, float ****tau_cn, int *tau_idx_cn, int *type, float **x,
double *cn
) {
int iter = blockIdx.x * blockDim.x + threadIdx.x;
if (iter < maxij) {
int iat, jat;
ij_at_linij(iter, iat, jat);
float cn_local = 0.0f;
if (iat == jat) {
const float rcov_sum = rcov[type[iat]] * 2.0f;
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_cn[k-2];
const int idx2 = tau_idx_cn[k-1];
const int idx3 = tau_idx_cn[k];
if (idx1 == rep_cn[0] && idx2 == rep_cn[1] && idx3 == rep_cn[2]) { continue; }
const float rx = tau_cn[idx1][idx2][idx3][0];
const float ry = tau_cn[idx1][idx2][idx3][1];
const float rz = tau_cn[idx1][idx2][idx3][2];
const float r2 = rx * rx + ry * ry + rz * rz;
if (r2 <= cnthr) {
const float r_rc = rsqrtf(r2);
const float damp = 1.0f / (1.0f + expf(-K1 * ((rcov_sum * r_rc) - 1.0f)));
cn_local += damp;
}
}
atomicAdd(&cn[iat], cn_local);
}
else {
const float rcov_sum = rcov[type[iat]] + rcov[type[jat]];
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_cn[k-2];
const int idx2 = tau_idx_cn[k-1];
const int idx3 = tau_idx_cn[k];
const float rx = x[jat][0] - x[iat][0] + tau_cn[idx1][idx2][idx3][0];
const float ry = x[jat][1] - x[iat][1] + tau_cn[idx1][idx2][idx3][1];
const float rz = x[jat][2] - x[iat][2] + tau_cn[idx1][idx2][idx3][2];
const float r2 = rx * rx + ry * ry + rz * rz;
if (r2 <= cnthr) {
const float r_rc = rsqrtf(r2);
const float damp = 1.0f / (1.0f + expf(-K1 * ((rcov_sum * r_rc) - 1.0f)));
cn_local += damp;
}
}
atomicAdd(&cn[iat], cn_local);
atomicAdd(&cn[jat], cn_local);
}
}
}
void PairD3::get_coordination_number() {
int n = atom->natoms;
int maxij = n * (n + 1) / 2;
int maxtau = tau_idx_cn_total_size;
for (int i = 0; i < n; i++) {
cn[i] = 0.0;
}
//START_CUDA_TIMER();
int threadsPerBlock = 128;
int blocksPerGrid = (maxij + threadsPerBlock - 1) / threadsPerBlock;
kernel_get_coordination_number<<<blocksPerGrid, threadsPerBlock>>>(
maxij, maxtau, cnthr, K1,
rcov, rep_cn, tau_cn, tau_idx_cn, atomtype, x,
cn
);
cudaDeviceSynchronize();
//STOP_CUDA_TIMER("get_coord");
}
/* ----------------------------------------------------------------------
reallcate memory if the number of atoms has changed (used in PairD3::compute)
------------------------------------------------------------------------- */
void PairD3::reallocate_arrays() {
/* -------------- Destroy previous arrays -------------- */
cudaFree(cn);
for (int i = 0; i < n_save; i++) { cudaFree(x[i]); }; cudaFree(x);
cudaFree(dc6i);
for (int i = 0; i < n_save; i++) { cudaFree(f[i]); }; cudaFree(f);
cudaFree(dc6_iji_tot);
cudaFree(dc6_ijj_tot);
cudaFree(c6_ij_tot);
cudaFree(atomtype);
/* -------------- Destroy previous arrays -------------- */
/* -------------- Create new arrays -------------- */
int n = atom->natoms;
n_save = n;
cudaMallocManaged(&cn, n * sizeof(double));
cudaMallocManaged(&x, n * sizeof(float*)); for (int i = 0; i < n; i++) { cudaMallocManaged(&x[i], 3 * sizeof(float)); }
cudaMallocManaged(&dc6i, n * sizeof(double));
cudaMallocManaged(&f, n * sizeof(double*)); for (int i = 0; i < n; i++) { cudaMallocManaged(&f[i], 3 * sizeof(double)); }
int n_ij_combination = n * (n + 1) / 2;
cudaMallocManaged(&dc6_iji_tot, n_ij_combination * sizeof(float));
cudaMallocManaged(&dc6_ijj_tot, n_ij_combination * sizeof(float));
cudaMallocManaged(&c6_ij_tot, n_ij_combination * sizeof(float));
cudaMallocManaged(&atomtype, n * sizeof(int));
/* -------------- Create new arrays -------------- */
}
/* ----------------------------------------------------------------------
Initialize atomic positions & types (used in PairD3::compute)
As the default xyz from lammps does not assure that atoms are within unit cell,
this function shifts atoms into the unit cell.
------------------------------------------------------------------------- */
void PairD3::load_atom_info() {
double lat[3][3];
lat[0][0] = lat_v_1[0];
lat[0][1] = lat_v_2[0];
lat[0][2] = lat_v_3[0];
lat[1][0] = lat_v_1[1];
lat[1][1] = lat_v_2[1];
lat[1][2] = lat_v_3[1];
lat[2][0] = lat_v_1[2];
lat[2][1] = lat_v_2[2];
lat[2][2] = lat_v_3[2];
double det = lat[0][0] * lat[1][1] * lat[2][2]
+ lat[0][1] * lat[1][2] * lat[2][0]
+ lat[0][2] * lat[1][0] * lat[2][1]
- lat[0][2] * lat[1][1] * lat[2][0]
- lat[0][1] * lat[1][0] * lat[2][2]
- lat[0][0] * lat[1][2] * lat[2][1];
double lat_inv[3][3];
lat_inv[0][0] = (lat[1][1] * lat[2][2] - lat[1][2] * lat[2][1]) / det;
lat_inv[1][0] = (lat[1][2] * lat[2][0] - lat[1][0] * lat[2][2]) / det;
lat_inv[2][0] = (lat[1][0] * lat[2][1] - lat[1][1] * lat[2][0]) / det;
lat_inv[0][1] = (lat[0][2] * lat[2][1] - lat[0][1] * lat[2][2]) / det;
lat_inv[1][1] = (lat[0][0] * lat[2][2] - lat[0][2] * lat[2][0]) / det;
lat_inv[2][1] = (lat[0][1] * lat[2][0] - lat[0][0] * lat[2][1]) / det;
lat_inv[0][2] = (lat[0][1] * lat[1][2] - lat[0][2] * lat[1][1]) / det;
lat_inv[1][2] = (lat[0][2] * lat[1][0] - lat[0][0] * lat[1][2]) / det;
lat_inv[2][2] = (lat[0][0] * lat[1][1] - lat[0][1] * lat[1][0]) / det;
double a[3] = { 0.0 };
for (int iat = 0; iat < atom->natoms; iat++) {
for (int i = 0; i < 3; i++) {
a[i] = lat_inv[i][0] * (atom->x)[iat][0] / AU_TO_ANG +
lat_inv[i][1] * (atom->x)[iat][1] / AU_TO_ANG +
lat_inv[i][2] * (atom->x)[iat][2] / AU_TO_ANG;
a[i] -= floor(a[i]); // replaces the code below
//if (a[i] > 1) { while (a[i] > 1) { a[i]--; } }
//else if (a[i] < 0) { while (a[i] < 0) { a[i]++; } }
}
for (int i = 0; i < 3; i++) {
x[iat][i] = (lat[i][0] * a[0] + lat[i][1] * a[1] + lat[i][2] * a[2]);
}
}
}
/* ----------------------------------------------------------------------
Precalculate tau array
------------------------------------------------------------------------- */
void PairD3::precalculate_tau_array() {
int xlim = rep_vdw[0];
int ylim = rep_vdw[1];
int zlim = rep_vdw[2];
int index = 0;
for (int taux = -xlim; taux <= xlim; taux++) {
for (int tauy = -ylim; tauy <= ylim; tauy++) {
for (int tauz = -zlim; tauz <= zlim; tauz++) {
tau_vdw[taux + xlim][tauy + ylim][tauz + zlim][0] = lat_v_1[0] * taux + lat_v_2[0] * tauy + lat_v_3[0] * tauz;
tau_vdw[taux + xlim][tauy + ylim][tauz + zlim][1] = lat_v_1[1] * taux + lat_v_2[1] * tauy + lat_v_3[1] * tauz;
tau_vdw[taux + xlim][tauy + ylim][tauz + zlim][2] = lat_v_1[2] * taux + lat_v_2[2] * tauy + lat_v_3[2] * tauz;
tau_idx_vdw[index++] = taux + xlim;
tau_idx_vdw[index++] = tauy + ylim;
tau_idx_vdw[index++] = tauz + zlim;
}
}
}
xlim = rep_cn[0];
ylim = rep_cn[1];
zlim = rep_cn[2];
index = 0;
for (int taux = -xlim; taux <= xlim; taux++) {
for (int tauy = -ylim; tauy <= ylim; tauy++) {
for (int tauz = -zlim; tauz <= zlim; tauz++) {
tau_cn[taux + xlim][tauy + ylim][tauz + zlim][0] = lat_v_1[0] * taux + lat_v_2[0] * tauy + lat_v_3[0] * tauz;
tau_cn[taux + xlim][tauy + ylim][tauz + zlim][1] = lat_v_1[1] * taux + lat_v_2[1] * tauy + lat_v_3[1] * tauz;
tau_cn[taux + xlim][tauy + ylim][tauz + zlim][2] = lat_v_1[2] * taux + lat_v_2[2] * tauy + lat_v_3[2] * tauz;
tau_idx_cn[index++] = taux + xlim;
tau_idx_cn[index++] = tauy + ylim;
tau_idx_cn[index++] = tauz + zlim;
}
}
}
}
/* ----------------------------------------------------------------------
Get forces (Zero damping)
------------------------------------------------------------------------- */
__global__ void kernel_get_forces_without_dC6_zero(
int maxij, int maxtau, float rthr, float s6, float s8, float a1, float a2, float alp6, float alp8,
float *r2r4, float **r0ab, int *rep_vdw, float ****tau_vdw, int *tau_idx_vdw, int *type, float **x,
float *c6_ij_tot, float *dc6_iji_tot, float *dc6_ijj_tot,
double *dc6i, double *disp, double **f, double **sigma
) {
int iter = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ float sigma_00[128];
__shared__ float sigma_01[128];
__shared__ float sigma_02[128];
__shared__ float sigma_10[128];
__shared__ float sigma_11[128];
__shared__ float sigma_12[128];
__shared__ float sigma_20[128];
__shared__ float sigma_21[128];
__shared__ float sigma_22[128];
__shared__ float disp_shared[128];
float sigma_local_00 = 0.0f;
float sigma_local_01 = 0.0f;
float sigma_local_02 = 0.0f;
float sigma_local_10 = 0.0f;
float sigma_local_11 = 0.0f;
float sigma_local_12 = 0.0f;
float sigma_local_20 = 0.0f;
float sigma_local_21 = 0.0f;
float sigma_local_22 = 0.0f;
float disp_local = 0.0f;
if (iter < maxij) {
int iat, jat;
ij_at_linij(iter, iat, jat);
float f_local[3] = { 0.0f };
float dc6i_local_i = 0.0f;
float dc6i_local_j = 0.0f;
const float c6 = c6_ij_tot[iter];
const float dc6iji = dc6_iji_tot[iter];
const float dc6ijj = dc6_ijj_tot[iter];
if (iat == jat) {
const int atomtype_i = type[iat];
const float r0 = r0ab[atomtype_i][atomtype_i];
const float unit_r2r4 = r2r4[atomtype_i];
const float r42 = unit_r2r4 * unit_r2r4;
const float unit_a1 = (a1 * r0);
const float unit_a2 = (a2 * r0);
const float s8r42 = s8 * r42;
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_vdw[k-2];
const int idx2 = tau_idx_vdw[k-1];
const int idx3 = tau_idx_vdw[k];
if (idx1 == rep_vdw[0] && idx2 == rep_vdw[1] && idx3 == rep_vdw[2]) { continue; }
const float rij[3] = {
tau_vdw[idx1][idx2][idx3][0],
tau_vdw[idx1][idx2][idx3][1],
tau_vdw[idx1][idx2][idx3][2]
};
const float r2 = lensq3(rij);
if (r2 > rthr) { continue; }
const float r_rc = rsqrtf(r2);
float unit_rc_a1 = unit_a1 * r_rc;
float t6 = unit_rc_a1 * unit_rc_a1; // ^2
t6 *= unit_rc_a1; // ^3
t6 *= t6; // ^6
t6 *= unit_rc_a1; // ^7
t6 *= t6; // ^14
const float damp6 = 1.0f / fmaf(t6, 6.0f, 1.0f);
float unit_rc_a2 = unit_a2 * r_rc;
float t8 = unit_rc_a2 * unit_rc_a2; // ^2
t8 *= t8; // ^4
t8 *= t8; // ^8
t8 *= t8; // ^16
const float damp8 = 1.0f / fmaf(t8, 6.0f, 1.0f);
const float r2_rc = r_rc * r_rc; // 1.0 / r2
const float r6_rc = r2_rc * r2_rc * r2_rc;
const float r8_rc = r6_rc * r2_rc;
const float x1 = 3.0f * c6 * r8_rc * fmaf(r2_rc, s8r42 * damp8 * fmaf(3.0f * alp8 * t8, damp8, -4.0f), s6 * damp6 * fmaf(alp6 * t6, damp6, -1.0f));
//const float x1 = 0.5 * 6.0 * c6 * r8_rc * (s6 * damp6 * (14.0 * t6 * damp6 - 1.0) + s8r42 * r2_rc * damp8 * (48.0 * t8 * damp8 - 4.0));
//3.0 * alp6 = 48.0
const float vec[3] = {
x1 * rij[0],
x1 * rij[1],
x1 * rij[2]
};
sigma_local_00 += vec[0] * rij[0];
sigma_local_01 += vec[0] * rij[1];
sigma_local_02 += vec[0] * rij[2];
sigma_local_10 += vec[1] * rij[0];
sigma_local_11 += vec[1] * rij[1];
sigma_local_12 += vec[1] * rij[2];
sigma_local_20 += vec[2] * rij[0];
sigma_local_21 += vec[2] * rij[1];
sigma_local_22 += vec[2] * rij[2];
const float dc6_rest = 0.5f * r6_rc * fmaf(3.0f * r2_rc, s8r42 * damp8, s6 * damp6);
//const float dc6_rest = 0.5 * r6_rc * (s6 * damp6 + 3.0 * s8r42 * damp8 * r2_rc);
disp_local -= dc6_rest * c6;
dc6i_local_i += dc6_rest * dc6iji;
dc6i_local_j += dc6_rest * dc6ijj;
}
atomicAdd(&dc6i[iat], dc6i_local_i);
atomicAdd(&dc6i[jat], dc6i_local_j);
}
else {
const int atomtype_i = type[iat];
const int atomtype_j = type[jat];
const float r0 = r0ab[atomtype_i][atomtype_j];
const float r42 = r2r4[atomtype_i] * r2r4[atomtype_j];
const float unit_a1 = (a1 * r0);
const float unit_a2 = (a2 * r0);
const float s8r42 = s8 * r42;
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_vdw[k-2];
const int idx2 = tau_idx_vdw[k-1];
const int idx3 = tau_idx_vdw[k];
const float rij[3] = {
x[jat][0] - x[iat][0] + tau_vdw[idx1][idx2][idx3][0],
x[jat][1] - x[iat][1] + tau_vdw[idx1][idx2][idx3][1],
x[jat][2] - x[iat][2] + tau_vdw[idx1][idx2][idx3][2]
};
const float r2 = lensq3(rij);
if (r2 > rthr) { continue; }
const float r_rc = rsqrtf(r2);
float unit_rc_a1 = unit_a1 * r_rc;
float t6 = unit_rc_a1 * unit_rc_a1; // ^2
t6 *= unit_rc_a1; // ^3
t6 *= t6; // ^6
t6 *= unit_rc_a1; // ^7
t6 *= t6; // ^14
const float damp6 = 1.0f / fmaf(t6, 6.0f, 1.0f);
float unit_rc_a2 = unit_a2 * r_rc;
float t8 = unit_rc_a2 * unit_rc_a2; // ^2
t8 *= t8; // ^4
t8 *= t8; // ^8
t8 *= t8; // ^16
const float damp8 = 1.0f / fmaf(t8, 6.0f, 1.0f);
const float r2_rc = r_rc * r_rc; // 1.0 / r2
const float r6_rc = r2_rc * r2_rc * r2_rc;
const float r8_rc = r6_rc * r2_rc;
const float x1 = 6.0f * c6 * r8_rc * fmaf(r2_rc, s8r42 * damp8 * fmaf(3.0f * alp8 * t8, damp8, -4.0f), s6 * damp6 * fmaf(alp6 * t6, damp6, -1.0f));
//const float x1 = 6.0 * c6 * r8_rc * (s6 * damp6 * (14.0 * t6 * damp6 - 1.0) + s8r42 * r2_rc * damp8 * (48.0 * t8 * damp8 - 4.0));
//3.0 * alp6 = 48.0
const float vec[3] = {
x1 * rij[0],
x1 * rij[1],
x1 * rij[2]
};
f_local[0] -= vec[0];
f_local[1] -= vec[1];
f_local[2] -= vec[2];
sigma_local_00 += vec[0] * rij[0];
sigma_local_01 += vec[0] * rij[1];
sigma_local_02 += vec[0] * rij[2];
sigma_local_10 += vec[1] * rij[0];
sigma_local_11 += vec[1] * rij[1];
sigma_local_12 += vec[1] * rij[2];
sigma_local_20 += vec[2] * rij[0];
sigma_local_21 += vec[2] * rij[1];
sigma_local_22 += vec[2] * rij[2];
const float dc6_rest = r6_rc * fmaf(3.0f * r2_rc, s8r42 * damp8, s6 * damp6);
//const float dc6_rest = r6_rc * (s6 * damp6 + 3.0 * s8r42 * damp8 * r2_rc);
disp_local -= dc6_rest * c6;
dc6i_local_i += dc6_rest * dc6iji;
dc6i_local_j += dc6_rest * dc6ijj;
}
atomicAdd(&dc6i[iat], dc6i_local_i);
atomicAdd(&dc6i[jat], dc6i_local_j);
atomicAdd(&f[iat][0], f_local[0]);
atomicAdd(&f[iat][1], f_local[1]);
atomicAdd(&f[iat][2], f_local[2]);
atomicAdd(&f[jat][0], -f_local[0]);
atomicAdd(&f[jat][1], -f_local[1]);
atomicAdd(&f[jat][2], -f_local[2]);
}
}
sigma_00[threadIdx.x] = sigma_local_00;
sigma_01[threadIdx.x] = sigma_local_01;
sigma_02[threadIdx.x] = sigma_local_02;
sigma_10[threadIdx.x] = sigma_local_10;
sigma_11[threadIdx.x] = sigma_local_11;
sigma_12[threadIdx.x] = sigma_local_12;
sigma_20[threadIdx.x] = sigma_local_20;
sigma_21[threadIdx.x] = sigma_local_21;
sigma_22[threadIdx.x] = sigma_local_22;
disp_shared[threadIdx.x] = disp_local;
__syncthreads();
for (int s=blockDim.x/2; s>0; s>>=1) {
if (threadIdx.x < s) {
sigma_00[threadIdx.x] += sigma_00[threadIdx.x + s];
sigma_01[threadIdx.x] += sigma_01[threadIdx.x + s];
sigma_02[threadIdx.x] += sigma_02[threadIdx.x + s];
sigma_10[threadIdx.x] += sigma_10[threadIdx.x + s];
sigma_11[threadIdx.x] += sigma_11[threadIdx.x + s];
sigma_12[threadIdx.x] += sigma_12[threadIdx.x + s];
sigma_20[threadIdx.x] += sigma_20[threadIdx.x + s];
sigma_21[threadIdx.x] += sigma_21[threadIdx.x + s];
sigma_22[threadIdx.x] += sigma_22[threadIdx.x + s];
disp_shared[threadIdx.x] += disp_shared[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(&sigma[0][0], sigma_00[0]);
atomicAdd(&sigma[0][1], sigma_01[0]);
atomicAdd(&sigma[0][2], sigma_02[0]);
atomicAdd(&sigma[1][0], sigma_10[0]);
atomicAdd(&sigma[1][1], sigma_11[0]);
atomicAdd(&sigma[1][2], sigma_12[0]);
atomicAdd(&sigma[2][0], sigma_20[0]);
atomicAdd(&sigma[2][1], sigma_21[0]);
atomicAdd(&sigma[2][2], sigma_22[0]);
atomicAdd(disp, disp_shared[0]);
}
}
void PairD3::get_forces_without_dC6_zero() {
int n = atom->natoms;
int maxij = n * (n + 1) / 2;
int maxtau = tau_idx_vdw_total_size;
*disp = 0.0;
for (int dim = 0; dim < n; dim++) { dc6i[dim] = 0.0; }
for (int i = 0; i < n; i++) {
for (int j = 0; j < 3; j++) {
f[i][j] = 0.0;
}
}
for (int ii = 0; ii < 3; ii++) {
for (int jj = 0; jj < 3; jj++) {
sigma[ii][jj] = 0.0;
}
}
//START_CUDA_TIMER();
int threadsPerBlock = 128;
int blocksPerGrid = (maxij + threadsPerBlock - 1) / threadsPerBlock;
kernel_get_forces_without_dC6_zero<<<blocksPerGrid, threadsPerBlock>>>(
maxij, maxtau, rthr, s6, s8, a1, a2, alp6, alp8,
r2r4, r0ab, rep_vdw, tau_vdw, tau_idx_vdw, atomtype, x,
c6_ij_tot, dc6_iji_tot, dc6_ijj_tot,
dc6i, disp, f, sigma
);
cudaDeviceSynchronize();
disp_total = *disp;
//STOP_CUDA_TIMER("get_forces_without");
}
__global__ void kernel_get_forces_without_dC6_bj(
int maxij, int maxtau, float rthr, float s6, float s8, float a1, float a2,
float *r2r4, int *rep_vdw, float ****tau_vdw, int *tau_idx_vdw, int *type, float **x,
float *c6_ij_tot, float *dc6_iji_tot, float *dc6_ijj_tot,
double *dc6i, double *disp, double **f, double **sigma
) {
int iter = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ float sigma_00[128];
__shared__ float sigma_01[128];
__shared__ float sigma_02[128];
__shared__ float sigma_10[128];
__shared__ float sigma_11[128];
__shared__ float sigma_12[128];
__shared__ float sigma_20[128];
__shared__ float sigma_21[128];
__shared__ float sigma_22[128];
__shared__ float disp_shared[128];
float sigma_local_00 = 0.0f;
float sigma_local_01 = 0.0f;
float sigma_local_02 = 0.0f;
float sigma_local_10 = 0.0f;
float sigma_local_11 = 0.0f;
float sigma_local_12 = 0.0f;
float sigma_local_20 = 0.0f;
float sigma_local_21 = 0.0f;
float sigma_local_22 = 0.0f;
float disp_local = 0.0f;
if (iter < maxij) {
int iat, jat;
ij_at_linij(iter, iat, jat);
float f_local[3] = { 0.0f };
float dc6i_local_i = 0.0f;
float dc6i_local_j = 0.0f;
const float c6 = c6_ij_tot[iter];
const float dc6iji = dc6_iji_tot[iter];
const float dc6ijj = dc6_ijj_tot[iter];
if (iat == jat) {
const float unit_r2r4 = r2r4[type[iat]];
const float r42x3 = unit_r2r4 * unit_r2r4 * 3.0f;
const float R0 = fmaf(a1, sqrtf(r42x3), a2);
const float R0_2 = R0 * R0;
const float R0_6 = R0_2 * R0_2 * R0_2;
const float R0_8 = R0_6 * R0_2;
const float s8r42x3 = s8 * r42x3;
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_vdw[k-2];
const int idx2 = tau_idx_vdw[k-1];
const int idx3 = tau_idx_vdw[k];
if (idx1 == rep_vdw[0] && idx2 == rep_vdw[1] && idx3 == rep_vdw[2]) { continue; }
const float rij[3] = {
tau_vdw[idx1][idx2][idx3][0],
tau_vdw[idx1][idx2][idx3][1],
tau_vdw[idx1][idx2][idx3][2]
};
const float r2 = lensq3(rij);
if (r2 > rthr) { continue; }
const float r = sqrtf(r2);
const float r5 = r2 * r2 * r;
const float r7 = r5 * r2;
const float t6_rc = 1.0f / fmaf(r5, r, R0_6);
const float t8_rc = 1.0f / fmaf(r7, r, R0_8);
const float t6_sqrc = t6_rc * t6_rc;
const float t8_sqrc = t8_rc * t8_rc;
const float x1 = -c6 * fmaf(4.0f * s8r42x3 * r7, t8_sqrc, 3.0f * s6 * r5 * t6_sqrc);
//const float x1 = 0.5 * -c6 * (6.0 * s6 * r5 * t6_sqrc + 8.0 * s8r42x3 * r7 * t8_sqrc;
const float r_rc = 1.0f / r; // rsqrt(r2)
const float vec[3] = {
x1 * rij[0] * r_rc,
x1 * rij[1] * r_rc,
x1 * rij[2] * r_rc
};
sigma_local_00 += vec[0] * rij[0];
sigma_local_01 += vec[0] * rij[1];
sigma_local_02 += vec[0] * rij[2];
sigma_local_10 += vec[1] * rij[0];
sigma_local_11 += vec[1] * rij[1];
sigma_local_12 += vec[1] * rij[2];
sigma_local_20 += vec[2] * rij[0];
sigma_local_21 += vec[2] * rij[1];
sigma_local_22 += vec[2] * rij[2];
const float dc6_rest = 0.5f * fmaf(s8r42x3, t8_rc, s6 * t6_rc);
//const float dc6_rest = 0.5 * s6 * t6_rc + s8r42x3 * t8_rc;
disp_local -= dc6_rest * c6;
dc6i_local_i += dc6_rest * dc6iji;
dc6i_local_j += dc6_rest * dc6ijj;
}
atomicAdd(&dc6i[iat], dc6i_local_i);
atomicAdd(&dc6i[jat], dc6i_local_j);
}
else {
const float r42x3 = r2r4[type[iat]] * r2r4[type[jat]] * 3.0f;
const float R0 = fmaf(a1, sqrtf(r42x3), a2);
const float R0_2 = R0 * R0;
const float R0_6 = R0_2 * R0_2 * R0_2;
const float R0_8 = R0_6 * R0_2;
const float s8r42x3 = s8 * r42x3;
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_vdw[k-2];
const int idx2 = tau_idx_vdw[k-1];
const int idx3 = tau_idx_vdw[k];
const float rij[3] = {
x[jat][0] - x[iat][0] + tau_vdw[idx1][idx2][idx3][0],
x[jat][1] - x[iat][1] + tau_vdw[idx1][idx2][idx3][1],
x[jat][2] - x[iat][2] + tau_vdw[idx1][idx2][idx3][2]
};
const float r2 = lensq3(rij);
if (r2 > rthr) { continue; }
const float r = sqrtf(r2);
const float r5 = r2 * r2 * r;
const float r7 = r5 * r2;
const float t6_rc = 1.0f / fmaf(r5, r, R0_6);
const float t8_rc = 1.0f / fmaf(r7, r, R0_8);
const float t6_sqrc = t6_rc * t6_rc;
const float t8_sqrc = t8_rc * t8_rc;
const float x1 = -c6 * fmaf(8.0f * s8r42x3 * r7, t8_sqrc, 6.0f * s6 * r5 * t6_sqrc);
//const float x1 = -c6 * (6.0 * s6 * r5 * t6_sqrc + 8.0 * s8r42x3 * r7 * t8_sqrc;
const float r_rc = 1.0f / r; // rsqrt(r2)
const float vec[3] = {
x1 * rij[0] * r_rc,
x1 * rij[1] * r_rc,
x1 * rij[2] * r_rc
};
f_local[0] -= vec[0];
f_local[1] -= vec[1];
f_local[2] -= vec[2];
sigma_local_00 += vec[0] * rij[0];
sigma_local_01 += vec[0] * rij[1];
sigma_local_02 += vec[0] * rij[2];
sigma_local_10 += vec[1] * rij[0];
sigma_local_11 += vec[1] * rij[1];
sigma_local_12 += vec[1] * rij[2];
sigma_local_20 += vec[2] * rij[0];
sigma_local_21 += vec[2] * rij[1];
sigma_local_22 += vec[2] * rij[2];
const float dc6_rest = fmaf(s8r42x3, t8_rc, s6 * t6_rc);
//const float dc6_rest = s6 * t6_rc + s8r42x3 * t8_rc;
disp_local -= dc6_rest * c6;
dc6i_local_i += dc6_rest * dc6iji;
dc6i_local_j += dc6_rest * dc6ijj;
}
atomicAdd(&dc6i[iat], dc6i_local_i);
atomicAdd(&dc6i[jat], dc6i_local_j);
atomicAdd(&f[iat][0], f_local[0]);
atomicAdd(&f[iat][1], f_local[1]);
atomicAdd(&f[iat][2], f_local[2]);
atomicAdd(&f[jat][0], -f_local[0]);
atomicAdd(&f[jat][1], -f_local[1]);
atomicAdd(&f[jat][2], -f_local[2]);
}
}
sigma_00[threadIdx.x] = sigma_local_00;
sigma_01[threadIdx.x] = sigma_local_01;
sigma_02[threadIdx.x] = sigma_local_02;
sigma_10[threadIdx.x] = sigma_local_10;
sigma_11[threadIdx.x] = sigma_local_11;
sigma_12[threadIdx.x] = sigma_local_12;
sigma_20[threadIdx.x] = sigma_local_20;
sigma_21[threadIdx.x] = sigma_local_21;
sigma_22[threadIdx.x] = sigma_local_22;
disp_shared[threadIdx.x] = disp_local;
__syncthreads();
for (int s=blockDim.x/2; s>0; s>>=1) {
if (threadIdx.x < s) {
sigma_00[threadIdx.x] += sigma_00[threadIdx.x + s];
sigma_01[threadIdx.x] += sigma_01[threadIdx.x + s];
sigma_02[threadIdx.x] += sigma_02[threadIdx.x + s];
sigma_10[threadIdx.x] += sigma_10[threadIdx.x + s];
sigma_11[threadIdx.x] += sigma_11[threadIdx.x + s];
sigma_12[threadIdx.x] += sigma_12[threadIdx.x + s];
sigma_20[threadIdx.x] += sigma_20[threadIdx.x + s];
sigma_21[threadIdx.x] += sigma_21[threadIdx.x + s];
sigma_22[threadIdx.x] += sigma_22[threadIdx.x + s];
disp_shared[threadIdx.x] += disp_shared[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(&sigma[0][0], sigma_00[0]);
atomicAdd(&sigma[0][1], sigma_01[0]);
atomicAdd(&sigma[0][2], sigma_02[0]);
atomicAdd(&sigma[1][0], sigma_10[0]);
atomicAdd(&sigma[1][1], sigma_11[0]);
atomicAdd(&sigma[1][2], sigma_12[0]);
atomicAdd(&sigma[2][0], sigma_20[0]);
atomicAdd(&sigma[2][1], sigma_21[0]);
atomicAdd(&sigma[2][2], sigma_22[0]);
atomicAdd(disp, disp_shared[0]);
}
}
void PairD3::get_forces_without_dC6_bj() {
int n = atom->natoms;
int maxij = n * (n + 1) / 2;
int maxtau = tau_idx_vdw_total_size;
*disp = 0.0;
for (int dim = 0; dim < n; dim++) { dc6i[dim] = 0.0; }
for (int i = 0; i < n; i++) {
for (int j = 0; j < 3; j++) {
f[i][j] = 0.0;
}
}
for (int ii = 0; ii < 3; ii++) {
for (int jj = 0; jj < 3; jj++) {
sigma[ii][jj] = 0.0;
}
}
//START_CUDA_TIMER();
int threadsPerBlock = 128;
int blocksPerGrid = (maxij + threadsPerBlock - 1) / threadsPerBlock;
kernel_get_forces_without_dC6_bj<<<blocksPerGrid, threadsPerBlock>>>(
maxij, maxtau, rthr, s6, s8, a1, a2,
r2r4, rep_vdw, tau_vdw, tau_idx_vdw, atomtype, x,
c6_ij_tot, dc6_iji_tot, dc6_ijj_tot,
dc6i, disp, f, sigma
);
cudaDeviceSynchronize();
disp_total = *disp;
//STOP_CUDA_TIMER("get_forces_without");
}
void PairD3::get_forces_without_dC6_zerom() {}
void PairD3::get_forces_without_dC6_bjm() {}
void PairD3::get_forces_without_dC6() {
void (PairD3::*get_forces_without_dC6_damp[4])() = {
&PairD3::get_forces_without_dC6_zero,
&PairD3::get_forces_without_dC6_bj,
&PairD3::get_forces_without_dC6_zerom,
&PairD3::get_forces_without_dC6_bjm
};
(this->*get_forces_without_dC6_damp[damping])();
}
__global__ void kernel_get_forces_with_dC6(
int maxij, int maxtau, float cnthr, float K1,
double *dc6i, float *rcov, int *rep_cn, float ****tau_cn, int *tau_idx_cn, int *type, float **x,
double **f, double **sigma
) {
int iter = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ float sigma_00[128];
__shared__ float sigma_01[128];
__shared__ float sigma_02[128];
__shared__ float sigma_10[128];
__shared__ float sigma_11[128];
__shared__ float sigma_12[128];
__shared__ float sigma_20[128];
__shared__ float sigma_21[128];
__shared__ float sigma_22[128];
float sigma_local_00 = 0.0f;
float sigma_local_01 = 0.0f;
float sigma_local_02 = 0.0f;
float sigma_local_10 = 0.0f;
float sigma_local_11 = 0.0f;
float sigma_local_12 = 0.0f;
float sigma_local_20 = 0.0f;
float sigma_local_21 = 0.0f;
float sigma_local_22 = 0.0f;
float f_local[3] = { 0.0f };
if (iter < maxij) {
int iat, jat;
ij_at_linij(iter, iat, jat);
if (iat == jat) {
const float rcov_sum = rcov[type[iat]] * 2.0f;
const float dc6i_sum = dc6i[iat];
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_cn[k-2];
const int idx2 = tau_idx_cn[k-1];
const int idx3 = tau_idx_cn[k];
if (idx1 == rep_cn[0] && idx2 == rep_cn[1] && idx3 == rep_cn[2]) { continue; }
const float rij[3] = {
tau_cn[idx1][idx2][idx3][0],
tau_cn[idx1][idx2][idx3][1],
tau_cn[idx1][idx2][idx3][2],
};
const float r2 = lensq3(rij);
if (r2 >= cnthr) { continue; }
const float r_rc = rsqrtf(r2);
const float expterm = expf(-K1 * (rcov_sum * r_rc - 1.0f));
const float unit_rc = 1.0f / (r2 * (expterm + 1.0f) * (expterm + 1.0f));
const float dcnn = -K1 * rcov_sum * expterm * unit_rc;
const float x1 = dcnn * dc6i_sum;
const float vec[3] = {
x1 * rij[0] * r_rc,
x1 * rij[1] * r_rc,
x1 * rij[2] * r_rc
};
sigma_local_00 += vec[0] * rij[0];
sigma_local_01 += vec[0] * rij[1];
sigma_local_02 += vec[0] * rij[2];
sigma_local_10 += vec[1] * rij[0];
sigma_local_11 += vec[1] * rij[1];
sigma_local_12 += vec[1] * rij[2];
sigma_local_20 += vec[2] * rij[0];
sigma_local_21 += vec[2] * rij[1];
sigma_local_22 += vec[2] * rij[2];
}
}
else {
const float rcov_sum = rcov[type[iat]] + rcov[type[jat]];
const float dc6i_sum = dc6i[iat] + dc6i[jat];
for (int k = maxtau - 1; k >= 0; k -= 3) {
const int idx1 = tau_idx_cn[k-2];
const int idx2 = tau_idx_cn[k-1];
const int idx3 = tau_idx_cn[k];
const float rij[3] = {
x[jat][0] - x[iat][0] + tau_cn[idx1][idx2][idx3][0],
x[jat][1] - x[iat][1] + tau_cn[idx1][idx2][idx3][1],
x[jat][2] - x[iat][2] + tau_cn[idx1][idx2][idx3][2]
};
const float r2 = lensq3(rij);
if (r2 >= cnthr) { continue; }
const float r_rc = rsqrtf(r2);
const float expterm = expf(-K1 * (rcov_sum * r_rc - 1.0f));
const float unit_rc = 1.0f / (r2 * (expterm + 1.0f) * (expterm + 1.0f));
const float dcnn = -K1 * rcov_sum * expterm * unit_rc;
const float x1 = dcnn * dc6i_sum;
const float vec[3] = {
x1 * rij[0] * r_rc,
x1 * rij[1] * r_rc,
x1 * rij[2] * r_rc
};
f_local[0] -= vec[0];
f_local[1] -= vec[1];
f_local[2] -= vec[2];
sigma_local_00 += vec[0] * rij[0];
sigma_local_01 += vec[0] * rij[1];
sigma_local_02 += vec[0] * rij[2];
sigma_local_10 += vec[1] * rij[0];
sigma_local_11 += vec[1] * rij[1];
sigma_local_12 += vec[1] * rij[2];
sigma_local_20 += vec[2] * rij[0];
sigma_local_21 += vec[2] * rij[1];
sigma_local_22 += vec[2] * rij[2];
}
atomicAdd(&f[iat][0], f_local[0]);
atomicAdd(&f[iat][1], f_local[1]);
atomicAdd(&f[iat][2], f_local[2]);
atomicAdd(&f[jat][0], -f_local[0]);
atomicAdd(&f[jat][1], -f_local[1]);
atomicAdd(&f[jat][2], -f_local[2]);
}
}
sigma_00[threadIdx.x] = sigma_local_00;
sigma_01[threadIdx.x] = sigma_local_01;
sigma_02[threadIdx.x] = sigma_local_02;
sigma_10[threadIdx.x] = sigma_local_10;
sigma_11[threadIdx.x] = sigma_local_11;
sigma_12[threadIdx.x] = sigma_local_12;
sigma_20[threadIdx.x] = sigma_local_20;
sigma_21[threadIdx.x] = sigma_local_21;
sigma_22[threadIdx.x] = sigma_local_22;
__syncthreads();
for (int s=blockDim.x/2; s>0; s>>=1) {
if (threadIdx.x < s) {
sigma_00[threadIdx.x] += sigma_00[threadIdx.x + s];
sigma_01[threadIdx.x] += sigma_01[threadIdx.x + s];
sigma_02[threadIdx.x] += sigma_02[threadIdx.x + s];
sigma_10[threadIdx.x] += sigma_10[threadIdx.x + s];
sigma_11[threadIdx.x] += sigma_11[threadIdx.x + s];
sigma_12[threadIdx.x] += sigma_12[threadIdx.x + s];
sigma_20[threadIdx.x] += sigma_20[threadIdx.x + s];
sigma_21[threadIdx.x] += sigma_21[threadIdx.x + s];
sigma_22[threadIdx.x] += sigma_22[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(&sigma[0][0], sigma_00[0]);
atomicAdd(&sigma[0][1], sigma_01[0]);
atomicAdd(&sigma[0][2], sigma_02[0]);
atomicAdd(&sigma[1][0], sigma_10[0]);
atomicAdd(&sigma[1][1], sigma_11[0]);
atomicAdd(&sigma[1][2], sigma_12[0]);
atomicAdd(&sigma[2][0], sigma_20[0]);
atomicAdd(&sigma[2][1], sigma_21[0]);
atomicAdd(&sigma[2][2], sigma_22[0]);
}
}
void PairD3::get_forces_with_dC6() {
int n = atom->natoms;
int maxij = n * (n + 1) / 2;
int maxtau = tau_idx_cn_total_size;
//START_CUDA_TIMER();
int threadsPerBlock = 128;
int blocksPerGrid = (maxij + threadsPerBlock - 1) / threadsPerBlock;
kernel_get_forces_with_dC6<<<blocksPerGrid, threadsPerBlock>>>(
maxij, maxtau, cnthr, K1,
dc6i, rcov, rep_cn, tau_cn, tau_idx_cn, atomtype, x,
f, sigma
);
cudaDeviceSynchronize();
//STOP_CUDA_TIMER("get_forces_with");
}
/* ----------------------------------------------------------------------
Update energy, force, and stress
------------------------------------------------------------------------- */
void PairD3::update(int eflag, int vflag) {
int n = atom->natoms;
if (eflag) { eng_vdwl += disp_total * AU_TO_EV; } // Energy update
double** f_local = atom->f; // Force update
for (int i = 0; i < n; i++) {
for (int j = 0; j < 3; j++) {
f_local[i][j] += f[i][j] * AU_TO_EV / AU_TO_ANG;
}
}
if (vflag) {
virial[0] += sigma[0][0] * AU_TO_EV;
virial[1] += sigma[1][1] * AU_TO_EV;
virial[2] += sigma[2][2] * AU_TO_EV;
virial[3] += sigma[0][1] * AU_TO_EV;
virial[4] += sigma[0][2] * AU_TO_EV;
virial[5] += sigma[1][2] * AU_TO_EV;
} // Stress update
}
/* ----------------------------------------------------------------------
Compute : energy, force, and stress (Required)
------------------------------------------------------------------------- */
void PairD3::compute(int eflag, int vflag) {
if (eflag || vflag) { ev_setup(eflag, vflag); }
if (atom->natoms != n_save) { reallocate_arrays(); }
set_lattice_vectors();
precalculate_tau_array();
load_atom_info();
cudaMemcpy(atomtype, atom->type, atom->natoms * sizeof(int), cudaMemcpyHostToDevice);
get_coordination_number();
get_dC6_dCNij();
get_forces_without_dC6();
get_forces_with_dC6();
update(eflag, vflag);
CHECK_CUDA_ERROR();
}
/* ----------------------------------------------------------------------
init for one type pair i,j and corresponding j,i
------------------------------------------------------------------------- */
double PairD3::init_one(int i, int j) {
if (setflag[i][j] == 0) error->all(FLERR, "All pair coeffs are not set");
// No need to count local neighbor in D3
/* return std::sqrt(rthr * std::pow(au_to_ang, 2)); */
return 0.0;
}
/* ----------------------------------------------------------------------
init specific to this pair style (Optional)
------------------------------------------------------------------------- */
void PairD3::init_style() {
neighbor->add_request(this, NeighConst::REQ_FULL);
}
/* ----------------------------------------------------------------------
proc 0 writes to restart file
------------------------------------------------------------------------- */
void PairD3::write_restart(FILE *fp) {}
/* ----------------------------------------------------------------------
proc 0 reads from restart file, bcasts
------------------------------------------------------------------------- */
void PairD3::read_restart(FILE *fp) {}
/* ----------------------------------------------------------------------
proc 0 writes to restart file
------------------------------------------------------------------------- */
void PairD3::write_restart_settings(FILE *fp) {}
/* ----------------------------------------------------------------------
proc 0 reads from restart file, bcasts
------------------------------------------------------------------------- */
void PairD3::read_restart_settings(FILE *fp) {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment