Unverified Commit 251f5af2 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 73ff4f3a
###########################################################################################
# Radial basis and cutoff
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
import ase
import numpy as np
import torch
from e3nn.util.jit import compile_mode
from mace.tools.scatter import scatter_sum
@compile_mode("script")
class BesselBasis(torch.nn.Module):
"""
Equation (7)
"""
def __init__(self, r_max: float, num_basis=8, trainable=False):
super().__init__()
bessel_weights = (
np.pi
/ r_max
* torch.linspace(
start=1.0,
end=num_basis,
steps=num_basis,
dtype=torch.get_default_dtype(),
)
)
if trainable:
self.bessel_weights = torch.nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
self.register_buffer(
"prefactor",
torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()),
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
numerator = torch.sin(self.bessel_weights * x) # [..., num_basis]
return self.prefactor * (numerator / x)
def __repr__(self):
return (
f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, "
f"trainable={self.bessel_weights.requires_grad})"
)
@compile_mode("script")
class ChebychevBasis(torch.nn.Module):
"""
Equation (7)
"""
def __init__(self, r_max: float, num_basis=8):
super().__init__()
self.register_buffer(
"n",
torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze(
0
),
)
self.num_basis = num_basis
self.r_max = r_max
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
x = x.repeat(1, self.num_basis)
n = self.n.repeat(len(x), 1)
return torch.special.chebyshev_polynomial_t(x, n)
def __repr__(self):
return (
f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis},"
)
@compile_mode("script")
class GaussianBasis(torch.nn.Module):
"""
Gaussian basis functions
"""
def __init__(self, r_max: float, num_basis=128, trainable=False):
super().__init__()
gaussian_weights = torch.linspace(
start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype()
)
if trainable:
self.gaussian_weights = torch.nn.Parameter(
gaussian_weights, requires_grad=True
)
else:
self.register_buffer("gaussian_weights", gaussian_weights)
self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
x = x - self.gaussian_weights
return torch.exp(self.coeff * torch.pow(x, 2))
@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""
p: torch.Tensor
r_max: torch.Tensor
def __init__(self, r_max: float, p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.calculate_envelope(x, self.r_max, self.p.to(torch.int))
@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
return envelope * (x < r_max)
def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"
@compile_mode("script")
class ZBLBasis(torch.nn.Module):
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""
p: torch.Tensor
def __init__(self, p=6, trainable=False, **kwargs):
super().__init__()
if "r_max" in kwargs:
logging.warning(
"r_max is deprecated. r_max is determined from the covalent radii."
)
# Pre-calculate the p coefficients for the ZBL potential
self.register_buffer(
"c",
torch.tensor(
[0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype()
),
)
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
if trainable:
self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True))
self.a_prefactor = torch.nn.Parameter(
torch.tensor(0.4543, requires_grad=True)
)
else:
self.register_buffer("a_exp", torch.tensor(0.300))
self.register_buffer("a_prefactor", torch.tensor(0.4543))
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
a = (
self.a_prefactor
* 0.529
/ (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp))
)
r_over_a = x / a
phi = (
self.c[0] * torch.exp(-3.2 * r_over_a)
+ self.c[1] * torch.exp(-0.9423 * r_over_a)
+ self.c[2] * torch.exp(-0.4028 * r_over_a)
+ self.c[3] * torch.exp(-0.2016 * r_over_a)
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
def __repr__(self):
return f"{self.__class__.__name__}(c={self.c})"
@compile_mode("script")
class AgnesiTransform(torch.nn.Module):
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""
def __init__(
self,
q: float = 0.9183,
p: float = 4.5791,
a: float = 1.0805,
trainable=False,
):
super().__init__()
self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype()))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
if trainable:
self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True))
self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True))
self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True))
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()
def __repr__(self):
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""
Tanh-based smooth transformation:
T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))],
which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0.
"""
def __init__(self, alpha: float = 4.0, trainable=False):
"""
Args:
p1 (float): Lower "clamp" point.
alpha (float): Steepness; if None, defaults to ~6/(r0-p1).
trainable (bool): Whether to make parameters trainable.
"""
super().__init__()
# Initialize parameters
self.register_buffer(
"alpha", torch.tensor(alpha, dtype=torch.get_default_dtype())
)
if trainable:
self.alpha = torch.nn.Parameter(self.alpha.clone())
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
def compute_r_0(
self,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
"""
Compute r_0 based on atomic information.
Args:
node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers).
edge_index (torch.Tensor): Edge index indicating connections.
atomic_numbers (torch.Tensor): Atomic numbers.
Returns:
torch.Tensor: r_0 values for each edge.
"""
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
return r_0
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers)
p_0 = (3 / 4) * r_0
p_1 = (4 / 3) * r_0
m = 0.5 * (p_0 + p_1)
alpha = self.alpha / (p_1 - p_0)
s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m)))
return p_0 + (x - p_0) * s_x
def __repr__(self):
return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})"
###########################################################################################
# Implementation of the symmetric contraction algorithm presented in the MACE paper
# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11)
# Authors: Ilyes Batatia
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import Dict, Optional, Union
import opt_einsum_fx
import torch
import torch.fx
from e3nn import o3
from e3nn.util.codegen import CodeGenMixin
from e3nn.util.jit import compile_mode
from mace.tools.cg import U_matrix_real
BATCH_EXAMPLE = 10
ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"]
@compile_mode("script")
class SymmetricContraction(CodeGenMixin, torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
correlation: Union[int, Dict[str, int]],
irrep_normalization: str = "component",
path_normalization: str = "element",
internal_weights: Optional[bool] = None,
shared_weights: Optional[bool] = None,
num_elements: Optional[int] = None,
) -> None:
super().__init__()
if irrep_normalization is None:
irrep_normalization = "component"
if path_normalization is None:
path_normalization = "element"
assert irrep_normalization in ["component", "norm", "none"]
assert path_normalization in ["element", "path", "none"]
self.irreps_in = o3.Irreps(irreps_in)
self.irreps_out = o3.Irreps(irreps_out)
del irreps_in, irreps_out
if not isinstance(correlation, tuple):
corr = correlation
correlation = {}
for irrep_out in self.irreps_out:
correlation[irrep_out] = corr
assert shared_weights or not internal_weights
if internal_weights is None:
internal_weights = True
self.internal_weights = internal_weights
self.shared_weights = shared_weights
del internal_weights, shared_weights
self.contractions = torch.nn.ModuleList()
for irrep_out in self.irreps_out:
self.contractions.append(
Contraction(
irreps_in=self.irreps_in,
irrep_out=o3.Irreps(str(irrep_out.ir)),
correlation=correlation[irrep_out],
internal_weights=self.internal_weights,
num_elements=num_elements,
weights=self.shared_weights,
)
)
def forward(self, x: torch.Tensor, y: torch.Tensor):
outs = [contraction(x, y) for contraction in self.contractions]
return torch.cat(outs, dim=-1)
@compile_mode("script")
class Contraction(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irrep_out: o3.Irreps,
correlation: int,
internal_weights: bool = True,
num_elements: Optional[int] = None,
weights: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.num_features = irreps_in.count((0, 1))
self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
self.correlation = correlation
dtype = torch.get_default_dtype()
for nu in range(1, correlation + 1):
U_matrix = U_matrix_real(
irreps_in=self.coupling_irreps,
irreps_out=irrep_out,
correlation=nu,
dtype=dtype,
)[-1]
self.register_buffer(f"U_matrix_{nu}", U_matrix)
# Tensor contraction equations
self.contractions_weighting = torch.nn.ModuleList()
self.contractions_features = torch.nn.ModuleList()
# Create weight for product basis
self.weights = torch.nn.ParameterList([])
for i in range(correlation, 0, -1):
# Shapes definying
num_params = self.U_tensors(i).size()[-1]
num_equivariance = 2 * irrep_out.lmax + 1
num_ell = self.U_tensors(i).size()[-2]
if i == correlation:
parse_subscript_main = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
+ ["ik,ekc,bci,be -> bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
)
graph_module_main = torch.fx.symbolic_trace(
lambda x, y, w, z: torch.einsum(
"".join(parse_subscript_main), x, y, w, z
)
)
# Optimizing the contractions
self.graph_opt_main = opt_einsum_fx.optimize_einsums_full(
model=graph_module_main,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights_max = w
else:
# Generate optimized contractions equations
parse_subscript_weighting = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
+ ["k,ekc,be->bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
)
parse_subscript_features = (
["bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
+ ["i,bci->bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
)
# Symbolic tracing of contractions
graph_module_weighting = torch.fx.symbolic_trace(
lambda x, y, z: torch.einsum(
"".join(parse_subscript_weighting), x, y, z
)
)
graph_module_features = torch.fx.symbolic_trace(
lambda x, y: torch.einsum("".join(parse_subscript_features), x, y)
)
# Optimizing the contractions
graph_opt_weighting = opt_einsum_fx.optimize_einsums_full(
model=graph_module_weighting,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
graph_opt_features = opt_einsum_fx.optimize_einsums_full(
model=graph_module_features,
example_inputs=(
torch.randn(
[BATCH_EXAMPLE, self.num_features, num_equivariance]
+ [num_ell] * i
).squeeze(2),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
),
)
self.contractions_weighting.append(graph_opt_weighting)
self.contractions_features.append(graph_opt_features)
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights.append(w)
if not internal_weights:
self.weights = weights[:-1]
self.weights_max = weights[-1]
def forward(self, x: torch.Tensor, y: torch.Tensor):
out = self.graph_opt_main(
self.U_tensors(self.correlation),
self.weights_max,
x,
y,
)
for i, (weight, contract_weights, contract_features) in enumerate(
zip(self.weights, self.contractions_weighting, self.contractions_features)
):
c_tensor = contract_weights(
self.U_tensors(self.correlation - i - 1),
weight,
y,
)
c_tensor = c_tensor + out
out = contract_features(c_tensor, x)
return out.view(out.shape[0], -1)
def U_tensors(self, nu: int):
return dict(self.named_buffers())[f"U_matrix_{nu}"]
###########################################################################################
# Utilities
# Authors: Ilyes Batatia, Gregor Simm and David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
from typing import Dict, List, NamedTuple, Optional, Tuple
import numpy as np
import torch
import torch.utils.data
from scipy.constants import c, e
from mace.tools import to_numpy
from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum
from mace.tools.torch_geometric.batch import Batch
from .blocks import AtomicEnergiesBlock
def compute_forces(
energy: torch.Tensor, positions: torch.Tensor, training: bool = True
) -> torch.Tensor:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
gradient = torch.autograd.grad(
outputs=[energy], # [n_graphs, ]
inputs=[positions], # [n_nodes, 3]
grad_outputs=grad_outputs,
retain_graph=training, # Make sure the graph is not destroyed during training
create_graph=training, # Create graph for second derivative
allow_unused=True, # For complete dissociation turn to true
)[
0
] # [n_nodes, 3]
if gradient is None:
return torch.zeros_like(positions)
return -1 * gradient
def compute_forces_virials(
energy: torch.Tensor,
positions: torch.Tensor,
displacement: torch.Tensor,
cell: torch.Tensor,
training: bool = True,
compute_stress: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
forces, virials = torch.autograd.grad(
outputs=[energy], # [n_graphs, ]
inputs=[positions, displacement], # [n_nodes, 3]
grad_outputs=grad_outputs,
retain_graph=training, # Make sure the graph is not destroyed during training
create_graph=training, # Create graph for second derivative
allow_unused=True,
)
stress = torch.zeros_like(displacement)
if compute_stress and virials is not None:
cell = cell.view(-1, 3, 3)
volume = torch.linalg.det(cell).abs().unsqueeze(-1)
stress = virials / volume.view(-1, 1, 1)
stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress))
if forces is None:
forces = torch.zeros_like(positions)
if virials is None:
virials = torch.zeros((1, 3, 3))
return -1 * forces, -1 * virials, stress
def get_symmetric_displacement(
positions: torch.Tensor,
unit_shifts: torch.Tensor,
cell: Optional[torch.Tensor],
edge_index: torch.Tensor,
num_graphs: int,
batch: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if cell is None:
cell = torch.zeros(
num_graphs * 3,
3,
dtype=positions.dtype,
device=positions.device,
)
sender = edge_index[0]
displacement = torch.zeros(
(num_graphs, 3, 3),
dtype=positions.dtype,
device=positions.device,
)
displacement.requires_grad_(True)
symmetric_displacement = 0.5 * (
displacement + displacement.transpose(-1, -2)
) # From https://github.com/mir-group/nequip
positions = positions + torch.einsum(
"be,bec->bc", positions, symmetric_displacement[batch]
)
cell = cell.view(-1, 3, 3)
cell = cell + torch.matmul(cell, symmetric_displacement)
shifts = torch.einsum(
"be,bec->bc",
unit_shifts,
cell[batch[sender]],
)
return positions, shifts, displacement
@torch.jit.unused
def compute_hessians_vmap(
forces: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
forces_flatten = forces.view(-1)
num_elements = forces_flatten.shape[0]
def get_vjp(v):
return torch.autograd.grad(
-1 * forces_flatten,
positions,
v,
retain_graph=True,
create_graph=False,
allow_unused=False,
)
I_N = torch.eye(num_elements).to(forces.device)
try:
chunk_size = 1 if num_elements < 64 else 16
gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)(
I_N
)[0]
except RuntimeError:
gradient = compute_hessians_loop(forces, positions)
if gradient is None:
return torch.zeros((positions.shape[0], forces.shape[0], 3, 3))
return gradient
@torch.jit.unused
def compute_hessians_loop(
forces: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hessian = []
for grad_elem in forces.view(-1):
hess_row = torch.autograd.grad(
outputs=[-1 * grad_elem],
inputs=[positions],
grad_outputs=torch.ones_like(grad_elem),
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
hess_row = hess_row.detach() # this makes it very slow? but needs less memory
if hess_row is None:
hessian.append(torch.zeros_like(positions))
else:
hessian.append(hess_row)
hessian = torch.stack(hessian)
return hessian
def get_outputs(
energy: torch.Tensor,
positions: torch.Tensor,
cell: torch.Tensor,
displacement: Optional[torch.Tensor],
vectors: Optional[torch.Tensor] = None,
training: bool = False,
compute_force: bool = True,
compute_virials: bool = True,
compute_stress: bool = True,
compute_hessian: bool = False,
compute_edge_forces: bool = False,
) -> Tuple[
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
if (compute_virials or compute_stress) and displacement is not None:
forces, virials, stress = compute_forces_virials(
energy=energy,
positions=positions,
displacement=displacement,
cell=cell,
compute_stress=compute_stress,
training=(training or compute_hessian or compute_edge_forces),
)
elif compute_force:
forces, virials, stress = (
compute_forces(
energy=energy,
positions=positions,
training=(training or compute_hessian or compute_edge_forces),
),
None,
None,
)
else:
forces, virials, stress = (None, None, None)
if compute_hessian:
assert forces is not None, "Forces must be computed to get the hessian"
hessian = compute_hessians_vmap(forces, positions)
else:
hessian = None
if compute_edge_forces and vectors is not None:
edge_forces = compute_forces(
energy=energy,
positions=vectors,
training=(training or compute_hessian),
)
if edge_forces is not None:
edge_forces = -1 * edge_forces # Match LAMMPS sign convention
else:
edge_forces = None
return forces, virials, stress, hessian, edge_forces
def get_atomic_virials_stresses(
edge_forces: torch.Tensor, # [n_edges, 3]
edge_index: torch.Tensor, # [2, n_edges]
vectors: torch.Tensor, # [n_edges, 3]
num_atoms: int,
batch: torch.Tensor,
cell: torch.Tensor, # [n_graphs, 3, 3]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Compute atomic virials and optionally atomic stresses from edge forces and vectors.
From pobo95 PR #528.
Returns:
Tuple of:
- Atomic virials [num_atoms, 3, 3]
- Atomic stresses [num_atoms, 3, 3] (None if not computed)
"""
edge_virial = torch.einsum("zi,zj->zij", edge_forces, vectors)
atom_virial_sender = scatter_sum(
src=edge_virial, index=edge_index[0], dim=0, dim_size=num_atoms
)
atom_virial_receiver = scatter_sum(
src=edge_virial, index=edge_index[1], dim=0, dim_size=num_atoms
)
atom_virial = (atom_virial_sender + atom_virial_receiver) / 2
atom_virial = (atom_virial + atom_virial.transpose(-1, -2)) / 2
atom_stress = None
cell = cell.view(-1, 3, 3)
volume = torch.linalg.det(cell).abs().unsqueeze(-1)
atom_volume = volume[batch].view(-1, 1, 1)
atom_stress = atom_virial / atom_volume
atom_stress = torch.where(
torch.abs(atom_stress) < 1e10, atom_stress, torch.zeros_like(atom_stress)
)
return -1 * atom_virial, atom_stress
def get_edge_vectors_and_lengths(
positions: torch.Tensor, # [n_nodes, 3]
edge_index: torch.Tensor, # [2, n_edges]
shifts: torch.Tensor, # [n_edges, 3]
normalize: bool = False,
eps: float = 1e-9,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3]
lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1]
if normalize:
vectors_normed = vectors / (lengths + eps)
return vectors_normed, lengths
return vectors, lengths
def _check_non_zero(std):
if np.any(std == 0):
logging.warning(
"Standard deviation of the scaling is zero, Changing to no scaling"
)
std[std == 0] = 1
return std
def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int):
out = []
out.append(x[:, :num_features])
for i in range(1, num_layers):
out.append(
x[
:,
i
* (l_max + 1) ** 2
* num_features : (i * (l_max + 1) ** 2 + 1)
* num_features,
]
)
return torch.cat(out, dim=-1)
def compute_mean_std_atomic_inter_energy(
data_loader: torch.utils.data.DataLoader,
atomic_energies: np.ndarray,
) -> Tuple[float, float]:
atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies)
avg_atom_inter_es_list = []
head_list = []
for batch in data_loader:
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), batch.head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
avg_atom_inter_es_list.append(
(batch.energy - graph_e0s) / graph_sizes
) # {[n_graphs], }
head_list.append(batch.head)
avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs]
head = torch.cat(head_list, dim=0) # [total_n_graphs]
# mean = to_numpy(torch.mean(avg_atom_inter_es)).item()
# std = to_numpy(torch.std(avg_atom_inter_es)).item()
mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1))
std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1))
std = _check_non_zero(std)
return mean, std
def _compute_mean_std_atomic_inter_energy(
batch: Batch,
atomic_energies_fn: AtomicEnergiesBlock,
) -> Tuple[torch.Tensor, torch.Tensor]:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energies = (batch.energy - graph_e0s) / graph_sizes
return atom_energies
def compute_mean_rms_energy_forces(
data_loader: torch.utils.data.DataLoader,
atomic_energies: np.ndarray,
) -> Tuple[float, float]:
atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies)
atom_energy_list = []
forces_list = []
head_list = []
head_batch = []
for batch in data_loader:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energy_list.append(
(batch.energy - graph_e0s) / graph_sizes
) # {[n_graphs], }
forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], }
head_list.append(head)
head_batch.append(head[batch.batch])
atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs]
forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], }
head = torch.cat(head_list, dim=0) # [total_n_graphs]
head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs]
# mean = to_numpy(torch.mean(atom_energies)).item()
# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item()
mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1))
rms = to_numpy(
torch.sqrt(
scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1)
)
)
rms = _check_non_zero(rms)
return mean, rms
def _compute_mean_rms_energy_forces(
batch: Batch,
atomic_energies_fn: AtomicEnergiesBlock,
) -> Tuple[torch.Tensor, torch.Tensor]:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], }
forces = batch.forces # {[n_graphs*n_atoms,3], }
return atom_energies, forces
def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float:
num_neighbors = []
for batch in data_loader:
_, receivers = batch.edge_index
_, counts = torch.unique(receivers, return_counts=True)
num_neighbors.append(counts)
avg_num_neighbors = torch.mean(
torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype())
)
return to_numpy(avg_num_neighbors).item()
def compute_statistics(
data_loader: torch.utils.data.DataLoader,
atomic_energies: np.ndarray,
) -> Tuple[float, float, float, float]:
atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies)
atom_energy_list = []
forces_list = []
num_neighbors = []
head_list = []
head_batch = []
for batch in data_loader:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energy_list.append(
(batch.energy - graph_e0s) / graph_sizes
) # {[n_graphs], }
forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], }
head_list.append(head) # {[n_graphs], }
head_batch.append(head[batch.batch])
_, receivers = batch.edge_index
_, counts = torch.unique(receivers, return_counts=True)
num_neighbors.append(counts)
atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs]
forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], }
head = torch.cat(head_list, dim=0) # [total_n_graphs]
head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs]
# mean = to_numpy(torch.mean(atom_energies)).item()
mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1))
rms = to_numpy(
torch.sqrt(
scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1)
)
)
avg_num_neighbors = torch.mean(
torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype())
)
return to_numpy(avg_num_neighbors).item(), mean, rms
def compute_rms_dipoles(
data_loader: torch.utils.data.DataLoader,
) -> Tuple[float, float]:
dipoles_list = []
for batch in data_loader:
dipoles_list.append(batch.dipole) # {[n_graphs,3], }
dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], }
rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item()
rms = _check_non_zero(rms)
return rms
def compute_fixed_charge_dipole(
charges: torch.Tensor,
positions: torch.Tensor,
batch: torch.Tensor,
num_graphs: int,
) -> torch.Tensor:
mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3]
return scatter_sum(
src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs
) # [N_graphs,3]
class InteractionKwargs(NamedTuple):
lammps_class: Optional[torch.Tensor]
lammps_natoms: Tuple[int, int] = (0, 0)
class GraphContext(NamedTuple):
is_lammps: bool
num_graphs: int
num_atoms_arange: torch.Tensor
displacement: Optional[torch.Tensor]
positions: torch.Tensor
vectors: torch.Tensor
lengths: torch.Tensor
cell: torch.Tensor
node_heads: torch.Tensor
interaction_kwargs: InteractionKwargs
def prepare_graph(
data: Dict[str, torch.Tensor],
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
lammps_mliap: bool = False,
) -> GraphContext:
if torch.jit.is_scripting():
lammps_mliap = False
node_heads = (
data["head"][data["batch"]]
if "head" in data
else torch.zeros_like(data["batch"])
)
if lammps_mliap:
n_real, n_total = data["natoms"][0], data["natoms"][1]
num_graphs = 2
num_atoms_arange = torch.arange(n_real, device=data["node_attrs"].device)
displacement = None
positions = torch.zeros(
(int(n_real), 3),
dtype=data["vectors"].dtype,
device=data["vectors"].device,
)
cell = torch.zeros(
(num_graphs, 3, 3),
dtype=data["vectors"].dtype,
device=data["vectors"].device,
)
vectors = data["vectors"].requires_grad_(True)
lengths = torch.linalg.vector_norm(vectors, dim=1, keepdim=True)
ikw = InteractionKwargs(data["lammps_class"], (n_real, n_total))
else:
data["positions"].requires_grad_(True)
positions = data["positions"]
cell = data["cell"]
num_atoms_arange = torch.arange(positions.shape[0], device=positions.device)
num_graphs = int(data["ptr"].numel() - 1)
displacement = torch.zeros(
(num_graphs, 3, 3), dtype=positions.dtype, device=positions.device
)
if compute_virials or compute_stress or compute_displacement:
p, s, displacement = get_symmetric_displacement(
positions=positions,
unit_shifts=data["unit_shifts"],
cell=cell,
edge_index=data["edge_index"],
num_graphs=num_graphs,
batch=data["batch"],
)
data["positions"], data["shifts"] = p, s
vectors, lengths = get_edge_vectors_and_lengths(
positions=data["positions"],
edge_index=data["edge_index"],
shifts=data["shifts"],
)
ikw = InteractionKwargs(None, (0, 0))
return GraphContext(
is_lammps=lammps_mliap,
num_graphs=num_graphs,
num_atoms_arange=num_atoms_arange,
displacement=displacement,
positions=positions,
vectors=vectors,
lengths=lengths,
cell=cell,
node_heads=node_heads,
interaction_kwargs=ikw,
)
"""
Wrapper class for o3.Linear that optionally uses cuet.Linear
"""
import dataclasses
from typing import List, Optional
import torch
from e3nn import o3
from mace.modules.symmetric_contraction import SymmetricContraction
from mace.tools.cg import O3_e3nn
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
CUET_AVAILABLE = True
except ImportError:
CUET_AVAILABLE = False
@dataclasses.dataclass
class CuEquivarianceConfig:
"""Configuration for cuequivariance acceleration"""
enabled: bool = False
layout: str = "mul_ir" # One of: mul_ir, ir_mul
layout_str: str = "mul_ir"
group: str = "O3"
optimize_all: bool = False # Set to True to enable all optimizations
optimize_linear: bool = False
optimize_channelwise: bool = False
optimize_symmetric: bool = False
optimize_fctp: bool = False
def __post_init__(self):
if self.enabled and CUET_AVAILABLE:
self.layout_str = self.layout
self.layout = getattr(cue, self.layout)
self.group = (
O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group)
)
if not CUET_AVAILABLE:
self.enabled = False
class Linear:
"""Returns either a cuet.Linear or o3.Linear based on config"""
def __new__(
cls,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
shared_weights: bool = True,
internal_weights: bool = True,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_linear)
):
return cuet.Linear(
cue.Irreps(cueq_config.group, irreps_in),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
use_fallback=True,
)
return o3.Linear(
irreps_in,
irreps_out,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class TensorProduct:
"""Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct"""
def __new__(
cls,
irreps_in1: o3.Irreps,
irreps_in2: o3.Irreps,
irreps_out: o3.Irreps,
instructions: Optional[List] = None,
shared_weights: bool = False,
internal_weights: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_channelwise)
):
return cuet.ChannelWiseTensorProduct(
cue.Irreps(cueq_config.group, irreps_in1),
cue.Irreps(cueq_config.group, irreps_in2),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
internal_weights=internal_weights,
dtype=torch.get_default_dtype(),
math_dtype=torch.get_default_dtype(),
)
return o3.TensorProduct(
irreps_in1,
irreps_in2,
irreps_out,
instructions=instructions,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class FullyConnectedTensorProduct:
"""Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct"""
def __new__(
cls,
irreps_in1: o3.Irreps,
irreps_in2: o3.Irreps,
irreps_out: o3.Irreps,
shared_weights: bool = True,
internal_weights: bool = True,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_fctp)
):
return cuet.FullyConnectedTensorProduct(
cue.Irreps(cueq_config.group, irreps_in1),
cue.Irreps(cueq_config.group, irreps_in2),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
internal_weights=internal_weights,
use_fallback=True,
)
return o3.FullyConnectedTensorProduct(
irreps_in1,
irreps_in2,
irreps_out,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class SymmetricContractionWrapper:
"""Wrapper around SymmetricContraction/cuet.SymmetricContraction"""
def __new__(
cls,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
correlation: int,
num_elements: Optional[int] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_symmetric)
):
return cuet.SymmetricContraction(
cue.Irreps(cueq_config.group, irreps_in),
cue.Irreps(cueq_config.group, irreps_out),
layout_in=cue.ir_mul,
layout_out=cueq_config.layout,
contraction_degree=correlation,
num_elements=num_elements,
original_mace=True,
dtype=torch.get_default_dtype(),
math_dtype=torch.get_default_dtype(),
)
return SymmetricContraction(
irreps_in=irreps_in,
irreps_out=irreps_out,
correlation=correlation,
num_elements=num_elements,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment