Unverified Commit a69b2190 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Delete modules directory

parent d47e8ba0
###########################################################################################
# Elementary tools for handling irreducible representations
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import List, Optional, Tuple
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from mace.modules.wrapper_ops import CuEquivarianceConfig
# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
) -> Tuple[o3.Irreps, List]:
trainable = True
# Collect possible irreps and their instructions
irreps_out_list: List[Tuple[int, o3.Irreps]] = []
instructions = []
for i, (mul, ir_in) in enumerate(irreps1):
for j, (_, ir_edge) in enumerate(irreps2):
for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
if ir_out in target_irreps:
k = len(irreps_out_list) # instruction index
irreps_out_list.append((mul, ir_out))
instructions.append((i, j, k, "uvu", trainable))
# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_out = o3.Irreps(irreps_out_list)
irreps_out, permut, _ = irreps_out.sort()
# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, permut[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
instructions = sorted(instructions, key=lambda x: x[2])
return irreps_out, instructions
def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps:
# Assuming simplified irreps
irreps_mid = []
for _, ir_in in irreps:
found = False
for mul, ir_out in target_irreps:
if ir_in == ir_out:
irreps_mid.append((mul, ir_out))
found = True
break
if not found:
raise RuntimeError(f"{ir_in} not in {target_irreps}")
return o3.Irreps(irreps_mid)
@compile_mode("script")
class reshape_irreps(torch.nn.Module):
def __init__(
self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None
) -> None:
super().__init__()
self.irreps = o3.Irreps(irreps)
self.cueq_config = cueq_config
self.dims = []
self.muls = []
for mul, ir in self.irreps:
d = ir.dim
self.dims.append(d)
self.muls.append(mul)
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
ix = 0
out = []
batch, _ = tensor.shape
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.layout_str == "mul_ir":
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, d, mul)
else:
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, mul, d)
out.append(field)
if hasattr(self, "cueq_config"):
if self.cueq_config is not None: # pylint: disable=no-else-return
if self.cueq_config.layout_str == "mul_ir":
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-2)
else:
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-1)
def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor:
mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device)
idx = torch.arange(mask.shape[0], device=x.device)
mask[idx, :, head] = 1
mask = mask.permute(0, 2, 1).reshape(x.shape)
return x * mask
###########################################################################################
# Implementation of different loss functions
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import Optional
import torch
import torch.distributed as dist
from mace.tools import TensorDict
from mace.tools.torch_geometric import Batch
# ------------------------------------------------------------------------------
# Helper function for loss reduction that handles DDP correction
# ------------------------------------------------------------------------------
def is_ddp_enabled():
return dist.is_initialized() and dist.get_world_size() > 1
def reduce_loss(raw_loss: torch.Tensor, ddp: Optional[bool] = None) -> torch.Tensor:
"""
Reduces an element-wise loss tensor.
If ddp is True and distributed is initialized, the function computes:
loss = (local_sum * world_size) / global_num_elements
Otherwise, it returns the regular mean.
"""
ddp = is_ddp_enabled() if ddp is None else ddp
if ddp and dist.is_initialized():
world_size = dist.get_world_size()
n_local = raw_loss.numel()
loss_sum = raw_loss.sum()
total_samples = torch.tensor(
n_local, device=raw_loss.device, dtype=raw_loss.dtype
)
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
return loss_sum * world_size / total_samples
return raw_loss.mean()
# ------------------------------------------------------------------------------
# Energy Loss Functions
# ------------------------------------------------------------------------------
def mean_squared_error_energy(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
raw_loss = torch.square(ref["energy"] - pred["energy"])
return reduce_loss(raw_loss, ddp)
def weighted_mean_squared_error_energy(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
# Calculate per-graph number of atoms.
num_atoms = ref.ptr[1:] - ref.ptr[:-1] # shape: [n_graphs]
raw_loss = (
ref.weight
* ref.energy_weight
* torch.square((ref["energy"] - pred["energy"]) / num_atoms)
)
return reduce_loss(raw_loss, ddp)
def weighted_mean_absolute_error_energy(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
raw_loss = (
ref.weight
* ref.energy_weight
* torch.abs((ref["energy"] - pred["energy"]) / num_atoms)
)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Stress and Virials Loss Functions
# ------------------------------------------------------------------------------
def weighted_mean_squared_stress(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
configs_weight = ref.weight.view(-1, 1, 1)
configs_stress_weight = ref.stress_weight.view(-1, 1, 1)
raw_loss = (
configs_weight
* configs_stress_weight
* torch.square(ref["stress"] - pred["stress"])
)
return reduce_loss(raw_loss, ddp)
def weighted_mean_squared_virials(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
configs_weight = ref.weight.view(-1, 1, 1)
configs_virials_weight = ref.virials_weight.view(-1, 1, 1)
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1)
raw_loss = (
configs_weight
* configs_virials_weight
* torch.square((ref["virials"] - pred["virials"]) / num_atoms)
)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Forces Loss Functions
# ------------------------------------------------------------------------------
def mean_squared_error_forces(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
# Repeat per-graph weights to per-atom level.
configs_weight = torch.repeat_interleave(
ref.weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
configs_forces_weight = torch.repeat_interleave(
ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
raw_loss = (
configs_weight
* configs_forces_weight
* torch.square(ref["forces"] - pred["forces"])
)
return reduce_loss(raw_loss, ddp)
def mean_normed_error_forces(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
raw_loss = torch.linalg.vector_norm(ref["forces"] - pred["forces"], ord=2, dim=-1)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Dipole Loss Function
# ------------------------------------------------------------------------------
def weighted_mean_squared_error_dipole(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1)
raw_loss = torch.square((ref["dipole"] - pred["dipole"]) / num_atoms)
return reduce_loss(raw_loss, ddp)
# ------------------------------------------------------------------------------
# Conditional Losses for Forces
# ------------------------------------------------------------------------------
def conditional_mse_forces(
ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
configs_weight = torch.repeat_interleave(
ref.weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
configs_forces_weight = torch.repeat_interleave(
ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
# Define multiplication factors for different regimes.
factors = torch.tensor(
[1.0, 0.7, 0.4, 0.1], device=ref["forces"].device, dtype=ref["forces"].dtype
)
err = ref["forces"] - pred["forces"]
se = torch.zeros_like(err)
norm_forces = torch.norm(ref["forces"], dim=-1)
c1 = norm_forces < 100
c2 = (norm_forces >= 100) & (norm_forces < 200)
c3 = (norm_forces >= 200) & (norm_forces < 300)
se[c1] = torch.square(err[c1]) * factors[0]
se[c2] = torch.square(err[c2]) * factors[1]
se[c3] = torch.square(err[c3]) * factors[2]
se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3]
raw_loss = configs_weight * configs_forces_weight * se
return reduce_loss(raw_loss, ddp)
def conditional_huber_forces(
ref_forces: torch.Tensor,
pred_forces: torch.Tensor,
huber_delta: float,
ddp: Optional[bool] = None,
) -> torch.Tensor:
factors = huber_delta * torch.tensor(
[1.0, 0.7, 0.4, 0.1], device=ref_forces.device, dtype=ref_forces.dtype
)
norm_forces = torch.norm(ref_forces, dim=-1)
c1 = norm_forces < 100
c2 = (norm_forces >= 100) & (norm_forces < 200)
c3 = (norm_forces >= 200) & (norm_forces < 300)
c4 = ~(c1 | c2 | c3)
se = torch.zeros_like(pred_forces)
se[c1] = torch.nn.functional.huber_loss(
ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0]
)
se[c2] = torch.nn.functional.huber_loss(
ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1]
)
se[c3] = torch.nn.functional.huber_loss(
ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2]
)
se[c4] = torch.nn.functional.huber_loss(
ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3]
)
return reduce_loss(se, ddp)
# ------------------------------------------------------------------------------
# Loss Modules Combining Multiple Quantities
# ------------------------------------------------------------------------------
class WeightedEnergyForcesLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
return self.energy_weight * loss_energy + self.forces_weight * loss_forces
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f})"
)
class WeightedForcesLoss(torch.nn.Module):
def __init__(self, forces_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_forces = mean_squared_error_forces(ref, pred, ddp)
return self.forces_weight * loss_forces
def __repr__(self):
return f"{self.__class__.__name__}(forces_weight={self.forces_weight:.3f})"
class WeightedEnergyForcesStressLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
loss_stress = weighted_mean_squared_stress(ref, pred, ddp)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.stress_weight * loss_stress
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})"
)
class WeightedHuberEnergyForcesStressLoss(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01
) -> None:
super().__init__()
# We store the huber_delta rather than a loss with fixed reduction.
self.huber_delta = huber_delta
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
if ddp:
loss_energy = torch.nn.functional.huber_loss(
ref["energy"] / num_atoms,
pred["energy"] / num_atoms,
reduction="none",
delta=self.huber_delta,
)
loss_energy = reduce_loss(loss_energy, ddp)
loss_forces = torch.nn.functional.huber_loss(
ref["forces"], pred["forces"], reduction="none", delta=self.huber_delta
)
loss_forces = reduce_loss(loss_forces, ddp)
loss_stress = torch.nn.functional.huber_loss(
ref["stress"], pred["stress"], reduction="none", delta=self.huber_delta
)
loss_stress = reduce_loss(loss_stress, ddp)
else:
loss_energy = torch.nn.functional.huber_loss(
ref["energy"] / num_atoms,
pred["energy"] / num_atoms,
reduction="mean",
delta=self.huber_delta,
)
loss_forces = torch.nn.functional.huber_loss(
ref["forces"], pred["forces"], reduction="mean", delta=self.huber_delta
)
loss_stress = torch.nn.functional.huber_loss(
ref["stress"], pred["stress"], reduction="mean", delta=self.huber_delta
)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.stress_weight * loss_stress
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})"
)
class UniversalLoss(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01
) -> None:
super().__init__()
self.huber_delta = huber_delta
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1)
configs_energy_weight = ref.energy_weight
configs_forces_weight = torch.repeat_interleave(
ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
if ddp:
loss_energy = torch.nn.functional.huber_loss(
configs_energy_weight * ref["energy"] / num_atoms,
configs_energy_weight * pred["energy"] / num_atoms,
reduction="none",
delta=self.huber_delta,
)
loss_energy = reduce_loss(loss_energy, ddp)
loss_forces = conditional_huber_forces(
configs_forces_weight * ref["forces"],
configs_forces_weight * pred["forces"],
huber_delta=self.huber_delta,
ddp=ddp,
)
loss_stress = torch.nn.functional.huber_loss(
configs_stress_weight * ref["stress"],
configs_stress_weight * pred["stress"],
reduction="none",
delta=self.huber_delta,
)
loss_stress = reduce_loss(loss_stress, ddp)
else:
loss_energy = torch.nn.functional.huber_loss(
configs_energy_weight * ref["energy"] / num_atoms,
configs_energy_weight * pred["energy"] / num_atoms,
reduction="mean",
delta=self.huber_delta,
)
loss_forces = conditional_huber_forces(
configs_forces_weight * ref["forces"],
configs_forces_weight * pred["forces"],
huber_delta=self.huber_delta,
ddp=ddp,
)
loss_stress = torch.nn.functional.huber_loss(
configs_stress_weight * ref["stress"],
configs_stress_weight * pred["stress"],
reduction="mean",
delta=self.huber_delta,
)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.stress_weight * loss_stress
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})"
)
class WeightedEnergyForcesVirialsLoss(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0
) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"virials_weight",
torch.tensor(virials_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
loss_virials = weighted_mean_squared_virials(ref, pred, ddp)
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.virials_weight * loss_virials
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})"
)
class DipoleSingleLoss(torch.nn.Module):
def __init__(self, dipole_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"dipole_weight",
torch.tensor(dipole_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss = (
weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0
) # scale adjustment
return self.dipole_weight * loss
def __repr__(self):
return f"{self.__class__.__name__}(dipole_weight={self.dipole_weight:.3f})"
class WeightedEnergyForcesDipoleLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"dipole_weight",
torch.tensor(dipole_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp)
loss_forces = mean_squared_error_forces(ref, pred, ddp)
loss_dipole = weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0
return (
self.energy_weight * loss_energy
+ self.forces_weight * loss_forces
+ self.dipole_weight * loss_dipole
)
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})"
)
class WeightedEnergyForcesL1L2Loss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
def forward(
self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None
) -> torch.Tensor:
loss_energy = weighted_mean_absolute_error_energy(ref, pred, ddp)
loss_forces = mean_normed_error_forces(ref, pred, ddp)
return self.energy_weight * loss_energy + self.forces_weight * loss_forces
def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f})"
)
###########################################################################################
# Implementation of MACE models and other models based E(3)-Equivariant MPNNs
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import Any, Callable, Dict, List, Optional, Type, Union
import numpy as np
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from mace.modules.radial import ZBLBasis
from mace.tools.scatter import scatter_sum
from .blocks import (
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
NonLinearDipoleReadoutBlock,
NonLinearReadoutBlock,
RadialEmbeddingBlock,
ScaleShiftBlock,
)
from .utils import (
compute_fixed_charge_dipole,
get_atomic_virials_stresses,
get_edge_vectors_and_lengths,
get_outputs,
get_symmetric_displacement,
prepare_graph,
)
# pylint: disable=C0302
@compile_mode("script")
class MACE(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
num_interactions: int,
num_elements: int,
hidden_irreps: o3.Irreps,
MLP_irreps: o3.Irreps,
atomic_energies: np.ndarray,
avg_num_neighbors: float,
atomic_numbers: List[int],
correlation: Union[int, List[int]],
gate: Optional[Callable],
pair_repulsion: bool = False,
distance_transform: str = "None",
radial_MLP: Optional[List[int]] = None,
radial_type: Optional[str] = "bessel",
heads: Optional[List[str]] = None,
cueq_config: Optional[Dict[str, Any]] = None,
lammps_mliap: Optional[bool] = False,
):
super().__init__()
self.register_buffer(
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)
)
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)
if heads is None:
heads = ["Default"]
self.heads = heads
if isinstance(correlation, int):
correlation = [correlation] * num_interactions
self.lammps_mliap = lammps_mliap
# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
self.node_embedding = LinearNodeEmbeddingBlock(
irreps_in=node_attr_irreps,
irreps_out=node_feats_irreps,
cueq_config=cueq_config,
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
radial_type=radial_type,
distance_transform=distance_transform,
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
if pair_repulsion:
self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff)
self.pair_repulsion = True
sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
num_features = hidden_irreps.count(o3.Irrep(0, 1))
interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()
self.spherical_harmonics = o3.SphericalHarmonics(
sh_irreps, normalize=True, normalization="component"
)
if radial_MLP is None:
radial_MLP = [64, 64, 64]
# Interactions and readout
self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies)
inter = interaction_cls_first(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
cueq_config=cueq_config,
)
self.interactions = torch.nn.ModuleList([inter])
# Use the appropriate self connection at the first layer for proper E0
use_sc_first = False
if "Residual" in str(interaction_cls_first):
use_sc_first = True
node_feats_irreps_out = inter.target_irreps
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation[0],
num_elements=num_elements,
use_sc=use_sc_first,
cueq_config=cueq_config,
)
self.products = torch.nn.ModuleList([prod])
self.readouts = torch.nn.ModuleList()
self.readouts.append(
LinearReadoutBlock(
hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config
)
)
for i in range(num_interactions - 1):
if i == num_interactions - 2:
hidden_irreps_out = str(
hidden_irreps[0]
) # Select only scalars for last layer
else:
hidden_irreps_out = hidden_irreps
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
cueq_config=cueq_config,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation[i + 1],
num_elements=num_elements,
use_sc=True,
cueq_config=cueq_config,
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
gate,
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
cueq_config,
)
)
else:
self.readouts.append(
LinearReadoutBlock(
hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config
)
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_hessian: bool = False,
compute_edge_forces: bool = False,
compute_atomic_stresses: bool = False,
lammps_mliap: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
# Setup
ctx = prepare_graph(
data,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_displacement=compute_displacement,
lammps_mliap=lammps_mliap,
)
is_lammps = ctx.is_lammps
num_atoms_arange = ctx.num_atoms_arange
num_graphs = ctx.num_graphs
displacement = ctx.displacement
positions = ctx.positions
vectors = ctx.vectors
lengths = ctx.lengths
cell = ctx.cell
node_heads = ctx.node_heads
interaction_kwargs = ctx.interaction_kwargs
lammps_natoms = interaction_kwargs.lammps_natoms
lammps_class = interaction_kwargs.lammps_class
# Atomic energies
node_e0 = self.atomic_energies_fn(data["node_attrs"])[
num_atoms_arange, node_heads
]
e0 = scatter_sum(
src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs
) # [n_graphs, n_heads]
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if hasattr(self, "pair_repulsion"):
pair_node_energy = self.pair_repulsion_fn(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if is_lammps:
pair_node_energy = pair_node_energy[: lammps_natoms[0]]
pair_energy = scatter_sum(
src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs
) # [n_graphs,]
else:
pair_node_energy = torch.zeros_like(node_e0)
pair_energy = torch.zeros_like(e0)
# Interactions
energies = [e0, pair_energy]
node_energies_list = [node_e0, pair_node_energy]
node_feats_concat: List[torch.Tensor] = []
for i, (interaction, product, readout) in enumerate(
zip(self.interactions, self.products, self.readouts)
):
node_attrs_slice = data["node_attrs"]
if is_lammps and i > 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats, sc = interaction(
node_attrs=node_attrs_slice,
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
first_layer=(i == 0),
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
)
if is_lammps and i == 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice
)
node_feats_concat.append(node_feats)
node_es = readout(node_feats, node_heads)[num_atoms_arange, node_heads]
energy = scatter_sum(node_es, data["batch"], dim=0, dim_size=num_graphs)
energies.append(energy)
node_energies_list.append(node_es)
contributions = torch.stack(energies, dim=-1)
total_energy = torch.sum(contributions, dim=-1)
node_energy = torch.sum(torch.stack(node_energies_list, dim=-1), dim=-1)
node_feats_out = torch.cat(node_feats_concat, dim=-1)
node_energy = node_e0.double() + pair_node_energy.double()
forces, virials, stress, hessian, edge_forces = get_outputs(
energy=total_energy,
positions=positions,
displacement=displacement,
vectors=vectors,
cell=cell,
training=training,
compute_force=compute_force,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_hessian=compute_hessian,
compute_edge_forces=compute_edge_forces,
)
atomic_virials: Optional[torch.Tensor] = None
atomic_stresses: Optional[torch.Tensor] = None
if compute_atomic_stresses and edge_forces is not None:
atomic_virials, atomic_stresses = get_atomic_virials_stresses(
edge_forces=edge_forces,
edge_index=data["edge_index"],
vectors=vectors,
num_atoms=positions.shape[0],
batch=data["batch"],
cell=cell,
)
return {
"energy": total_energy,
"node_energy": node_energy,
"contributions": contributions,
"forces": forces,
"edge_forces": edge_forces,
"virials": virials,
"stress": stress,
"atomic_virials": atomic_virials,
"atomic_stresses": atomic_stresses,
"displacement": displacement,
"hessian": hessian,
"node_feats": node_feats_out,
}
@compile_mode("script")
class ScaleShiftMACE(MACE):
def __init__(
self,
atomic_inter_scale: float,
atomic_inter_shift: float,
**kwargs,
):
super().__init__(**kwargs)
self.scale_shift = ScaleShiftBlock(
scale=atomic_inter_scale, shift=atomic_inter_shift
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_hessian: bool = False,
compute_edge_forces: bool = False,
compute_atomic_stresses: bool = False,
lammps_mliap: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
# Setup
ctx = prepare_graph(
data,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_displacement=compute_displacement,
lammps_mliap=lammps_mliap,
)
is_lammps = ctx.is_lammps
num_atoms_arange = ctx.num_atoms_arange
num_graphs = ctx.num_graphs
displacement = ctx.displacement
positions = ctx.positions
vectors = ctx.vectors
lengths = ctx.lengths
cell = ctx.cell
node_heads = ctx.node_heads
interaction_kwargs = ctx.interaction_kwargs
lammps_natoms = interaction_kwargs.lammps_natoms
lammps_class = interaction_kwargs.lammps_class
# Atomic energies
node_e0 = self.atomic_energies_fn(data["node_attrs"])[
num_atoms_arange, node_heads
]
e0 = scatter_sum(
src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs
) # [n_graphs, num_heads]
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if hasattr(self, "pair_repulsion"):
pair_node_energy = self.pair_repulsion_fn(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
if is_lammps:
pair_node_energy = pair_node_energy[: lammps_natoms[0]]
else:
pair_node_energy = torch.zeros_like(node_e0)
# Interactions
node_es_list = [pair_node_energy]
node_feats_list: List[torch.Tensor] = []
for i, (interaction, product, readout) in enumerate(
zip(self.interactions, self.products, self.readouts)
):
node_attrs_slice = data["node_attrs"]
if is_lammps and i > 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats, sc = interaction(
node_attrs=node_attrs_slice,
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
first_layer=(i == 0),
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
)
if is_lammps and i == 0:
node_attrs_slice = node_attrs_slice[: lammps_natoms[0]]
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice
)
node_feats_list.append(node_feats)
node_es_list.append(
readout(node_feats, node_heads)[num_atoms_arange, node_heads]
)
node_feats_out = torch.cat(node_feats_list, dim=-1)
node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0)
node_inter_es = self.scale_shift(node_inter_es, node_heads)
inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs)
total_energy = e0 + inter_e
node_energy = node_e0.clone().double() + node_inter_es.clone().double()
forces, virials, stress, hessian, edge_forces = get_outputs(
energy=inter_e,
positions=positions,
displacement=displacement,
vectors=vectors,
cell=cell,
training=training,
compute_force=compute_force,
compute_virials=compute_virials,
compute_stress=compute_stress,
compute_hessian=compute_hessian,
compute_edge_forces=compute_edge_forces or compute_atomic_stresses,
)
atomic_virials: Optional[torch.Tensor] = None
atomic_stresses: Optional[torch.Tensor] = None
if compute_atomic_stresses and edge_forces is not None:
atomic_virials, atomic_stresses = get_atomic_virials_stresses(
edge_forces=edge_forces,
edge_index=data["edge_index"],
vectors=vectors,
num_atoms=positions.shape[0],
batch=data["batch"],
cell=cell,
)
return {
"energy": total_energy,
"node_energy": node_energy,
"interaction_energy": inter_e,
"forces": forces,
"edge_forces": edge_forces,
"virials": virials,
"stress": stress,
"atomic_virials": atomic_virials,
"atomic_stresses": atomic_stresses,
"hessian": hessian,
"displacement": displacement,
"node_feats": node_feats_out,
}
@compile_mode("script")
class AtomicDipolesMACE(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
num_interactions: int,
num_elements: int,
hidden_irreps: o3.Irreps,
MLP_irreps: o3.Irreps,
avg_num_neighbors: float,
atomic_numbers: List[int],
correlation: int,
gate: Optional[Callable],
atomic_energies: Optional[
None
], # Just here to make it compatible with energy models, MUST be None
radial_type: Optional[str] = "bessel",
radial_MLP: Optional[List[int]] = None,
cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument
):
super().__init__()
self.register_buffer(
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)
)
self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64))
self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)
assert atomic_energies is None
# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
self.node_embedding = LinearNodeEmbeddingBlock(
irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
radial_type=radial_type,
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
num_features = hidden_irreps.count(o3.Irrep(0, 1))
interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()
self.spherical_harmonics = o3.SphericalHarmonics(
sh_irreps, normalize=True, normalization="component"
)
if radial_MLP is None:
radial_MLP = [64, 64, 64]
# Interactions and readouts
inter = interaction_cls_first(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions = torch.nn.ModuleList([inter])
# Use the appropriate self connection at the first layer
use_sc_first = False
if "Residual" in str(interaction_cls_first):
use_sc_first = True
node_feats_irreps_out = inter.target_irreps
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation,
num_elements=num_elements,
use_sc=use_sc_first,
)
self.products = torch.nn.ModuleList([prod])
self.readouts = torch.nn.ModuleList()
self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True))
for i in range(num_interactions - 1):
if i == num_interactions - 2:
assert (
len(hidden_irreps) > 1
), "To predict dipoles use at least l=1 hidden_irreps"
hidden_irreps_out = str(
hidden_irreps[1]
) # Select only l=1 vectors for last layer
else:
hidden_irreps_out = hidden_irreps
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation,
num_elements=num_elements,
use_sc=True,
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearDipoleReadoutBlock(
hidden_irreps_out, MLP_irreps, gate, dipole_only=True
)
)
else:
self.readouts.append(
LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False, # pylint: disable=W0613
compute_force: bool = False,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_edge_forces: bool = False, # pylint: disable=W0613
compute_atomic_stresses: bool = False, # pylint: disable=W0613
) -> Dict[str, Optional[torch.Tensor]]:
assert compute_force is False
assert compute_virials is False
assert compute_stress is False
assert compute_displacement is False
# Setup
data["node_attrs"].requires_grad_(True)
data["positions"].requires_grad_(True)
num_graphs = data["ptr"].numel() - 1
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
vectors, lengths = get_edge_vectors_and_lengths(
positions=data["positions"],
edge_index=data["edge_index"],
shifts=data["shifts"],
)
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
# Interactions
dipoles = []
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
):
node_feats, sc = interaction(
node_attrs=data["node_attrs"],
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
)
node_feats = product(
node_feats=node_feats,
sc=sc,
node_attrs=data["node_attrs"],
)
node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3]
dipoles.append(node_dipoles)
# Compute the dipoles
contributions_dipoles = torch.stack(
dipoles, dim=-1
) # [n_nodes,3,n_contributions]
atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3]
total_dipole = scatter_sum(
src=atomic_dipoles,
index=data["batch"],
dim=0,
dim_size=num_graphs,
) # [n_graphs,3]
baseline = compute_fixed_charge_dipole(
charges=data["charges"],
positions=data["positions"],
batch=data["batch"],
num_graphs=num_graphs,
) # [n_graphs,3]
total_dipole = total_dipole + baseline
output = {
"dipole": total_dipole,
"atomic_dipoles": atomic_dipoles,
}
return output
@compile_mode("script")
class EnergyDipolesMACE(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
num_interactions: int,
num_elements: int,
hidden_irreps: o3.Irreps,
MLP_irreps: o3.Irreps,
avg_num_neighbors: float,
atomic_numbers: List[int],
correlation: int,
gate: Optional[Callable],
atomic_energies: Optional[np.ndarray],
radial_MLP: Optional[List[int]] = None,
cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument
):
super().__init__()
self.register_buffer(
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)
)
self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64))
self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)
# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
self.node_embedding = LinearNodeEmbeddingBlock(
irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
num_features = hidden_irreps.count(o3.Irrep(0, 1))
interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()
self.spherical_harmonics = o3.SphericalHarmonics(
sh_irreps, normalize=True, normalization="component"
)
if radial_MLP is None:
radial_MLP = [64, 64, 64]
# Interactions and readouts
self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies)
inter = interaction_cls_first(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions = torch.nn.ModuleList([inter])
# Use the appropriate self connection at the first layer
use_sc_first = False
if "Residual" in str(interaction_cls_first):
use_sc_first = True
node_feats_irreps_out = inter.target_irreps
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation,
num_elements=num_elements,
use_sc=use_sc_first,
)
self.products = torch.nn.ModuleList([prod])
self.readouts = torch.nn.ModuleList()
self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False))
for i in range(num_interactions - 1):
if i == num_interactions - 2:
assert (
len(hidden_irreps) > 1
), "To predict dipoles use at least l=1 hidden_irreps"
hidden_irreps_out = str(
hidden_irreps[:2]
) # Select scalars and l=1 vectors for last layer
else:
hidden_irreps_out = hidden_irreps
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation,
num_elements=num_elements,
use_sc=True,
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearDipoleReadoutBlock(
hidden_irreps_out, MLP_irreps, gate, dipole_only=False
)
)
else:
self.readouts.append(
LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)
)
def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_edge_forces: bool = False, # pylint: disable=W0613
compute_atomic_stresses: bool = False, # pylint: disable=W0613
) -> Dict[str, Optional[torch.Tensor]]:
# Setup
data["node_attrs"].requires_grad_(True)
data["positions"].requires_grad_(True)
num_graphs = data["ptr"].numel() - 1
num_atoms_arange = torch.arange(data["positions"].shape[0])
displacement = torch.zeros(
(num_graphs, 3, 3),
dtype=data["positions"].dtype,
device=data["positions"].device,
)
if compute_virials or compute_stress or compute_displacement:
(
data["positions"],
data["shifts"],
displacement,
) = get_symmetric_displacement(
positions=data["positions"],
unit_shifts=data["unit_shifts"],
cell=data["cell"],
edge_index=data["edge_index"],
num_graphs=num_graphs,
batch=data["batch"],
)
# Atomic energies
node_e0 = self.atomic_energies_fn(data["node_attrs"])[
num_atoms_arange, data["head"][data["batch"]]
]
e0 = scatter_sum(
src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs
) # [n_graphs,]
# Embeddings
node_feats = self.node_embedding(data["node_attrs"])
vectors, lengths = get_edge_vectors_and_lengths(
positions=data["positions"],
edge_index=data["edge_index"],
shifts=data["shifts"],
)
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(
lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
)
# Interactions
energies = [e0]
node_energies_list = [node_e0]
dipoles = []
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
):
node_feats, sc = interaction(
node_attrs=data["node_attrs"],
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
)
node_feats = product(
node_feats=node_feats,
sc=sc,
node_attrs=data["node_attrs"],
)
node_out = readout(node_feats).squeeze(-1) # [n_nodes, ]
# node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ]
node_energies = node_out[:, 0]
energy = scatter_sum(
src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs
) # [n_graphs,]
energies.append(energy)
node_dipoles = node_out[:, 1:]
dipoles.append(node_dipoles)
# Compute the energies and dipoles
contributions = torch.stack(energies, dim=-1)
total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ]
node_energy_contributions = torch.stack(node_energies_list, dim=-1)
node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ]
contributions_dipoles = torch.stack(
dipoles, dim=-1
) # [n_nodes,3,n_contributions]
atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3]
total_dipole = scatter_sum(
src=atomic_dipoles,
index=data["batch"].unsqueeze(-1),
dim=0,
dim_size=num_graphs,
) # [n_graphs,3]
baseline = compute_fixed_charge_dipole(
charges=data["charges"],
positions=data["positions"],
batch=data["batch"],
num_graphs=num_graphs,
) # [n_graphs,3]
total_dipole = total_dipole + baseline
forces, virials, stress, _, _ = get_outputs(
energy=total_energy,
positions=data["positions"],
displacement=displacement,
cell=data["cell"],
training=training,
compute_force=compute_force,
compute_virials=compute_virials,
compute_stress=compute_stress,
)
output = {
"energy": total_energy,
"node_energy": node_energy,
"contributions": contributions,
"forces": forces,
"virials": virials,
"stress": stress,
"displacement": displacement,
"dipole": total_dipole,
"atomic_dipoles": atomic_dipoles,
}
return output
###########################################################################################
# Radial basis and cutoff
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
import ase
import numpy as np
import torch
from e3nn.util.jit import compile_mode
from mace.tools.scatter import scatter_sum
@compile_mode("script")
class BesselBasis(torch.nn.Module):
"""
Equation (7)
"""
def __init__(self, r_max: float, num_basis=8, trainable=False):
super().__init__()
bessel_weights = (
np.pi
/ r_max
* torch.linspace(
start=1.0,
end=num_basis,
steps=num_basis,
dtype=torch.get_default_dtype(),
)
)
if trainable:
self.bessel_weights = torch.nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
self.register_buffer(
"prefactor",
torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()),
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
numerator = torch.sin(self.bessel_weights * x) # [..., num_basis]
return self.prefactor * (numerator / x)
def __repr__(self):
return (
f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, "
f"trainable={self.bessel_weights.requires_grad})"
)
@compile_mode("script")
class ChebychevBasis(torch.nn.Module):
"""
Equation (7)
"""
def __init__(self, r_max: float, num_basis=8):
super().__init__()
self.register_buffer(
"n",
torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze(
0
),
)
self.num_basis = num_basis
self.r_max = r_max
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
x = x.repeat(1, self.num_basis)
n = self.n.repeat(len(x), 1)
return torch.special.chebyshev_polynomial_t(x, n)
def __repr__(self):
return (
f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis},"
)
@compile_mode("script")
class GaussianBasis(torch.nn.Module):
"""
Gaussian basis functions
"""
def __init__(self, r_max: float, num_basis=128, trainable=False):
super().__init__()
gaussian_weights = torch.linspace(
start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype()
)
if trainable:
self.gaussian_weights = torch.nn.Parameter(
gaussian_weights, requires_grad=True
)
else:
self.register_buffer("gaussian_weights", gaussian_weights)
self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
x = x - self.gaussian_weights
return torch.exp(self.coeff * torch.pow(x, 2))
@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""
p: torch.Tensor
r_max: torch.Tensor
def __init__(self, r_max: float, p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.calculate_envelope(x, self.r_max, self.p.to(torch.int))
@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
return envelope * (x < r_max)
def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"
@compile_mode("script")
class ZBLBasis(torch.nn.Module):
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""
p: torch.Tensor
def __init__(self, p=6, trainable=False, **kwargs):
super().__init__()
if "r_max" in kwargs:
logging.warning(
"r_max is deprecated. r_max is determined from the covalent radii."
)
# Pre-calculate the p coefficients for the ZBL potential
self.register_buffer(
"c",
torch.tensor(
[0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype()
),
)
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
if trainable:
self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True))
self.a_prefactor = torch.nn.Parameter(
torch.tensor(0.4543, requires_grad=True)
)
else:
self.register_buffer("a_exp", torch.tensor(0.300))
self.register_buffer("a_prefactor", torch.tensor(0.4543))
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
a = (
self.a_prefactor
* 0.529
/ (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp))
)
r_over_a = x / a
phi = (
self.c[0] * torch.exp(-3.2 * r_over_a)
+ self.c[1] * torch.exp(-0.9423 * r_over_a)
+ self.c[2] * torch.exp(-0.4028 * r_over_a)
+ self.c[3] * torch.exp(-0.2016 * r_over_a)
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
def __repr__(self):
return f"{self.__class__.__name__}(c={self.c})"
@compile_mode("script")
class AgnesiTransform(torch.nn.Module):
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""
def __init__(
self,
q: float = 0.9183,
p: float = 4.5791,
a: float = 1.0805,
trainable=False,
):
super().__init__()
self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype()))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
if trainable:
self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True))
self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True))
self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True))
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()
def __repr__(self):
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""
Tanh-based smooth transformation:
T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))],
which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0.
"""
def __init__(self, alpha: float = 4.0, trainable=False):
"""
Args:
p1 (float): Lower "clamp" point.
alpha (float): Steepness; if None, defaults to ~6/(r0-p1).
trainable (bool): Whether to make parameters trainable.
"""
super().__init__()
# Initialize parameters
self.register_buffer(
"alpha", torch.tensor(alpha, dtype=torch.get_default_dtype())
)
if trainable:
self.alpha = torch.nn.Parameter(self.alpha.clone())
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
def compute_r_0(
self,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
"""
Compute r_0 based on atomic information.
Args:
node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers).
edge_index (torch.Tensor): Edge index indicating connections.
atomic_numbers (torch.Tensor): Atomic numbers.
Returns:
torch.Tensor: r_0 values for each edge.
"""
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
return r_0
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers)
p_0 = (3 / 4) * r_0
p_1 = (4 / 3) * r_0
m = 0.5 * (p_0 + p_1)
alpha = self.alpha / (p_1 - p_0)
s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m)))
return p_0 + (x - p_0) * s_x
def __repr__(self):
return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})"
###########################################################################################
# Implementation of the symmetric contraction algorithm presented in the MACE paper
# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11)
# Authors: Ilyes Batatia
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import Dict, Optional, Union
import opt_einsum_fx
import torch
import torch.fx
from e3nn import o3
from e3nn.util.codegen import CodeGenMixin
from e3nn.util.jit import compile_mode
from mace.tools.cg import U_matrix_real
BATCH_EXAMPLE = 10
ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"]
@compile_mode("script")
class SymmetricContraction(CodeGenMixin, torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
correlation: Union[int, Dict[str, int]],
irrep_normalization: str = "component",
path_normalization: str = "element",
internal_weights: Optional[bool] = None,
shared_weights: Optional[bool] = None,
num_elements: Optional[int] = None,
) -> None:
super().__init__()
if irrep_normalization is None:
irrep_normalization = "component"
if path_normalization is None:
path_normalization = "element"
assert irrep_normalization in ["component", "norm", "none"]
assert path_normalization in ["element", "path", "none"]
self.irreps_in = o3.Irreps(irreps_in)
self.irreps_out = o3.Irreps(irreps_out)
del irreps_in, irreps_out
if not isinstance(correlation, tuple):
corr = correlation
correlation = {}
for irrep_out in self.irreps_out:
correlation[irrep_out] = corr
assert shared_weights or not internal_weights
if internal_weights is None:
internal_weights = True
self.internal_weights = internal_weights
self.shared_weights = shared_weights
del internal_weights, shared_weights
self.contractions = torch.nn.ModuleList()
for irrep_out in self.irreps_out:
self.contractions.append(
Contraction(
irreps_in=self.irreps_in,
irrep_out=o3.Irreps(str(irrep_out.ir)),
correlation=correlation[irrep_out],
internal_weights=self.internal_weights,
num_elements=num_elements,
weights=self.shared_weights,
)
)
def forward(self, x: torch.Tensor, y: torch.Tensor):
outs = [contraction(x, y) for contraction in self.contractions]
return torch.cat(outs, dim=-1)
@compile_mode("script")
class Contraction(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irrep_out: o3.Irreps,
correlation: int,
internal_weights: bool = True,
num_elements: Optional[int] = None,
weights: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.num_features = irreps_in.count((0, 1))
self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
self.correlation = correlation
dtype = torch.get_default_dtype()
for nu in range(1, correlation + 1):
U_matrix = U_matrix_real(
irreps_in=self.coupling_irreps,
irreps_out=irrep_out,
correlation=nu,
dtype=dtype,
)[-1]
self.register_buffer(f"U_matrix_{nu}", U_matrix)
# Tensor contraction equations
self.contractions_weighting = torch.nn.ModuleList()
self.contractions_features = torch.nn.ModuleList()
# Create weight for product basis
self.weights = torch.nn.ParameterList([])
for i in range(correlation, 0, -1):
# Shapes definying
num_params = self.U_tensors(i).size()[-1]
num_equivariance = 2 * irrep_out.lmax + 1
num_ell = self.U_tensors(i).size()[-2]
if i == correlation:
parse_subscript_main = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
+ ["ik,ekc,bci,be -> bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
)
graph_module_main = torch.fx.symbolic_trace(
lambda x, y, w, z: torch.einsum(
"".join(parse_subscript_main), x, y, w, z
)
)
# Optimizing the contractions
self.graph_opt_main = opt_einsum_fx.optimize_einsums_full(
model=graph_module_main,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights_max = w
else:
# Generate optimized contractions equations
parse_subscript_weighting = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
+ ["k,ekc,be->bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
)
parse_subscript_features = (
["bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
+ ["i,bci->bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
)
# Symbolic tracing of contractions
graph_module_weighting = torch.fx.symbolic_trace(
lambda x, y, z: torch.einsum(
"".join(parse_subscript_weighting), x, y, z
)
)
graph_module_features = torch.fx.symbolic_trace(
lambda x, y: torch.einsum("".join(parse_subscript_features), x, y)
)
# Optimizing the contractions
graph_opt_weighting = opt_einsum_fx.optimize_einsums_full(
model=graph_module_weighting,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
graph_opt_features = opt_einsum_fx.optimize_einsums_full(
model=graph_module_features,
example_inputs=(
torch.randn(
[BATCH_EXAMPLE, self.num_features, num_equivariance]
+ [num_ell] * i
).squeeze(2),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
),
)
self.contractions_weighting.append(graph_opt_weighting)
self.contractions_features.append(graph_opt_features)
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights.append(w)
if not internal_weights:
self.weights = weights[:-1]
self.weights_max = weights[-1]
def forward(self, x: torch.Tensor, y: torch.Tensor):
out = self.graph_opt_main(
self.U_tensors(self.correlation),
self.weights_max,
x,
y,
)
for i, (weight, contract_weights, contract_features) in enumerate(
zip(self.weights, self.contractions_weighting, self.contractions_features)
):
c_tensor = contract_weights(
self.U_tensors(self.correlation - i - 1),
weight,
y,
)
c_tensor = c_tensor + out
out = contract_features(c_tensor, x)
return out.view(out.shape[0], -1)
def U_tensors(self, nu: int):
return dict(self.named_buffers())[f"U_matrix_{nu}"]
###########################################################################################
# Utilities
# Authors: Ilyes Batatia, Gregor Simm and David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
from typing import Dict, List, NamedTuple, Optional, Tuple
import numpy as np
import torch
import torch.utils.data
from scipy.constants import c, e
from mace.tools import to_numpy
from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum
from mace.tools.torch_geometric.batch import Batch
from .blocks import AtomicEnergiesBlock
def compute_forces(
energy: torch.Tensor, positions: torch.Tensor, training: bool = True
) -> torch.Tensor:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
gradient = torch.autograd.grad(
outputs=[energy], # [n_graphs, ]
inputs=[positions], # [n_nodes, 3]
grad_outputs=grad_outputs,
retain_graph=training, # Make sure the graph is not destroyed during training
create_graph=training, # Create graph for second derivative
allow_unused=True, # For complete dissociation turn to true
)[
0
] # [n_nodes, 3]
if gradient is None:
return torch.zeros_like(positions)
return -1 * gradient
def compute_forces_virials(
energy: torch.Tensor,
positions: torch.Tensor,
displacement: torch.Tensor,
cell: torch.Tensor,
training: bool = True,
compute_stress: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
forces, virials = torch.autograd.grad(
outputs=[energy], # [n_graphs, ]
inputs=[positions, displacement], # [n_nodes, 3]
grad_outputs=grad_outputs,
retain_graph=training, # Make sure the graph is not destroyed during training
create_graph=training, # Create graph for second derivative
allow_unused=True,
)
stress = torch.zeros_like(displacement)
if compute_stress and virials is not None:
cell = cell.view(-1, 3, 3)
volume = torch.linalg.det(cell).abs().unsqueeze(-1)
stress = virials / volume.view(-1, 1, 1)
stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress))
if forces is None:
forces = torch.zeros_like(positions)
if virials is None:
virials = torch.zeros((1, 3, 3))
return -1 * forces, -1 * virials, stress
def get_symmetric_displacement(
positions: torch.Tensor,
unit_shifts: torch.Tensor,
cell: Optional[torch.Tensor],
edge_index: torch.Tensor,
num_graphs: int,
batch: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if cell is None:
cell = torch.zeros(
num_graphs * 3,
3,
dtype=positions.dtype,
device=positions.device,
)
sender = edge_index[0]
displacement = torch.zeros(
(num_graphs, 3, 3),
dtype=positions.dtype,
device=positions.device,
)
displacement.requires_grad_(True)
symmetric_displacement = 0.5 * (
displacement + displacement.transpose(-1, -2)
) # From https://github.com/mir-group/nequip
positions = positions + torch.einsum(
"be,bec->bc", positions, symmetric_displacement[batch]
)
cell = cell.view(-1, 3, 3)
cell = cell + torch.matmul(cell, symmetric_displacement)
shifts = torch.einsum(
"be,bec->bc",
unit_shifts,
cell[batch[sender]],
)
return positions, shifts, displacement
@torch.jit.unused
def compute_hessians_vmap(
forces: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
forces_flatten = forces.view(-1)
num_elements = forces_flatten.shape[0]
def get_vjp(v):
return torch.autograd.grad(
-1 * forces_flatten,
positions,
v,
retain_graph=True,
create_graph=False,
allow_unused=False,
)
I_N = torch.eye(num_elements).to(forces.device)
try:
chunk_size = 1 if num_elements < 64 else 16
gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)(
I_N
)[0]
except RuntimeError:
gradient = compute_hessians_loop(forces, positions)
if gradient is None:
return torch.zeros((positions.shape[0], forces.shape[0], 3, 3))
return gradient
@torch.jit.unused
def compute_hessians_loop(
forces: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hessian = []
for grad_elem in forces.view(-1):
hess_row = torch.autograd.grad(
outputs=[-1 * grad_elem],
inputs=[positions],
grad_outputs=torch.ones_like(grad_elem),
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
hess_row = hess_row.detach() # this makes it very slow? but needs less memory
if hess_row is None:
hessian.append(torch.zeros_like(positions))
else:
hessian.append(hess_row)
hessian = torch.stack(hessian)
return hessian
def get_outputs(
energy: torch.Tensor,
positions: torch.Tensor,
cell: torch.Tensor,
displacement: Optional[torch.Tensor],
vectors: Optional[torch.Tensor] = None,
training: bool = False,
compute_force: bool = True,
compute_virials: bool = True,
compute_stress: bool = True,
compute_hessian: bool = False,
compute_edge_forces: bool = False,
) -> Tuple[
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
if (compute_virials or compute_stress) and displacement is not None:
forces, virials, stress = compute_forces_virials(
energy=energy,
positions=positions,
displacement=displacement,
cell=cell,
compute_stress=compute_stress,
training=(training or compute_hessian or compute_edge_forces),
)
elif compute_force:
forces, virials, stress = (
compute_forces(
energy=energy,
positions=positions,
training=(training or compute_hessian or compute_edge_forces),
),
None,
None,
)
else:
forces, virials, stress = (None, None, None)
if compute_hessian:
assert forces is not None, "Forces must be computed to get the hessian"
hessian = compute_hessians_vmap(forces, positions)
else:
hessian = None
if compute_edge_forces and vectors is not None:
edge_forces = compute_forces(
energy=energy,
positions=vectors,
training=(training or compute_hessian),
)
if edge_forces is not None:
edge_forces = -1 * edge_forces # Match LAMMPS sign convention
else:
edge_forces = None
return forces, virials, stress, hessian, edge_forces
def get_atomic_virials_stresses(
edge_forces: torch.Tensor, # [n_edges, 3]
edge_index: torch.Tensor, # [2, n_edges]
vectors: torch.Tensor, # [n_edges, 3]
num_atoms: int,
batch: torch.Tensor,
cell: torch.Tensor, # [n_graphs, 3, 3]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Compute atomic virials and optionally atomic stresses from edge forces and vectors.
From pobo95 PR #528.
Returns:
Tuple of:
- Atomic virials [num_atoms, 3, 3]
- Atomic stresses [num_atoms, 3, 3] (None if not computed)
"""
edge_virial = torch.einsum("zi,zj->zij", edge_forces, vectors)
atom_virial_sender = scatter_sum(
src=edge_virial, index=edge_index[0], dim=0, dim_size=num_atoms
)
atom_virial_receiver = scatter_sum(
src=edge_virial, index=edge_index[1], dim=0, dim_size=num_atoms
)
atom_virial = (atom_virial_sender + atom_virial_receiver) / 2
atom_virial = (atom_virial + atom_virial.transpose(-1, -2)) / 2
atom_stress = None
cell = cell.view(-1, 3, 3)
volume = torch.linalg.det(cell).abs().unsqueeze(-1)
atom_volume = volume[batch].view(-1, 1, 1)
atom_stress = atom_virial / atom_volume
atom_stress = torch.where(
torch.abs(atom_stress) < 1e10, atom_stress, torch.zeros_like(atom_stress)
)
return -1 * atom_virial, atom_stress
def get_edge_vectors_and_lengths(
positions: torch.Tensor, # [n_nodes, 3]
edge_index: torch.Tensor, # [2, n_edges]
shifts: torch.Tensor, # [n_edges, 3]
normalize: bool = False,
eps: float = 1e-9,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3]
lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1]
if normalize:
vectors_normed = vectors / (lengths + eps)
return vectors_normed, lengths
return vectors, lengths
def _check_non_zero(std):
if np.any(std == 0):
logging.warning(
"Standard deviation of the scaling is zero, Changing to no scaling"
)
std[std == 0] = 1
return std
def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int):
out = []
out.append(x[:, :num_features])
for i in range(1, num_layers):
out.append(
x[
:,
i
* (l_max + 1) ** 2
* num_features : (i * (l_max + 1) ** 2 + 1)
* num_features,
]
)
return torch.cat(out, dim=-1)
def compute_mean_std_atomic_inter_energy(
data_loader: torch.utils.data.DataLoader,
atomic_energies: np.ndarray,
) -> Tuple[float, float]:
atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies)
avg_atom_inter_es_list = []
head_list = []
for batch in data_loader:
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), batch.head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
avg_atom_inter_es_list.append(
(batch.energy - graph_e0s) / graph_sizes
) # {[n_graphs], }
head_list.append(batch.head)
avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs]
head = torch.cat(head_list, dim=0) # [total_n_graphs]
# mean = to_numpy(torch.mean(avg_atom_inter_es)).item()
# std = to_numpy(torch.std(avg_atom_inter_es)).item()
mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1))
std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1))
std = _check_non_zero(std)
return mean, std
def _compute_mean_std_atomic_inter_energy(
batch: Batch,
atomic_energies_fn: AtomicEnergiesBlock,
) -> Tuple[torch.Tensor, torch.Tensor]:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energies = (batch.energy - graph_e0s) / graph_sizes
return atom_energies
def compute_mean_rms_energy_forces(
data_loader: torch.utils.data.DataLoader,
atomic_energies: np.ndarray,
) -> Tuple[float, float]:
atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies)
atom_energy_list = []
forces_list = []
head_list = []
head_batch = []
for batch in data_loader:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energy_list.append(
(batch.energy - graph_e0s) / graph_sizes
) # {[n_graphs], }
forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], }
head_list.append(head)
head_batch.append(head[batch.batch])
atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs]
forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], }
head = torch.cat(head_list, dim=0) # [total_n_graphs]
head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs]
# mean = to_numpy(torch.mean(atom_energies)).item()
# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item()
mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1))
rms = to_numpy(
torch.sqrt(
scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1)
)
)
rms = _check_non_zero(rms)
return mean, rms
def _compute_mean_rms_energy_forces(
batch: Batch,
atomic_energies_fn: AtomicEnergiesBlock,
) -> Tuple[torch.Tensor, torch.Tensor]:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], }
forces = batch.forces # {[n_graphs*n_atoms,3], }
return atom_energies, forces
def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float:
num_neighbors = []
for batch in data_loader:
_, receivers = batch.edge_index
_, counts = torch.unique(receivers, return_counts=True)
num_neighbors.append(counts)
avg_num_neighbors = torch.mean(
torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype())
)
return to_numpy(avg_num_neighbors).item()
def compute_statistics(
data_loader: torch.utils.data.DataLoader,
atomic_energies: np.ndarray,
) -> Tuple[float, float, float, float]:
atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies)
atom_energy_list = []
forces_list = []
num_neighbors = []
head_list = []
head_batch = []
for batch in data_loader:
head = batch.head
node_e0 = atomic_energies_fn(batch.node_attrs)
graph_e0s = scatter_sum(
src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs
)[torch.arange(batch.num_graphs), head]
graph_sizes = batch.ptr[1:] - batch.ptr[:-1]
atom_energy_list.append(
(batch.energy - graph_e0s) / graph_sizes
) # {[n_graphs], }
forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], }
head_list.append(head) # {[n_graphs], }
head_batch.append(head[batch.batch])
_, receivers = batch.edge_index
_, counts = torch.unique(receivers, return_counts=True)
num_neighbors.append(counts)
atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs]
forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], }
head = torch.cat(head_list, dim=0) # [total_n_graphs]
head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs]
# mean = to_numpy(torch.mean(atom_energies)).item()
mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1))
rms = to_numpy(
torch.sqrt(
scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1)
)
)
avg_num_neighbors = torch.mean(
torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype())
)
return to_numpy(avg_num_neighbors).item(), mean, rms
def compute_rms_dipoles(
data_loader: torch.utils.data.DataLoader,
) -> Tuple[float, float]:
dipoles_list = []
for batch in data_loader:
dipoles_list.append(batch.dipole) # {[n_graphs,3], }
dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], }
rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item()
rms = _check_non_zero(rms)
return rms
def compute_fixed_charge_dipole(
charges: torch.Tensor,
positions: torch.Tensor,
batch: torch.Tensor,
num_graphs: int,
) -> torch.Tensor:
mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3]
return scatter_sum(
src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs
) # [N_graphs,3]
class InteractionKwargs(NamedTuple):
lammps_class: Optional[torch.Tensor]
lammps_natoms: Tuple[int, int] = (0, 0)
class GraphContext(NamedTuple):
is_lammps: bool
num_graphs: int
num_atoms_arange: torch.Tensor
displacement: Optional[torch.Tensor]
positions: torch.Tensor
vectors: torch.Tensor
lengths: torch.Tensor
cell: torch.Tensor
node_heads: torch.Tensor
interaction_kwargs: InteractionKwargs
def prepare_graph(
data: Dict[str, torch.Tensor],
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
lammps_mliap: bool = False,
) -> GraphContext:
if torch.jit.is_scripting():
lammps_mliap = False
node_heads = (
data["head"][data["batch"]]
if "head" in data
else torch.zeros_like(data["batch"])
)
if lammps_mliap:
n_real, n_total = data["natoms"][0], data["natoms"][1]
num_graphs = 2
num_atoms_arange = torch.arange(n_real, device=data["node_attrs"].device)
displacement = None
positions = torch.zeros(
(int(n_real), 3),
dtype=data["vectors"].dtype,
device=data["vectors"].device,
)
cell = torch.zeros(
(num_graphs, 3, 3),
dtype=data["vectors"].dtype,
device=data["vectors"].device,
)
vectors = data["vectors"].requires_grad_(True)
lengths = torch.linalg.vector_norm(vectors, dim=1, keepdim=True)
ikw = InteractionKwargs(data["lammps_class"], (n_real, n_total))
else:
data["positions"].requires_grad_(True)
positions = data["positions"]
cell = data["cell"]
num_atoms_arange = torch.arange(positions.shape[0], device=positions.device)
num_graphs = int(data["ptr"].numel() - 1)
displacement = torch.zeros(
(num_graphs, 3, 3), dtype=positions.dtype, device=positions.device
)
if compute_virials or compute_stress or compute_displacement:
p, s, displacement = get_symmetric_displacement(
positions=positions,
unit_shifts=data["unit_shifts"],
cell=cell,
edge_index=data["edge_index"],
num_graphs=num_graphs,
batch=data["batch"],
)
data["positions"], data["shifts"] = p, s
vectors, lengths = get_edge_vectors_and_lengths(
positions=data["positions"],
edge_index=data["edge_index"],
shifts=data["shifts"],
)
ikw = InteractionKwargs(None, (0, 0))
return GraphContext(
is_lammps=lammps_mliap,
num_graphs=num_graphs,
num_atoms_arange=num_atoms_arange,
displacement=displacement,
positions=positions,
vectors=vectors,
lengths=lengths,
cell=cell,
node_heads=node_heads,
interaction_kwargs=ikw,
)
"""
Wrapper class for o3.Linear that optionally uses cuet.Linear
"""
import dataclasses
from typing import List, Optional
import torch
from e3nn import o3
from mace.modules.symmetric_contraction import SymmetricContraction
from mace.tools.cg import O3_e3nn
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
CUET_AVAILABLE = True
except ImportError:
CUET_AVAILABLE = False
@dataclasses.dataclass
class CuEquivarianceConfig:
"""Configuration for cuequivariance acceleration"""
enabled: bool = False
layout: str = "mul_ir" # One of: mul_ir, ir_mul
layout_str: str = "mul_ir"
group: str = "O3"
optimize_all: bool = False # Set to True to enable all optimizations
optimize_linear: bool = False
optimize_channelwise: bool = False
optimize_symmetric: bool = False
optimize_fctp: bool = False
def __post_init__(self):
if self.enabled and CUET_AVAILABLE:
self.layout_str = self.layout
self.layout = getattr(cue, self.layout)
self.group = (
O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group)
)
if not CUET_AVAILABLE:
self.enabled = False
class Linear:
"""Returns either a cuet.Linear or o3.Linear based on config"""
def __new__(
cls,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
shared_weights: bool = True,
internal_weights: bool = True,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_linear)
):
return cuet.Linear(
cue.Irreps(cueq_config.group, irreps_in),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
use_fallback=True,
)
return o3.Linear(
irreps_in,
irreps_out,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class TensorProduct:
"""Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct"""
def __new__(
cls,
irreps_in1: o3.Irreps,
irreps_in2: o3.Irreps,
irreps_out: o3.Irreps,
instructions: Optional[List] = None,
shared_weights: bool = False,
internal_weights: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_channelwise)
):
return cuet.ChannelWiseTensorProduct(
cue.Irreps(cueq_config.group, irreps_in1),
cue.Irreps(cueq_config.group, irreps_in2),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
internal_weights=internal_weights,
dtype=torch.get_default_dtype(),
math_dtype=torch.get_default_dtype(),
)
return o3.TensorProduct(
irreps_in1,
irreps_in2,
irreps_out,
instructions=instructions,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class FullyConnectedTensorProduct:
"""Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct"""
def __new__(
cls,
irreps_in1: o3.Irreps,
irreps_in2: o3.Irreps,
irreps_out: o3.Irreps,
shared_weights: bool = True,
internal_weights: bool = True,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_fctp)
):
return cuet.FullyConnectedTensorProduct(
cue.Irreps(cueq_config.group, irreps_in1),
cue.Irreps(cueq_config.group, irreps_in2),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
internal_weights=internal_weights,
use_fallback=True,
)
return o3.FullyConnectedTensorProduct(
irreps_in1,
irreps_in2,
irreps_out,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class SymmetricContractionWrapper:
"""Wrapper around SymmetricContraction/cuet.SymmetricContraction"""
def __new__(
cls,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
correlation: int,
num_elements: Optional[int] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_symmetric)
):
return cuet.SymmetricContraction(
cue.Irreps(cueq_config.group, irreps_in),
cue.Irreps(cueq_config.group, irreps_out),
layout_in=cue.ir_mul,
layout_out=cueq_config.layout,
contraction_degree=correlation,
num_elements=num_elements,
original_mace=True,
dtype=torch.get_default_dtype(),
math_dtype=torch.get_default_dtype(),
)
return SymmetricContraction(
irreps_in=irreps_in,
irreps_out=irreps_out,
correlation=correlation,
num_elements=num_elements,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment