"projects/Task020_RibFrac/vscode:/vscode.git/clone" did not exist on "7246044d8824f7b3f6c243db054b61420212ad05"
Unverified Commit a69b2190 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Delete modules directory

parent d47e8ba0
from typing import Callable, Dict, Optional, Type
import torch
from .blocks import (
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
NonLinearDipoleReadoutBlock,
NonLinearReadoutBlock,
RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock,
RealAgnosticDensityInteractionBlock,
RealAgnosticDensityResidualInteractionBlock,
RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock,
ScaleShiftBlock,
)
from .loss import (
DipoleSingleLoss,
UniversalLoss,
WeightedEnergyForcesDipoleLoss,
WeightedEnergyForcesL1L2Loss,
WeightedEnergyForcesLoss,
WeightedEnergyForcesStressLoss,
WeightedEnergyForcesVirialsLoss,
WeightedForcesLoss,
WeightedHuberEnergyForcesStressLoss,
)
from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE
from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis
from .symmetric_contraction import SymmetricContraction
from .utils import (
compute_avg_num_neighbors,
compute_fixed_charge_dipole,
compute_mean_rms_energy_forces,
compute_mean_std_atomic_inter_energy,
compute_rms_dipoles,
compute_statistics,
)
interaction_classes: Dict[str, Type[InteractionBlock]] = {
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock,
"RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock,
}
scaling_classes: Dict[str, Callable] = {
"std_scaling": compute_mean_std_atomic_inter_energy,
"rms_forces_scaling": compute_mean_rms_energy_forces,
"rms_dipoles_scaling": compute_rms_dipoles,
}
gate_dict: Dict[str, Optional[Callable]] = {
"abs": torch.abs,
"tanh": torch.tanh,
"silu": torch.nn.functional.silu,
"None": None,
}
__all__ = [
"AtomicEnergiesBlock",
"RadialEmbeddingBlock",
"ZBLBasis",
"LinearNodeEmbeddingBlock",
"LinearReadoutBlock",
"EquivariantProductBasisBlock",
"ScaleShiftBlock",
"LinearDipoleReadoutBlock",
"NonLinearDipoleReadoutBlock",
"InteractionBlock",
"NonLinearReadoutBlock",
"PolynomialCutoff",
"BesselBasis",
"GaussianBasis",
"MACE",
"ScaleShiftMACE",
"AtomicDipolesMACE",
"EnergyDipolesMACE",
"WeightedEnergyForcesLoss",
"WeightedForcesLoss",
"WeightedEnergyForcesVirialsLoss",
"WeightedEnergyForcesStressLoss",
"DipoleSingleLoss",
"WeightedEnergyForcesDipoleLoss",
"WeightedHuberEnergyForcesStressLoss",
"UniversalLoss",
"WeightedEnergyForcesL1L2Loss",
"SymmetricContraction",
"interaction_classes",
"compute_mean_std_atomic_inter_energy",
"compute_avg_num_neighbors",
"compute_statistics",
"compute_fixed_charge_dipole",
]
###########################################################################################
# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch.nn.functional
from e3nn import nn, o3
from e3nn.util.jit import compile_mode
from mace.modules.wrapper_ops import (
CuEquivarianceConfig,
FullyConnectedTensorProduct,
Linear,
SymmetricContractionWrapper,
TensorProduct,
)
from mace.tools.compile import simplify_if_compile
from mace.tools.scatter import scatter_sum
from mace.tools.utils import LAMMPS_MP
from .irreps_tools import mask_head, reshape_irreps, tp_out_irreps_with_instructions
from .radial import (
AgnesiTransform,
BesselBasis,
ChebychevBasis,
GaussianBasis,
PolynomialCutoff,
SoftTransform,
)
@compile_mode("script")
class LinearNodeEmbeddingBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.linear = Linear(
irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config
)
def forward(
self,
node_attrs: torch.Tensor,
) -> torch.Tensor: # [n_nodes, irreps]
return self.linear(node_attrs)
@compile_mode("script")
class LinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irrep_out: o3.Irreps = o3.Irreps("0e"),
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.linear = Linear(
irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config
)
def forward(
self,
x: torch.Tensor,
heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@simplify_if_compile
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
gate: Optional[Callable],
irrep_out: o3.Irreps = o3.Irreps("0e"),
num_heads: int = 1,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.hidden_irreps = MLP_irreps
self.num_heads = num_heads
self.linear_1 = Linear(
irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config
)
self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate])
self.linear_2 = Linear(
irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config
)
def forward(
self, x: torch.Tensor, heads: Optional[torch.Tensor] = None
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
if hasattr(self, "num_heads"):
if self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
return self.linear_2(x) # [n_nodes, len(heads)]
@compile_mode("script")
class LinearDipoleReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
dipole_only: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
self.linear = Linear(
irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@compile_mode("script")
class NonLinearDipoleReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
gate: Callable,
dipole_only: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
super().__init__()
self.hidden_irreps = MLP_irreps
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
irreps_scalars = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out]
)
irreps_gated = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out]
)
irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated)
self.equivariant_nonlin = nn.Gate(
irreps_scalars=irreps_scalars,
act_scalars=[gate for _, ir in irreps_scalars],
irreps_gates=irreps_gates,
act_gates=[gate] * len(irreps_gates),
irreps_gated=irreps_gated,
)
self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify()
self.linear_1 = Linear(
irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config
)
self.linear_2 = Linear(
irreps_in=self.hidden_irreps,
irreps_out=self.irreps_out,
cueq_config=cueq_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.equivariant_nonlin(self.linear_1(x))
return self.linear_2(x) # [n_nodes, 1]
@compile_mode("script")
class AtomicEnergiesBlock(torch.nn.Module):
atomic_energies: torch.Tensor
def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]):
super().__init__()
# assert len(atomic_energies.shape) == 1
self.register_buffer(
"atomic_energies",
torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),
) # [n_elements, n_heads]
def forward(
self, x: torch.Tensor # one-hot of elements [..., n_elements]
) -> torch.Tensor: # [..., ]
return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T)
def __repr__(self):
formatted_energies = ", ".join(
[
"[" + ", ".join([f"{x:.4f}" for x in group]) + "]"
for group in torch.atleast_2d(self.atomic_energies)
]
)
return f"{self.__class__.__name__}(energies=[{formatted_energies}])"
@compile_mode("script")
class RadialEmbeddingBlock(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
radial_type: str = "bessel",
distance_transform: str = "None",
):
super().__init__()
if radial_type == "bessel":
self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel)
elif radial_type == "gaussian":
self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel)
elif radial_type == "chebyshev":
self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel)
if distance_transform == "Agnesi":
self.distance_transform = AgnesiTransform()
elif distance_transform == "Soft":
self.distance_transform = SoftTransform()
self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff)
self.out_dim = num_bessel
def forward(
self,
edge_lengths: torch.Tensor, # [n_edges, 1]
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
):
cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1]
if hasattr(self, "distance_transform"):
edge_lengths = self.distance_transform(
edge_lengths, node_attrs, edge_index, atomic_numbers
)
radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis]
return radial * cutoff # [n_edges, n_basis]
@compile_mode("script")
class EquivariantProductBasisBlock(torch.nn.Module):
def __init__(
self,
node_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
correlation: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
) -> None:
super().__init__()
self.use_sc = use_sc
self.symmetric_contractions = SymmetricContractionWrapper(
irreps_in=node_feats_irreps,
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
cueq_config=cueq_config,
)
# Update linear
self.linear = Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=cueq_config,
)
self.cueq_config = cueq_config
def forward(
self,
node_feats: torch.Tensor,
sc: Optional[torch.Tensor],
node_attrs: torch.Tensor,
) -> torch.Tensor:
use_cueq = False
use_cueq_mul_ir = False
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.enabled and (
self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric
):
use_cueq = True
if self.cueq_config.layout_str == "mul_ir":
use_cueq_mul_ir = True
if use_cueq:
if use_cueq_mul_ir:
node_feats = torch.transpose(node_feats, 1, 2)
index_attrs = torch.nonzero(node_attrs)[:, 1].int()
node_feats = self.symmetric_contractions(
node_feats.flatten(1),
index_attrs,
)
else:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
@compile_mode("script")
class InteractionBlock(torch.nn.Module):
def __init__(
self,
node_attrs_irreps: o3.Irreps,
node_feats_irreps: o3.Irreps,
edge_attrs_irreps: o3.Irreps,
edge_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
hidden_irreps: o3.Irreps,
avg_num_neighbors: float,
radial_MLP: Optional[List[int]] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
) -> None:
super().__init__()
self.node_attrs_irreps = node_attrs_irreps
self.node_feats_irreps = node_feats_irreps
self.edge_attrs_irreps = edge_attrs_irreps
self.edge_feats_irreps = edge_feats_irreps
self.target_irreps = target_irreps
self.hidden_irreps = hidden_irreps
self.avg_num_neighbors = avg_num_neighbors
if radial_MLP is None:
radial_MLP = [64, 64, 64]
self.radial_MLP = radial_MLP
self.cueq_config = cueq_config
self._setup()
@abstractmethod
def _setup(self) -> None:
raise NotImplementedError
def handle_lammps(
self,
node_feats: torch.Tensor,
lammps_class: Optional[Any],
lammps_natoms: Tuple[int, int],
first_layer: bool,
) -> torch.Tensor: # noqa: D401 – internal helper
if lammps_class is None or first_layer or torch.jit.is_scripting():
return node_feats
_, n_total = lammps_natoms
pad = torch.zeros(
(n_total, node_feats.shape[1]),
dtype=node_feats.dtype,
device=node_feats.device,
)
node_feats = torch.cat((node_feats, pad), dim=0)
node_feats = LAMMPS_MP.apply(node_feats, lammps_class)
return node_feats
def truncate_ghosts(
self, tensor: torch.Tensor, n_real: Optional[int] = None
) -> torch.Tensor:
"""Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations."""
return tensor[:n_real] if n_real is not None else tensor
@abstractmethod
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh}
@compile_mode("script")
class RealAgnosticInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out,
self.node_attrs_irreps,
self.irreps_out,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_natoms: Tuple[int, int] = (0, 0),
lammps_class: Optional[Any] = None,
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
message = self.linear(message) / self.avg_num_neighbors
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps,
self.node_attrs_irreps,
self.hidden_irreps,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
sc = self.truncate_ghosts(sc, n_real)
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticDensityInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.irreps_out,
self.node_attrs_irreps,
self.irreps_out,
cueq_config=self.cueq_config,
)
# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
density = self.truncate_ghosts(density, n_real)
message = self.linear(message) / (density + 1)
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticDensityResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# Selector TensorProduct
self.skip_tp = FullyConnectedTensorProduct(
self.node_feats_irreps,
self.node_attrs_irreps,
self.hidden_irreps,
cueq_config=self.cueq_config,
)
# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
n_real = lammps_natoms[0] if lammps_class is not None else None
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(
node_feats,
lammps_class=lammps_class,
lammps_natoms=lammps_natoms,
first_layer=first_layer,
)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.truncate_ghosts(message, n_real)
node_attrs = self.truncate_ghosts(node_attrs, n_real)
density = self.truncate_ghosts(density, n_real)
sc = self.truncate_ghosts(sc, n_real)
message = self.linear(message) / (density + 1)
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None
self.node_feats_down_irreps = o3.Irreps("64x0e")
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
cueq_config=self.cueq_config,
)
# Convolution weights
self.linear_down = Linear(
self.node_feats_irreps,
self.node_feats_down_irreps,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
input_dim = (
self.edge_feats_irreps.num_irreps
+ 2 * self.node_feats_down_irreps.num_irreps
)
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + 3 * [256] + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
cueq_config=self.cueq_config,
)
self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config)
# Skip connection.
self.skip_linear = Linear(
self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config
)
# pylint: disable=unused-argument
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
lammps_class: Optional[Any] = None,
lammps_natoms: Tuple[int, int] = (0, 0),
first_layer: bool = False,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_linear(node_feats)
node_feats_up = self.linear_up(node_feats)
node_feats_down = self.linear_down(node_feats)
augmented_edge_feats = torch.cat(
[
edge_feats,
node_feats_down[sender],
node_feats_down[receiver],
],
dim=-1,
)
tp_weights = self.conv_tp_weights(augmented_edge_feats)
mji = self.conv_tp(
node_feats_up[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class ScaleShiftBlock(torch.nn.Module):
def __init__(self, scale: float, shift: float):
super().__init__()
self.register_buffer(
"scale",
torch.tensor(scale, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"shift",
torch.tensor(shift, dtype=torch.get_default_dtype()),
)
def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor:
return (
torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head]
)
def __repr__(self):
formatted_scale = (
", ".join([f"{x:.4f}" for x in self.scale])
if self.scale.numel() > 1
else f"{self.scale.item():.4f}"
)
formatted_shift = (
", ".join([f"{x:.4f}" for x in self.shift])
if self.shift.numel() > 1
else f"{self.shift.item():.4f}"
)
return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment