Unverified Commit 251f5af2 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 73ff4f3a
###########################################################################################
# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch.nn.functional
from e3nn import nn, o3
from e3nn.util.jit import compile_mode
from mace.modules.wrapper_ops import (
CuEquivarianceConfig,
FullyConnectedTensorProduct,
Linear,
SymmetricContractionWrapper,
TensorProduct,
)
from mace.tools.compile import simplify_if_compile
from mace.tools.scatter import scatter_sum
from mace.tools.utils import LAMMPS_MP
from .irreps_tools import mask_head, reshape_irreps, tp_out_irreps_with_instructions
from .radial import (
AgnesiTransform,
BesselBasis,
ChebychevBasis,
GaussianBasis,
PolynomialCutoff,
SoftTransform,
)
@compile_mode("script")
class LinearNodeEmbeddingBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.linear = Linear(
irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config
)
def forward(
self,
node_attrs: torch.Tensor,
) -> torch.Tensor: # [n_nodes, irreps]
return self.linear(node_attrs)
@compile_mode("script")
class LinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irrep_out: o3.Irreps = o3.Irreps("0e"),
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.linear = Linear(
irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config
)
def forward(
self,
x: torch.Tensor,
heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@simplify_if_compile
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
gate: Optional[Callable],
irrep_out: o3.Irreps = o3.Irreps("0e"),
num_heads: int = 1,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.hidden_irreps = MLP_irreps
self.num_heads = num_heads
self.linear_1 = Linear(
irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config
)
self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate])
self.linear_2 = Linear(
irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config
)
def forward(
self, x: torch.Tensor, heads: Optional[torch.Tensor] = None
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
if hasattr(self, "num_heads"):
if self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
return self.linear_2(x) # [n_nodes, len(heads)]
@compile_mode("script")
class LinearDipoleReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
dipole_only: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
self.linear = Linear(
irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@compile_mode("script")
class NonLinearDipoleReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
gate: Callable,
dipole_only: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.hidden_irreps = MLP_irreps
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
irreps_scalars = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out]
)
irreps_gated = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out]
)
irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated)
self.equivariant_nonlin = nn.Gate(
irreps_scalars=irreps_scalars,
act_scalars=[gate for _, ir in irreps_scalars],
irreps_gates=irreps_gates,
act_gates=[gate] * len(irreps_gates),
irreps_gated=irreps_gated,
)
self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify()
self.linear_1 = Linear(
irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config
)
self.linear_2 = Linear(
irreps_in=self.hidden_irreps,
irreps_out=self.irreps_out,
cueq_config=cueq_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.equivariant_nonlin(self.linear_1(x))
return self.linear_2(x) # [n_nodes, 1]
@compile_mode("script")
class AtomicEnergiesBlock(torch.nn.Module):
atomic_energies: torch.Tensor
def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]):
super().__init__()
# assert len(atomic_energies.shape) == 1
self.register_buffer(
"atomic_energies",
torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),
) # [n_elements, n_heads]
def forward(
self, x: torch.Tensor # one-hot of elements [..., n_elements]
) -> torch.Tensor: # [..., ]
return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T)
def __repr__(self):
formatted_energies = ", ".join(
[
"[" + ", ".join([f"{x:.4f}" for x in group]) + "]"
for group in torch.atleast_2d(self.atomic_energies)
]
)
return f"{self.__class__.__name__}(energies=[{formatted_energies}])"
@compile_mode("script")
class RadialEmbeddingBlock(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
radial_type: str = "bessel",
distance_transform: str = "None",
):
super().__init__()
if radial_type == "bessel":
self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel)
elif radial_type == "gaussian":
self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel)
elif radial_type == "chebyshev":
self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel)
if distance_transform == "Agnesi":
self.distance_transform = AgnesiTransform()
elif distance_transform == "Soft":
self.distance_transform = SoftTransform()
self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff)
self.out_dim = num_bessel
def forward(
self,
edge_lengths: torch.Tensor, # [n_edges, 1]
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
):
cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1]
if hasattr(self, "distance_transform"):
edge_lengths = self.distance_transform(
edge_lengths, node_attrs, edge_index, atomic_numbers
)
radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis]
return radial * cutoff # [n_edges, n_basis]
@compile_mode("script")
class EquivariantProductBasisBlock(torch.nn.Module):
def __init__(
self,
node_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
correlation: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
) -> None:
super().__init__()
self.use_sc = use_sc
self.symmetric_contractions = SymmetricContractionWrapper(
irreps_in=node_feats_irreps,
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
cueq_config=cueq_config,
)
# Update linear
self.linear = Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=cueq_config,
)
self.cueq_config = cueq_config
def forward(
self,
node_feats: torch.Tensor,
sc: Optional[torch.Tensor],
node_attrs: torch.Tensor,
) -> torch.Tensor:
use_cueq = False
use_cueq_mul_ir = False
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.enabled and (
self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric
):
use_cueq = True
if self.cueq_config.layout_str == "mul_ir":
use_cueq_mul_ir = True
if use_cueq:
if use_cueq_mul_ir:
node_feats = torch.transpose(node_feats, 1, 2)
index_attrs = torch.nonzero(node_attrs)[:, 1].int()
node_feats = self.symmetric_contractions(
node_feats.flatten(1),
index_attrs,
)
else:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
@compile_mode("script")
class InteractionBlock(torch.nn.Module):
def __init__(
self,
node_attrs_irreps: o3.Irreps,
node_feats_irreps: o3.Irreps,
edge_attrs_irreps: o3.Irreps,
edge_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
hidden_irreps: o3.Irreps,
avg_num_neighbors: float,
radial_MLP: Optional[List[int]] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
) -> None:
super().__init__()
self.node_attrs_irreps = node_attrs_irreps
self.node_feats_irreps = node_feats_irreps
self.edge_attrs_irreps = edge_attrs_irreps
self.edge_feats_irreps = edge_feats_irreps
self.target_irreps = target_irreps
self.hidden_irreps = hidden_irreps
self.avg_num_neighbors = avg_num_neighbors
if radial_MLP is None:
radial_MLP = [64, 64, 64]
self.radial_MLP = radial_MLP
self.cueq_config = cueq_config
self._setup()
@abstractmethod
def _setup(self) -> None:
raise NotImplementedError
def handle_lammps(
self,
node_feats: torch.Tensor,
lammps_class: Optional[Any],
lammps_natoms: Tuple[int, int],
first_layer: bool,
) -> torch.Tensor: # noqa: D401 – internal helper
if lammps_class is None or first_layer or torch.jit.is_scripting():
return node_feats
_, n_total = lammps_natoms
pad = torch.zeros(
(n_total, node_feats.shape[1]),
dtype=node_feats.dtype,
device=node_feats.device,
)
node_feats = torch.cat((node_feats, pad), dim=0)
node_feats = LAMMPS_MP.apply(node_feats, lammps_class)
return node_feats
def truncate_ghosts(
self, tensor: torch.Tensor, n_real: Optional[int] = None
) -> torch.Tensor:
"""Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations."""
return tensor[:n_real] if n_real is not None else tensor
@abstractmethod
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh}
@compile_mode("script")
class RealAgnosticInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out,
self.node_attrs_irreps,
self.irreps_out,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_natoms: Tuple[int, int] = (0, 0),
lammps_class: Optional[Any] = None,
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
message = self.linear(message) / self.avg_num_neighbors
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps,
self.node_attrs_irreps,
self.hidden_irreps,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
sc = self.truncate_ghosts(sc, n_real)
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticDensityInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out,
self.node_attrs_irreps,
self.irreps_out,
cueq_config=self.cueq_config,
)
# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
density = self.truncate_ghosts(density, n_real)
message = self.linear(message) / (density + 1)
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticDensityResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps,
self.node_attrs_irreps,
self.hidden_irreps,
cueq_config=self.cueq_config,
)
# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
density = self.truncate_ghosts(density, n_real)
sc = self.truncate_ghosts(sc, n_real)
message = self.linear(message) / (density + 1)
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
self.node_feats_down_irreps = o3.Irreps("64x0e")
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
self.linear_down = Linear(
self.node_feats_irreps,
self.node_feats_down_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
input_dim = (
self.edge_feats_irreps.num_irreps
+ 2 * self.node_feats_down_irreps.num_irreps
)
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + 3 * [256] + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
# Skip connection.
self.skip_linear = Linear(
self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config
)
# pylint: disable=unused-argument
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_linear(node_feats)
node_feats_up = self.linear_up(node_feats)
node_feats_down = self.linear_down(node_feats)
augmented_edge_feats = torch.cat(
[
edge_feats,
node_feats_down[sender],
node_feats_down[receiver],
],
dim=-1,
)
tp_weights = self.conv_tp_weights(augmented_edge_feats)
mji = self.conv_tp(
node_feats_up[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class ScaleShiftBlock(torch.nn.Module):
def __init__(self, scale: float, shift: float):
super().__init__()
self.register_buffer(
"scale",
torch.tensor(scale, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"shift",
torch.tensor(shift, dtype=torch.get_default_dtype()),
)
def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor:
return (
torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head]
)
def __repr__(self):
formatted_scale = (
", ".join([f"{x:.4f}" for x in self.scale])
if self.scale.numel() > 1
else f"{self.scale.item():.4f}"
)
formatted_shift = (
", ".join([f"{x:.4f}" for x in self.shift])
if self.shift.numel() > 1
else f"{self.shift.item():.4f}"
)
return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})"
###########################################################################################
# Elementary tools for handling irreducible representations
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import List, Optional, Tuple
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from mace.modules.wrapper_ops import CuEquivarianceConfig
# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
) -> Tuple[o3.Irreps, List]:
trainable = True
# Collect possible irreps and their instructions
irreps_out_list: List[Tuple[int, o3.Irreps]] = []
instructions = []
for i, (mul, ir_in) in enumerate(irreps1):
for j, (_, ir_edge) in enumerate(irreps2):
for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
if ir_out in target_irreps:
k = len(irreps_out_list) # instruction index
irreps_out_list.append((mul, ir_out))
instructions.append((i, j, k, "uvu", trainable))
# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_out = o3.Irreps(irreps_out_list)
irreps_out, permut, _ = irreps_out.sort()
# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, permut[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
instructions = sorted(instructions, key=lambda x: x[2])
return irreps_out, instructions
def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps:
# Assuming simplified irreps
irreps_mid = []
for _, ir_in in irreps:
found = False
for mul, ir_out in target_irreps:
if ir_in == ir_out:
irreps_mid.append((mul, ir_out))
found = True
break
if not found:
raise RuntimeError(f"{ir_in} not in {target_irreps}")
return o3.Irreps(irreps_mid)
@compile_mode("script")
class reshape_irreps(torch.nn.Module):
def __init__(
self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None
) -> None:
super().__init__()
self.irreps = o3.Irreps(irreps)
self.cueq_config = cueq_config
self.dims = []
self.muls = []
for mul, ir in self.irreps:
d = ir.dim
self.dims.append(d)
self.muls.append(mul)
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
ix = 0
out = []
batch, _ = tensor.shape
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.layout_str == "mul_ir":
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, d, mul)
else:
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, mul, d)
out.append(field)
if hasattr(self, "cueq_config"):
if self.cueq_config is not None: # pylint: disable=no-else-return
if self.cueq_config.layout_str == "mul_ir":
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-2)
else:
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-1)
def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor:
mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device)
idx = torch.arange(mask.shape[0], device=x.device)
mask[idx, :, head] = 1
mask = mask.permute(0, 2, 1).reshape(x.shape)
return x * mask
###########################################################################################
# Implementation of different loss functions
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import Optional
import torch
import torch.distributed as dist
from mace.tools import TensorDict
from mace.tools.torch_geometric import Batch
# ------------------------------------------------------------------------------
# Helper function for loss reduction that handles DDP correction
# ------------------------------------------------------------------------------
def is_ddp_enabled():
return dist.is_initialized() and dist.get_world_size() > 1
def reduce_loss(raw_loss: torch.Tensor, ddp: Optional[bool] = None) -> torch.Tensor:
"""
Reduces an element-wise loss tensor.
If ddp is True and distributed is initialized, the function computes:
loss = (local_sum * world_size) / global_num_elements
Otherwise, it returns the regular mean.
"""
ddp = is_ddp_enabled() if ddp is None else ddp
if ddp and dist.is_initialized():
world_size = dist.get_world_size()
n_local = raw_loss.numel()
loss_sum = raw_loss.sum()
total_samples = torch.tensor(
n_local, device=raw_loss.device, dtype=raw_loss.dtype
)
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
return loss_sum * world_size / total_samples
return raw_loss.mean()
# ------------------------------------------------------------------------------
# Energy Loss Functions
# ------------------------------------------------------------------------------
def mean_squared_error_energy(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
raw_loss = torch.square(ref["energy"] - pred["energy"])
return reduce_loss(raw_loss, ddp)
def weighted_mean_squared_error_energy(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
# Calculate per-graph number of atoms.
num_atoms = ref.ptr[1:] - ref.ptr[:-1] # shape: [n_graphs]
raw_loss = (
ref.weight
* ref.energy_weight
* torch.square((ref["energy"] - pred["energy"]) / num_atoms)
)
return reduce_loss(raw_loss, ddp)
def weighted_mean_absolute_error_energy(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
raw_loss = (
ref.weight
* ref.energy_weight
* torch.abs((ref["energy"] - pred["energy"]) / num_atoms)
)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Stress and Virials Loss Functions
# ------------------------------------------------------------------------------
def weighted_mean_squared_stress(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
configs_weight = ref.weight.view(-1, 1, 1)
configs_stress_weight = ref.stress_weight.view(-1, 1, 1)
raw_loss = (
configs_weight
* configs_stress_weight
* torch.square(ref["stress"] - pred["stress"])
)
return reduce_loss(raw_loss, ddp)
def weighted_mean_squared_virials(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
configs_weight = ref.weight.view(-1, 1, 1)
configs_virials_weight = ref.virials_weight.view(-1, 1, 1)
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1)
raw_loss = (
configs_weight
* configs_virials_weight
* torch.square((ref["virials"] - pred["virials"]) / num_atoms)
)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Forces Loss Functions
# ------------------------------------------------------------------------------
def mean_squared_error_forces(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
# Repeat per-graph weights to per-atom level.
configs_weight = torch.repeat_interleave(
ref.weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
configs_forces_weight = torch.repeat_interleave(
ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
raw_loss = (
configs_weight
* configs_forces_weight
* torch.square(ref["forces"] - pred["forces"])
)
return reduce_loss(raw_loss, ddp)
def mean_normed_error_forces(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
raw_loss = torch.linalg.vector_norm(ref["forces"] - pred["forces"], ord=2, dim=-1)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Dipole Loss Function
# ------------------------------------------------------------------------------
def weighted_mean_squared_error_dipole(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1)
raw_loss = torch.square((ref["dipole"] - pred["dipole"]) / num_atoms)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Conditional Losses for Forces
# ------------------------------------------------------------------------------
def conditional_mse_forces(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
configs_weight = torch.repeat_interleave(
ref.weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
configs_forces_weight = torch.repeat_interleave(
ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
# Define multiplication factors for different regimes.
factors = torch.tensor(
[1.0, 0.7, 0.4, 0.1], device=ref["forces"].device, dtype=ref["forces"].dtype
)
err = ref["forces"] - pred["forces"]
se = torch.zeros_like(err)
norm_forces = torch.norm(ref["forces"], dim=-1)
c1 = norm_forces < 100
c2 = (norm_forces >= 100) & (norm_forces < 200)
c3 = (norm_forces >= 200) & (norm_forces < 300)
se[c1] = torch.square(err[c1]) * factors[0]
se[c2] = torch.square(err[c2]) * factors[1]
se[c3] = torch.square(err[c3]) * factors[2]
se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3]
raw_loss = configs_weight * configs_forces_weight * se
return reduce_loss(raw_loss, ddp)
def conditional_huber_forces(
ref_forces: torch.Tensor,
pred_forces: torch.Tensor,
huber_delta: float,
ddp: Optional[bool] = None,
) -> torch.Tensor:
factors = huber_delta * torch.tensor(
[1.0, 0.7, 0.4, 0.1], device=ref_forces.device, dtype=ref_forces.dtype
)
norm_forces = torch.norm(ref_forces, dim=-1)
c1 = norm_forces < 100
c2 = (norm_forces >= 100) & (norm_forces < 200)
c3 = (norm_forces >= 200) & (norm_forces < 300)
c4 = ~(c1 | c2 | c3)
se = torch.zeros_like(pred_forces)
se[c1] = torch.nn.functional.huber_loss(
ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0]
)
se[c2] = torch.nn.functional.huber_loss(
ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1]
)
se[c3] = torch.nn.functional.huber_loss(
ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2]
)
se[c4] = torch.nn.functional.huber_loss(
ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3]
)
return reduce_loss(se, ddp)
# ------------------------------------------------------------------------------
# Loss Modules Combining Multiple Quantities
# ------------------------------------------------------------------------------
class WeightedEnergyForcesLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
return self.energy_weight * loss_energy + self.forces_weight * loss_forces
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f})"
)
class WeightedForcesLoss(torch.nn.Module):
def __init__(self, forces_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_forces = mean_squared_error_forces(ref, pred, ddp)
return self.forces_weight * loss_forces
def __repr__(self):
return f"{self.__class__.__name__}(forces_weight={self.forces_weight:.3f})"
class WeightedEnergyForcesStressLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
loss_stress = weighted_mean_squared_stress(ref, pred, ddp)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.stress_weight * loss_stress
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})"
)
class WeightedHuberEnergyForcesStressLoss(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01
) -> None:
super().__init__()
# We store the huber_delta rather than a loss with fixed reduction.
self.huber_delta = huber_delta
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
if ddp:
loss_energy = torch.nn.functional.huber_loss(
ref["energy"] / num_atoms,
pred["energy"] / num_atoms,
reduction="none",
delta=self.huber_delta,
)
loss_energy = reduce_loss(loss_energy, ddp)
loss_forces = torch.nn.functional.huber_loss(
ref["forces"], pred["forces"], reduction="none", delta=self.huber_delta
)
loss_forces = reduce_loss(loss_forces, ddp)
loss_stress = torch.nn.functional.huber_loss(
ref["stress"], pred["stress"], reduction="none", delta=self.huber_delta
)
loss_stress = reduce_loss(loss_stress, ddp)
else:
loss_energy = torch.nn.functional.huber_loss(
ref["energy"] / num_atoms,
pred["energy"] / num_atoms,
reduction="mean",
delta=self.huber_delta,
)
loss_forces = torch.nn.functional.huber_loss(
ref["forces"], pred["forces"], reduction="mean", delta=self.huber_delta
)
loss_stress = torch.nn.functional.huber_loss(
ref["stress"], pred["stress"], reduction="mean", delta=self.huber_delta
)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.stress_weight * loss_stress
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})"
)
class UniversalLoss(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01
) -> None:
super().__init__()
self.huber_delta = huber_delta
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1)
configs_energy_weight = ref.energy_weight
configs_forces_weight = torch.repeat_interleave(
ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
if ddp:
loss_energy = torch.nn.functional.huber_loss(
configs_energy_weight * ref["energy"] / num_atoms,
configs_energy_weight * pred["energy"] / num_atoms,
reduction="none",
delta=self.huber_delta,
)
loss_energy = reduce_loss(loss_energy, ddp)
loss_forces = conditional_huber_forces(
configs_forces_weight * ref["forces"],
configs_forces_weight * pred["forces"],
huber_delta=self.huber_delta,
ddp=ddp,
)
loss_stress = torch.nn.functional.huber_loss(
configs_stress_weight * ref["stress"],
configs_stress_weight * pred["stress"],
reduction="none",
delta=self.huber_delta,
)
loss_stress = reduce_loss(loss_stress, ddp)
else:
loss_energy = torch.nn.functional.huber_loss(
configs_energy_weight * ref["energy"] / num_atoms,
configs_energy_weight * pred["energy"] / num_atoms,
reduction="mean",
delta=self.huber_delta,
)
loss_forces = conditional_huber_forces(
configs_forces_weight * ref["forces"],
configs_forces_weight * pred["forces"],
huber_delta=self.huber_delta,
ddp=ddp,
)
loss_stress = torch.nn.functional.huber_loss(
configs_stress_weight * ref["stress"],
configs_stress_weight * pred["stress"],
reduction="mean",
delta=self.huber_delta,
)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.stress_weight * loss_stress
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})"
)
class WeightedEnergyForcesVirialsLoss(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0
) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"virials_weight",
torch.tensor(virials_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
loss_virials = weighted_mean_squared_virials(ref, pred, ddp)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.virials_weight * loss_virials
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})"
)
class DipoleSingleLoss(torch.nn.Module):
def __init__(self, dipole_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"dipole_weight",
torch.tensor(dipole_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss = (
weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0
) # scale adjustment
return self.dipole_weight * loss
def __repr__(self):
return f"{self.__class__.__name__}(dipole_weight={self.dipole_weight:.3f})"
class WeightedEnergyForcesDipoleLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"dipole_weight",
torch.tensor(dipole_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
loss_dipole = weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.dipole_weight * loss_dipole
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})"
)
class WeightedEnergyForcesL1L2Loss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_absolute_error_energy(ref, pred, ddp)
loss_forces = mean_normed_error_forces(ref, pred, ddp)
return self.energy_weight * loss_energy + self.forces_weight * loss_forces
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f})"
)
###########################################################################################
# Implementation of MACE models and other models based E(3)-Equivariant MPNNs
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import Any, Callable, Dict, List, Optional, Type, Union
import numpy as np
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from mace.modules.radial import ZBLBasis
from mace.tools.scatter import scatter_sum
from .blocks import (
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
NonLinearDipoleReadoutBlock,
NonLinearReadoutBlock,
RadialEmbeddingBlock,
ScaleShiftBlock,
)
from .utils import (
compute_fixed_charge_dipole,
get_atomic_virials_stresses,
get_edge_vectors_and_lengths,
get_outputs,
get_symmetric_displacement,
prepare_graph,
)
# pylint: disable=C0302
@compile_mode("script")
class MACE(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
num_interactions: int,
num_elements: int,
hidden_irreps: o3.Irreps,
MLP_irreps: o3.Irreps,
atomic_energies: np.ndarray,
avg_num_neighbors: float,
atomic_numbers: List[int],
correlation: Union[int, List[int]],
gate: Optional[Callable],
pair_repulsion: bool = False,
distance_transform: str = "None",
radial_MLP: Optional[List[int]] = None,
radial_type: Optional[str] = "bessel",
heads: Optional[List[str]] = None,
cueq_config: Optional[Dict[str, Any]] = None,
lammps_mliap: Optional[bool] = False,
):
super().__init__()
self.register_buffer(
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)
)
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)
if heads is None:
heads = ["Default"]
self.heads = heads
if isinstance(correlation, int):
correlation = [correlation] * num_interactions
self.lammps_mliap = lammps_mliap
# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
self.node_embedding = LinearNodeEmbeddingBlock(
irreps_in=node_attr_irreps,
irreps_out=node_feats_irreps,
cueq_config=cueq_config,
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
radial_type=radial_type,
distance_transform=distance_transform,
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
if pair_repulsion:
self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff)
self.pair_repulsion = True
sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
num_features = hidden_irreps.count(o3.Irrep(0, 1))
interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()
self.spherical_harmonics = o3.SphericalHarmonics(
sh_irreps, normalize=True, normalization="component"
)
if radial_MLP is None:
radial_MLP = [64, 64, 64]
# Interactions and readout
self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies)
inter = interaction_cls_first(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
cueq_config=cueq_config,
)
self.interactions = torch.nn.ModuleList([inter])
# Use the appropriate self connection at the first layer for proper E0
use_sc_first = False
if "Residual" in str(interaction_cls_first):
use_sc_first = True
node_feats_irreps_out = inter.target_irreps
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation[0],
num_elements=num_elements,
use_sc=use_sc_first,
cueq_config=cueq_config,
)
self.products = torch.nn.ModuleList([prod])
self.readouts = torch.nn.ModuleList()
self.readouts.append(
LinearReadoutBlock(
hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config
)
)
for i in range(num_interactions - 1):
if i == num_interactions - 2:
hidden_irreps_out = str(
hidden_irreps[0]
) # Select only scalars for last layer
else:
hidden_irreps_out = hidden_irreps
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
cueq_config=cueq_config,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation[i + 1],
num_elements=num_elements,
use_sc=True,
cueq_config=cueq_config,
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
gate,
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
cueq_config,
)
)
else:
self.readouts.append(
LinearReadoutBlock(
hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config
)
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_hessian: bool = False,
compute_edge_forces: bool = False,
compute_atomic_stresses: bool = False,
lammps_mliap: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
# Setup
ctx = prepare_graph(
data,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_displacement=compute_displacement,
lammps_mliap=lammps_mliap,
)
is_lammps = ctx.is_lammps
num_atoms_arange = ctx.num_atoms_arange
num_graphs = ctx.num_graphs
displacement = ctx.displacement
positions = ctx.positions
vectors = ctx.vectors
lengths = ctx.lengths
cell = ctx.cell
node_heads = ctx.node_heads
interaction_kwargs = ctx.interaction_kwargs
lammps_natoms = interaction_kwargs.lammps_natoms
lammps_class = interaction_kwargs.lammps_class
# Atomic energies
node_e0 = self.atomic_energies_fn(data["node_attrs"])[
num_atoms_arange, node_heads
]
e0 = scatter_sum(
src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs
) # [n_graphs, n_heads]
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if hasattr(self, "pair_repulsion"):
pair_node_energy = self.pair_repulsion_fn(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if is_lammps:
pair_node_energy = pair_node_energy[: lammps_natoms[0]]
pair_energy = scatter_sum(
src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs
) # [n_graphs,]
else:
pair_node_energy = torch.zeros_like(node_e0)
pair_energy = torch.zeros_like(e0)
# Interactions
energies = [e0, pair_energy]
node_energies_list = [node_e0, pair_node_energy]
node_feats_concat: List[torch.Tensor] = []
for i, (interaction, product, readout) in enumerate(
zip(self.interactions, self.products, self.readouts)
):
node_attrs_slice = data["node_attrs"]
if is_lammps and i > 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats, sc = interaction(
node_attrs=node_attrs_slice,
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
first_layer=(i == 0),
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
)
if is_lammps and i == 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice
)
node_feats_concat.append(node_feats)
node_es = readout(node_feats, node_heads)[num_atoms_arange, node_heads]
energy = scatter_sum(node_es, data["batch"], dim=0, dim_size=num_graphs)
energies.append(energy)
node_energies_list.append(node_es)
contributions = torch.stack(energies, dim=-1)
total_energy = torch.sum(contributions, dim=-1)
node_energy = torch.sum(torch.stack(node_energies_list, dim=-1), dim=-1)
node_feats_out = torch.cat(node_feats_concat, dim=-1)
node_energy = node_e0.double() + pair_node_energy.double()
forces, virials, stress, hessian, edge_forces = get_outputs(
energy=total_energy,
positions=positions,
displacement=displacement,
vectors=vectors,
cell=cell,
training=training,
compute_force=compute_force,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_hessian=compute_hessian,
compute_edge_forces=compute_edge_forces,
)
atomic_virials: Optional[torch.Tensor] = None
atomic_stresses: Optional[torch.Tensor] = None
if compute_atomic_stresses and edge_forces is not None:
atomic_virials, atomic_stresses = get_atomic_virials_stresses(
edge_forces=edge_forces,
edge_index=data["edge_index"],
vectors=vectors,
num_atoms=positions.shape[0],
batch=data["batch"],
cell=cell,
)
return {
"energy": total_energy,
"node_energy": node_energy,
"contributions": contributions,
"forces": forces,
"edge_forces": edge_forces,
"virials": virials,
"stress": stress,
"atomic_virials": atomic_virials,
"atomic_stresses": atomic_stresses,
"displacement": displacement,
"hessian": hessian,
"node_feats": node_feats_out,
}
@compile_mode("script")
class ScaleShiftMACE(MACE):
def __init__(
self,
atomic_inter_scale: float,
atomic_inter_shift: float,
**kwargs,
):
super().__init__(**kwargs)
self.scale_shift = ScaleShiftBlock(
scale=atomic_inter_scale, shift=atomic_inter_shift
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_hessian: bool = False,
compute_edge_forces: bool = False,
compute_atomic_stresses: bool = False,
lammps_mliap: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
# Setup
ctx = prepare_graph(
data,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_displacement=compute_displacement,
lammps_mliap=lammps_mliap,
)
is_lammps = ctx.is_lammps
num_atoms_arange = ctx.num_atoms_arange
num_graphs = ctx.num_graphs
displacement = ctx.displacement
positions = ctx.positions
vectors = ctx.vectors
lengths = ctx.lengths
cell = ctx.cell
node_heads = ctx.node_heads
interaction_kwargs = ctx.interaction_kwargs
lammps_natoms = interaction_kwargs.lammps_natoms
lammps_class = interaction_kwargs.lammps_class
# Atomic energies
node_e0 = self.atomic_energies_fn(data["node_attrs"])[
num_atoms_arange, node_heads
]
e0 = scatter_sum(
src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs
) # [n_graphs, num_heads]
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if hasattr(self, "pair_repulsion"):
pair_node_energy = self.pair_repulsion_fn(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if is_lammps:
pair_node_energy = pair_node_energy[: lammps_natoms[0]]
else:
pair_node_energy = torch.zeros_like(node_e0)
# Interactions
node_es_list = [pair_node_energy]
node_feats_list: List[torch.Tensor] = []
for i, (interaction, product, readout) in enumerate(
zip(self.interactions, self.products, self.readouts)
):
node_attrs_slice = data["node_attrs"]
if is_lammps and i > 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats, sc = interaction(
node_attrs=node_attrs_slice,
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
first_layer=(i == 0),
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
)
if is_lammps and i == 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice
)
node_feats_list.append(node_feats)
node_es_list.append(
readout(node_feats, node_heads)[num_atoms_arange, node_heads]
)
node_feats_out = torch.cat(node_feats_list, dim=-1)
node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0)
node_inter_es = self.scale_shift(node_inter_es, node_heads)
inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs)
total_energy = e0 + inter_e
node_energy = node_e0.clone().double() + node_inter_es.clone().double()
forces, virials, stress, hessian, edge_forces = get_outputs(
energy=inter_e,
positions=positions,
displacement=displacement,
vectors=vectors,
cell=cell,
training=training,
compute_force=compute_force,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_hessian=compute_hessian,
compute_edge_forces=compute_edge_forces or compute_atomic_stresses,
)
atomic_virials: Optional[torch.Tensor] = None
atomic_stresses: Optional[torch.Tensor] = None
if compute_atomic_stresses and edge_forces is not None:
atomic_virials, atomic_stresses = get_atomic_virials_stresses(
edge_forces=edge_forces,
edge_index=data["edge_index"],
vectors=vectors,
num_atoms=positions.shape[0],
batch=data["batch"],
cell=cell,
)
return {
"energy": total_energy,
"node_energy": node_energy,
"interaction_energy": inter_e,
"forces": forces,
"edge_forces": edge_forces,
"virials": virials,
"stress": stress,
"atomic_virials": atomic_virials,
"atomic_stresses": atomic_stresses,
"hessian": hessian,
"displacement": displacement,
"node_feats": node_feats_out,
}
@compile_mode("script")
class AtomicDipolesMACE(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
num_interactions: int,
num_elements: int,
hidden_irreps: o3.Irreps,
MLP_irreps: o3.Irreps,
avg_num_neighbors: float,
atomic_numbers: List[int],
correlation: int,
gate: Optional[Callable],
atomic_energies: Optional[
None
], # Just here to make it compatible with energy models, MUST be None
radial_type: Optional[str] = "bessel",
radial_MLP: Optional[List[int]] = None,
cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument
):
super().__init__()
self.register_buffer(
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)
)
self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64))
self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)
assert atomic_energies is None
# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
self.node_embedding = LinearNodeEmbeddingBlock(
irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
radial_type=radial_type,
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
num_features = hidden_irreps.count(o3.Irrep(0, 1))
interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()
self.spherical_harmonics = o3.SphericalHarmonics(
sh_irreps, normalize=True, normalization="component"
)
if radial_MLP is None:
radial_MLP = [64, 64, 64]
# Interactions and readouts
inter = interaction_cls_first(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions = torch.nn.ModuleList([inter])
# Use the appropriate self connection at the first layer
use_sc_first = False
if "Residual" in str(interaction_cls_first):
use_sc_first = True
node_feats_irreps_out = inter.target_irreps
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation,
num_elements=num_elements,
use_sc=use_sc_first,
)
self.products = torch.nn.ModuleList([prod])
self.readouts = torch.nn.ModuleList()
self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True))
for i in range(num_interactions - 1):
if i == num_interactions - 2:
assert (
len(hidden_irreps) > 1
), "To predict dipoles use at least l=1 hidden_irreps"
hidden_irreps_out = str(
hidden_irreps[1]
) # Select only l=1 vectors for last layer
else:
hidden_irreps_out = hidden_irreps
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation,
num_elements=num_elements,
use_sc=True,
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearDipoleReadoutBlock(
hidden_irreps_out, MLP_irreps, gate, dipole_only=True
)
)
else:
self.readouts.append(
LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False, # pylint: disable=W0613
compute_force: bool = False,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_edge_forces: bool = False, # pylint: disable=W0613
compute_atomic_stresses: bool = False, # pylint: disable=W0613
) -> Dict[str, Optional[torch.Tensor]]:
assert compute_force is False
assert compute_virials is False
assert compute_stress is False
assert compute_displacement is False
# Setup
data["node_attrs"].requires_grad_(True)
data["positions"].requires_grad_(True)
num_graphs = data["ptr"].numel() - 1
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
vectors, lengths = get_edge_vectors_and_lengths(
positions=data["positions"],
edge_index=data["edge_index"],
shifts=data["shifts"],
)
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
# Interactions
dipoles = []
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
):
node_feats, sc = interaction(
node_attrs=data["node_attrs"],
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
)
node_feats = product(
node_feats=node_feats,
sc=sc,
node_attrs=data["node_attrs"],
)
node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3]
dipoles.append(node_dipoles)
# Compute the dipoles
contributions_dipoles = torch.stack(
dipoles, dim=-1
) # [n_nodes,3,n_contributions]
atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3]
total_dipole = scatter_sum(
src=atomic_dipoles,
index=data["batch"],
dim=0,
dim_size=num_graphs,
) # [n_graphs,3]
baseline = compute_fixed_charge_dipole(
charges=data["charges"],
positions=data["positions"],
batch=data["batch"],
num_graphs=num_graphs,
) # [n_graphs,3]
total_dipole = total_dipole + baseline
output = {
"dipole": total_dipole,
"atomic_dipoles": atomic_dipoles,
}
return output
@compile_mode("script")
class EnergyDipolesMACE(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
num_interactions: int,
num_elements: int,
hidden_irreps: o3.Irreps,
MLP_irreps: o3.Irreps,
avg_num_neighbors: float,
atomic_numbers: List[int],
correlation: int,
gate: Optional[Callable],
atomic_energies: Optional[np.ndarray],
radial_MLP: Optional[List[int]] = None,
cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument
):
super().__init__()
self.register_buffer(
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)
)
self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64))
self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)
# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
self.node_embedding = LinearNodeEmbeddingBlock(
irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
num_features = hidden_irreps.count(o3.Irrep(0, 1))
interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()
self.spherical_harmonics = o3.SphericalHarmonics(
sh_irreps, normalize=True, normalization="component"
)
if radial_MLP is None:
radial_MLP = [64, 64, 64]
# Interactions and readouts
self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies)
inter = interaction_cls_first(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions = torch.nn.ModuleList([inter])
# Use the appropriate self connection at the first layer
use_sc_first = False
if "Residual" in str(interaction_cls_first):
use_sc_first = True
node_feats_irreps_out = inter.target_irreps
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation,
num_elements=num_elements,
use_sc=use_sc_first,
)
self.products = torch.nn.ModuleList([prod])
self.readouts = torch.nn.ModuleList()
self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False))
for i in range(num_interactions - 1):
if i == num_interactions - 2:
assert (
len(hidden_irreps) > 1
), "To predict dipoles use at least l=1 hidden_irreps"
hidden_irreps_out = str(
hidden_irreps[:2]
) # Select scalars and l=1 vectors for last layer
else:
hidden_irreps_out = hidden_irreps
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation,
num_elements=num_elements,
use_sc=True,
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearDipoleReadoutBlock(
hidden_irreps_out, MLP_irreps, gate, dipole_only=False
)
)
else:
self.readouts.append(
LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_edge_forces: bool = False, # pylint: disable=W0613
compute_atomic_stresses: bool = False, # pylint: disable=W0613
) -> Dict[str, Optional[torch.Tensor]]:
# Setup
data["node_attrs"].requires_grad_(True)
data["positions"].requires_grad_(True)
num_graphs = data["ptr"].numel() - 1
num_atoms_arange = torch.arange(data["positions"].shape[0])
displacement = torch.zeros(
(num_graphs, 3, 3),
dtype=data["positions"].dtype,
device=data["positions"].device,
)
if compute_virials or compute_stress or compute_displacement:
(
data["positions"],
data["shifts"],
displacement,
) = get_symmetric_displacement(
positions=data["positions"],
unit_shifts=data["unit_shifts"],
cell=data["cell"],
edge_index=data["edge_index"],
num_graphs=num_graphs,
batch=data["batch"],
)
# Atomic energies
node_e0 = self.atomic_energies_fn(data["node_attrs"])[
num_atoms_arange, data["head"][data["batch"]]
]
e0 = scatter_sum(
src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs
) # [n_graphs,]
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
vectors, lengths = get_edge_vectors_and_lengths(
positions=data["positions"],
edge_index=data["edge_index"],
shifts=data["shifts"],
)
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
# Interactions
energies = [e0]
node_energies_list = [node_e0]
dipoles = []
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
):
node_feats, sc = interaction(
node_attrs=data["node_attrs"],
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
)
node_feats = product(
node_feats=node_feats,
sc=sc,
node_attrs=data["node_attrs"],
)
node_out = readout(node_feats).squeeze(-1) # [n_nodes, ]
# node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ]
node_energies = node_out[:, 0]
energy = scatter_sum(
src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs
) # [n_graphs,]
energies.append(energy)
node_dipoles = node_out[:, 1:]
dipoles.append(node_dipoles)
# Compute the energies and dipoles
contributions = torch.stack(energies, dim=-1)
total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ]
node_energy_contributions = torch.stack(node_energies_list, dim=-1)
node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ]
contributions_dipoles = torch.stack(
dipoles, dim=-1
) # [n_nodes,3,n_contributions]
atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3]
total_dipole = scatter_sum(
src=atomic_dipoles,
index=data["batch"].unsqueeze(-1),
dim=0,
dim_size=num_graphs,
) # [n_graphs,3]
baseline = compute_fixed_charge_dipole(
charges=data["charges"],
positions=data["positions"],
batch=data["batch"],
num_graphs=num_graphs,
) # [n_graphs,3]
total_dipole = total_dipole + baseline
forces, virials, stress, _, _ = get_outputs(
energy=total_energy,
positions=data["positions"],
displacement=displacement,
cell=data["cell"],
training=training,
compute_force=compute_force,
compute_virials=compute_virials,
compute_stress=compute_stress,
)
output = {
"energy": total_energy,
"node_energy": node_energy,
"contributions": contributions,
"forces": forces,
"virials": virials,
"stress": stress,
"displacement": displacement,
"dipole": total_dipole,
"atomic_dipoles": atomic_dipoles,
}
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment