Unverified Commit a69b2190 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Delete modules directory

parent d47e8ba0
from typing import Callable, Dict, Optional, Type
import torch
from .blocks import (
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
NonLinearDipoleReadoutBlock,
NonLinearReadoutBlock,
RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock,
RealAgnosticDensityInteractionBlock,
RealAgnosticDensityResidualInteractionBlock,
RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock,
ScaleShiftBlock,
)
from .loss import (
DipoleSingleLoss,
UniversalLoss,
WeightedEnergyForcesDipoleLoss,
WeightedEnergyForcesL1L2Loss,
WeightedEnergyForcesLoss,
WeightedEnergyForcesStressLoss,
WeightedEnergyForcesVirialsLoss,
WeightedForcesLoss,
WeightedHuberEnergyForcesStressLoss,
)
from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE
from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis
from .symmetric_contraction import SymmetricContraction
from .utils import (
compute_avg_num_neighbors,
compute_fixed_charge_dipole,
compute_mean_rms_energy_forces,
compute_mean_std_atomic_inter_energy,
compute_rms_dipoles,
compute_statistics,
)
interaction_classes: Dict[str, Type[InteractionBlock]] = {
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock,
"RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock,
}
scaling_classes: Dict[str, Callable] = {
"std_scaling": compute_mean_std_atomic_inter_energy,
"rms_forces_scaling": compute_mean_rms_energy_forces,
"rms_dipoles_scaling": compute_rms_dipoles,
}
gate_dict: Dict[str, Optional[Callable]] = {
"abs": torch.abs,
"tanh": torch.tanh,
"silu": torch.nn.functional.silu,
"None": None,
}
__all__ = [
"AtomicEnergiesBlock",
"RadialEmbeddingBlock",
"ZBLBasis",
"LinearNodeEmbeddingBlock",
"LinearReadoutBlock",
"EquivariantProductBasisBlock",
"ScaleShiftBlock",
"LinearDipoleReadoutBlock",
"NonLinearDipoleReadoutBlock",
"InteractionBlock",
"NonLinearReadoutBlock",
"PolynomialCutoff",
"BesselBasis",
"GaussianBasis",
"MACE",
"ScaleShiftMACE",
"AtomicDipolesMACE",
"EnergyDipolesMACE",
"WeightedEnergyForcesLoss",
"WeightedForcesLoss",
"WeightedEnergyForcesVirialsLoss",
"WeightedEnergyForcesStressLoss",
"DipoleSingleLoss",
"WeightedEnergyForcesDipoleLoss",
"WeightedHuberEnergyForcesStressLoss",
"UniversalLoss",
"WeightedEnergyForcesL1L2Loss",
"SymmetricContraction",
"interaction_classes",
"compute_mean_std_atomic_inter_energy",
"compute_avg_num_neighbors",
"compute_statistics",
"compute_fixed_charge_dipole",
]
###########################################################################################
# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch.nn.functional
from e3nn import nn, o3
from e3nn.util.jit import compile_mode
from mace.modules.wrapper_ops import (
CuEquivarianceConfig,
FullyConnectedTensorProduct,
Linear,
SymmetricContractionWrapper,
TensorProduct,
)
from mace.tools.compile import simplify_if_compile
from mace.tools.scatter import scatter_sum
from mace.tools.utils import LAMMPS_MP
from .irreps_tools import mask_head, reshape_irreps, tp_out_irreps_with_instructions
from .radial import (
AgnesiTransform,
BesselBasis,
ChebychevBasis,
GaussianBasis,
PolynomialCutoff,
SoftTransform,
)
@compile_mode("script")
class LinearNodeEmbeddingBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.linear = Linear(
irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config
)
def forward(
self,
node_attrs: torch.Tensor,
) -> torch.Tensor: # [n_nodes, irreps]
return self.linear(node_attrs)
@compile_mode("script")
class LinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irrep_out: o3.Irreps = o3.Irreps("0e"),
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.linear = Linear(
irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config
)
def forward(
self,
x: torch.Tensor,
heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@simplify_if_compile
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
gate: Optional[Callable],
irrep_out: o3.Irreps = o3.Irreps("0e"),
num_heads: int = 1,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.hidden_irreps = MLP_irreps
self.num_heads = num_heads
self.linear_1 = Linear(
irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config
)
self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate])
self.linear_2 = Linear(
irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config
)
def forward(
self, x: torch.Tensor, heads: Optional[torch.Tensor] = None
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
if hasattr(self, "num_heads"):
if self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
return self.linear_2(x) # [n_nodes, len(heads)]
@compile_mode("script")
class LinearDipoleReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
dipole_only: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
self.linear = Linear(
irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@compile_mode("script")
class NonLinearDipoleReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
gate: Callable,
dipole_only: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.hidden_irreps = MLP_irreps
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
irreps_scalars = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out]
)
irreps_gated = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out]
)
irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated)
self.equivariant_nonlin = nn.Gate(
irreps_scalars=irreps_scalars,
act_scalars=[gate for _, ir in irreps_scalars],
irreps_gates=irreps_gates,
act_gates=[gate] * len(irreps_gates),
irreps_gated=irreps_gated,
)
self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify()
self.linear_1 = Linear(
irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config
)
self.linear_2 = Linear(
irreps_in=self.hidden_irreps,
irreps_out=self.irreps_out,
cueq_config=cueq_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.equivariant_nonlin(self.linear_1(x))
return self.linear_2(x) # [n_nodes, 1]
@compile_mode("script")
class AtomicEnergiesBlock(torch.nn.Module):
atomic_energies: torch.Tensor
def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]):
super().__init__()
# assert len(atomic_energies.shape) == 1
self.register_buffer(
"atomic_energies",
torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),
) # [n_elements, n_heads]
def forward(
self, x: torch.Tensor # one-hot of elements [..., n_elements]
) -> torch.Tensor: # [..., ]
return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T)
def __repr__(self):
formatted_energies = ", ".join(
[
"[" + ", ".join([f"{x:.4f}" for x in group]) + "]"
for group in torch.atleast_2d(self.atomic_energies)
]
)
return f"{self.__class__.__name__}(energies=[{formatted_energies}])"
@compile_mode("script")
class RadialEmbeddingBlock(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
radial_type: str = "bessel",
distance_transform: str = "None",
):
super().__init__()
if radial_type == "bessel":
self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel)
elif radial_type == "gaussian":
self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel)
elif radial_type == "chebyshev":
self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel)
if distance_transform == "Agnesi":
self.distance_transform = AgnesiTransform()
elif distance_transform == "Soft":
self.distance_transform = SoftTransform()
self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff)
self.out_dim = num_bessel
def forward(
self,
edge_lengths: torch.Tensor, # [n_edges, 1]
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
):
cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1]
if hasattr(self, "distance_transform"):
edge_lengths = self.distance_transform(
edge_lengths, node_attrs, edge_index, atomic_numbers
)
radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis]
return radial * cutoff # [n_edges, n_basis]
@compile_mode("script")
class EquivariantProductBasisBlock(torch.nn.Module):
def __init__(
self,
node_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
correlation: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
) -> None:
super().__init__()
self.use_sc = use_sc
self.symmetric_contractions = SymmetricContractionWrapper(
irreps_in=node_feats_irreps,
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
cueq_config=cueq_config,
)
# Update linear
self.linear = Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=cueq_config,
)
self.cueq_config = cueq_config
def forward(
self,
node_feats: torch.Tensor,
sc: Optional[torch.Tensor],
node_attrs: torch.Tensor,
) -> torch.Tensor:
use_cueq = False
use_cueq_mul_ir = False
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.enabled and (
self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric
):
use_cueq = True
if self.cueq_config.layout_str == "mul_ir":
use_cueq_mul_ir = True
if use_cueq:
if use_cueq_mul_ir:
node_feats = torch.transpose(node_feats, 1, 2)
index_attrs = torch.nonzero(node_attrs)[:, 1].int()
node_feats = self.symmetric_contractions(
node_feats.flatten(1),
index_attrs,
)
else:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
@compile_mode("script")
class InteractionBlock(torch.nn.Module):
def __init__(
self,
node_attrs_irreps: o3.Irreps,
node_feats_irreps: o3.Irreps,
edge_attrs_irreps: o3.Irreps,
edge_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
hidden_irreps: o3.Irreps,
avg_num_neighbors: float,
radial_MLP: Optional[List[int]] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
) -> None:
super().__init__()
self.node_attrs_irreps = node_attrs_irreps
self.node_feats_irreps = node_feats_irreps
self.edge_attrs_irreps = edge_attrs_irreps
self.edge_feats_irreps = edge_feats_irreps
self.target_irreps = target_irreps
self.hidden_irreps = hidden_irreps
self.avg_num_neighbors = avg_num_neighbors
if radial_MLP is None:
radial_MLP = [64, 64, 64]
self.radial_MLP = radial_MLP
self.cueq_config = cueq_config
self._setup()
@abstractmethod
def _setup(self) -> None:
raise NotImplementedError
def handle_lammps(
self,
node_feats: torch.Tensor,
lammps_class: Optional[Any],
lammps_natoms: Tuple[int, int],
first_layer: bool,
) -> torch.Tensor: # noqa: D401 – internal helper
if lammps_class is None or first_layer or torch.jit.is_scripting():
return node_feats
_, n_total = lammps_natoms
pad = torch.zeros(
(n_total, node_feats.shape[1]),
dtype=node_feats.dtype,
device=node_feats.device,
)
node_feats = torch.cat((node_feats, pad), dim=0)
node_feats = LAMMPS_MP.apply(node_feats, lammps_class)
return node_feats
def truncate_ghosts(
self, tensor: torch.Tensor, n_real: Optional[int] = None
) -> torch.Tensor:
"""Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations."""
return tensor[:n_real] if n_real is not None else tensor
@abstractmethod
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh}
@compile_mode("script")
class RealAgnosticInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out,
self.node_attrs_irreps,
self.irreps_out,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_natoms: Tuple[int, int] = (0, 0),
lammps_class: Optional[Any] = None,
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
message = self.linear(message) / self.avg_num_neighbors
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps,
self.node_attrs_irreps,
self.hidden_irreps,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
sc = self.truncate_ghosts(sc, n_real)
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticDensityInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out,
self.node_attrs_irreps,
self.irreps_out,
cueq_config=self.cueq_config,
)
# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
density = self.truncate_ghosts(density, n_real)
message = self.linear(message) / (density + 1)
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticDensityResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps,
self.node_attrs_irreps,
self.hidden_irreps,
cueq_config=self.cueq_config,
)
# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
density = self.truncate_ghosts(density, n_real)
sc = self.truncate_ghosts(sc, n_real)
message = self.linear(message) / (density + 1)
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
self.node_feats_down_irreps = o3.Irreps("64x0e")
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
self.linear_down = Linear(
self.node_feats_irreps,
self.node_feats_down_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
input_dim = (
self.edge_feats_irreps.num_irreps
+ 2 * self.node_feats_down_irreps.num_irreps
)
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + 3 * [256] + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
# Skip connection.
self.skip_linear = Linear(
self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config
)
# pylint: disable=unused-argument
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_linear(node_feats)
node_feats_up = self.linear_up(node_feats)
node_feats_down = self.linear_down(node_feats)
augmented_edge_feats = torch.cat(
[
edge_feats,
node_feats_down[sender],
node_feats_down[receiver],
],
dim=-1,
)
tp_weights = self.conv_tp_weights(augmented_edge_feats)
mji = self.conv_tp(
node_feats_up[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class ScaleShiftBlock(torch.nn.Module):
def __init__(self, scale: float, shift: float):
super().__init__()
self.register_buffer(
"scale",
torch.tensor(scale, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"shift",
torch.tensor(shift, dtype=torch.get_default_dtype()),
)
def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor:
return (
torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head]
)
def __repr__(self):
formatted_scale = (
", ".join([f"{x:.4f}" for x in self.scale])
if self.scale.numel() > 1
else f"{self.scale.item():.4f}"
)
formatted_shift = (
", ".join([f"{x:.4f}" for x in self.shift])
if self.shift.numel() > 1
else f"{self.shift.item():.4f}"
)
return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment