########################################################################################### # Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network # Authors: Ilyes Batatia, Gregor Simm # This program is distributed under the MIT License (see MIT.md) ########################################################################################### from abc import abstractmethod from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch.nn.functional from e3nn import nn, o3 from e3nn.util.jit import compile_mode from mace.modules.wrapper_ops import ( CuEquivarianceConfig, FullyConnectedTensorProduct, Linear, SymmetricContractionWrapper, TensorProduct, ) from mace.tools.compile import simplify_if_compile from mace.tools.scatter import scatter_sum from mace.tools.utils import LAMMPS_MP from .irreps_tools import mask_head, reshape_irreps, tp_out_irreps_with_instructions from .radial import ( AgnesiTransform, BesselBasis, ChebychevBasis, GaussianBasis, PolynomialCutoff, SoftTransform, ) @compile_mode("script") class LinearNodeEmbeddingBlock(torch.nn.Module): def __init__( self, irreps_in: o3.Irreps, irreps_out: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.linear = Linear( irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config ) def forward( self, node_attrs: torch.Tensor, ) -> torch.Tensor: # [n_nodes, irreps] return self.linear(node_attrs) @compile_mode("script") class LinearReadoutBlock(torch.nn.Module): def __init__( self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e"), cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.linear = Linear( irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config ) def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @simplify_if_compile @compile_mode("script") class NonLinearReadoutBlock(torch.nn.Module): def __init__( self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), num_heads: int = 1, cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps self.num_heads = num_heads self.linear_1 = Linear( irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config ) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) self.linear_2 = Linear( irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config ) def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) if hasattr(self, "num_heads"): if self.num_heads > 1 and heads is not None: x = mask_head(x, heads, self.num_heads) return self.linear_2(x) # [n_nodes, len(heads)] @compile_mode("script") class LinearDipoleReadoutBlock(torch.nn.Module): def __init__( self, irreps_in: o3.Irreps, dipole_only: bool = False, cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() if dipole_only: self.irreps_out = o3.Irreps("1x1o") else: self.irreps_out = o3.Irreps("1x0e + 1x1o") self.linear = Linear( irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @compile_mode("script") class NonLinearDipoleReadoutBlock(torch.nn.Module): def __init__( self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, gate: Callable, dipole_only: bool = False, cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps if dipole_only: self.irreps_out = o3.Irreps("1x1o") else: self.irreps_out = o3.Irreps("1x0e + 1x1o") irreps_scalars = o3.Irreps( [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] ) irreps_gated = o3.Irreps( [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] ) irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) self.equivariant_nonlin = nn.Gate( irreps_scalars=irreps_scalars, act_scalars=[gate for _, ir in irreps_scalars], irreps_gates=irreps_gates, act_gates=[gate] * len(irreps_gates), irreps_gated=irreps_gated, ) self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() self.linear_1 = Linear( irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config ) self.linear_2 = Linear( irreps_in=self.hidden_irreps, irreps_out=self.irreps_out, cueq_config=cueq_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.equivariant_nonlin(self.linear_1(x)) return self.linear_2(x) # [n_nodes, 1] @compile_mode("script") class AtomicEnergiesBlock(torch.nn.Module): atomic_energies: torch.Tensor def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): super().__init__() # assert len(atomic_energies.shape) == 1 self.register_buffer( "atomic_energies", torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), ) # [n_elements, n_heads] def forward( self, x: torch.Tensor # one-hot of elements [..., n_elements] ) -> torch.Tensor: # [..., ] return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) def __repr__(self): formatted_energies = ", ".join( [ "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" for group in torch.atleast_2d(self.atomic_energies) ] ) return f"{self.__class__.__name__}(energies=[{formatted_energies}])" @compile_mode("script") class RadialEmbeddingBlock(torch.nn.Module): def __init__( self, r_max: float, num_bessel: int, num_polynomial_cutoff: int, radial_type: str = "bessel", distance_transform: str = "None", ): super().__init__() if radial_type == "bessel": self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) elif radial_type == "gaussian": self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) elif radial_type == "chebyshev": self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) if distance_transform == "Agnesi": self.distance_transform = AgnesiTransform() elif distance_transform == "Soft": self.distance_transform = SoftTransform() self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) self.out_dim = num_bessel def forward( self, edge_lengths: torch.Tensor, # [n_edges, 1] node_attrs: torch.Tensor, edge_index: torch.Tensor, atomic_numbers: torch.Tensor, ): cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] if hasattr(self, "distance_transform"): edge_lengths = self.distance_transform( edge_lengths, node_attrs, edge_index, atomic_numbers ) radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] return radial * cutoff # [n_edges, n_basis] @compile_mode("script") class EquivariantProductBasisBlock(torch.nn.Module): def __init__( self, node_feats_irreps: o3.Irreps, target_irreps: o3.Irreps, correlation: int, use_sc: bool = True, num_elements: Optional[int] = None, cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.use_sc = use_sc self.symmetric_contractions = SymmetricContractionWrapper( irreps_in=node_feats_irreps, irreps_out=target_irreps, correlation=correlation, num_elements=num_elements, cueq_config=cueq_config, ) # Update linear self.linear = Linear( target_irreps, target_irreps, internal_weights=True, shared_weights=True, cueq_config=cueq_config, ) self.cueq_config = cueq_config def forward( self, node_feats: torch.Tensor, sc: Optional[torch.Tensor], node_attrs: torch.Tensor, ) -> torch.Tensor: use_cueq = False use_cueq_mul_ir = False if hasattr(self, "cueq_config"): if self.cueq_config is not None: if self.cueq_config.enabled and ( self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric ): use_cueq = True if self.cueq_config.layout_str == "mul_ir": use_cueq_mul_ir = True if use_cueq: if use_cueq_mul_ir: node_feats = torch.transpose(node_feats, 1, 2) index_attrs = torch.nonzero(node_attrs)[:, 1].int() node_feats = self.symmetric_contractions( node_feats.flatten(1), index_attrs, ) else: node_feats = self.symmetric_contractions(node_feats, node_attrs) if self.use_sc and sc is not None: return self.linear(node_feats) + sc return self.linear(node_feats) @compile_mode("script") class InteractionBlock(torch.nn.Module): def __init__( self, node_attrs_irreps: o3.Irreps, node_feats_irreps: o3.Irreps, edge_attrs_irreps: o3.Irreps, edge_feats_irreps: o3.Irreps, target_irreps: o3.Irreps, hidden_irreps: o3.Irreps, avg_num_neighbors: float, radial_MLP: Optional[List[int]] = None, cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.node_attrs_irreps = node_attrs_irreps self.node_feats_irreps = node_feats_irreps self.edge_attrs_irreps = edge_attrs_irreps self.edge_feats_irreps = edge_feats_irreps self.target_irreps = target_irreps self.hidden_irreps = hidden_irreps self.avg_num_neighbors = avg_num_neighbors if radial_MLP is None: radial_MLP = [64, 64, 64] self.radial_MLP = radial_MLP self.cueq_config = cueq_config self._setup() @abstractmethod def _setup(self) -> None: raise NotImplementedError def handle_lammps( self, node_feats: torch.Tensor, lammps_class: Optional[Any], lammps_natoms: Tuple[int, int], first_layer: bool, ) -> torch.Tensor: # noqa: D401 – internal helper if lammps_class is None or first_layer or torch.jit.is_scripting(): return node_feats _, n_total = lammps_natoms pad = torch.zeros( (n_total, node_feats.shape[1]), dtype=node_feats.dtype, device=node_feats.device, ) node_feats = torch.cat((node_feats, pad), dim=0) node_feats = LAMMPS_MP.apply(node_feats, lammps_class) return node_feats def truncate_ghosts( self, tensor: torch.Tensor, n_real: Optional[int] = None ) -> torch.Tensor: """Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations.""" return tensor[:n_real] if n_real is not None else tensor @abstractmethod def forward( self, node_attrs: torch.Tensor, node_feats: torch.Tensor, edge_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, ) -> torch.Tensor: raise NotImplementedError nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} @compile_mode("script") class RealAgnosticInteractionBlock(InteractionBlock): def _setup(self) -> None: if not hasattr(self, "cueq_config"): self.cueq_config = None # First linear self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, ) self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, cueq_config=self.cueq_config, ) # Convolution weights input_dim = self.edge_feats_irreps.num_irreps self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], torch.nn.functional.silu, ) # Linear self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config, ) self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, node_attrs: torch.Tensor, node_feats: torch.Tensor, edge_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, lammps_natoms: Tuple[int, int] = (0, 0), lammps_class: Optional[Any] = None, first_layer: bool = False, ) -> Tuple[torch.Tensor, None]: sender = edge_index[0] receiver = edge_index[1] num_nodes = node_feats.shape[0] n_real = lammps_natoms[0] if lammps_class is not None else None node_feats = self.linear_up(node_feats) node_feats = self.handle_lammps( node_feats, lammps_class=lammps_class, lammps_natoms=lammps_natoms, first_layer=first_layer, ) tp_weights = self.conv_tp_weights(edge_feats) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] message = scatter_sum( src=mji, index=receiver, dim=0, dim_size=num_nodes ) # [n_nodes, irreps] message = self.truncate_ghosts(message, n_real) node_attrs = self.truncate_ghosts(node_attrs, n_real) message = self.linear(message) / self.avg_num_neighbors message = self.skip_tp(message, node_attrs) return ( self.reshape(message), None, ) # [n_nodes, channels, (lmax + 1)**2] @compile_mode("script") class RealAgnosticResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: if not hasattr(self, "cueq_config"): self.cueq_config = None # First linear self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, ) self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, cueq_config=self.cueq_config, ) # Convolution weights input_dim = self.edge_feats_irreps.num_irreps self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], torch.nn.functional.silu, # gate ) # Linear self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps, cueq_config=self.cueq_config, ) self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, node_attrs: torch.Tensor, node_feats: torch.Tensor, edge_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, lammps_class: Optional[Any] = None, lammps_natoms: Tuple[int, int] = (0, 0), first_layer: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: sender = edge_index[0] receiver = edge_index[1] num_nodes = node_feats.shape[0] n_real = lammps_natoms[0] if lammps_class is not None else None sc = self.skip_tp(node_feats, node_attrs) node_feats = self.linear_up(node_feats) node_feats = self.handle_lammps( node_feats, lammps_class=lammps_class, lammps_natoms=lammps_natoms, first_layer=first_layer, ) tp_weights = self.conv_tp_weights(edge_feats) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] message = scatter_sum( src=mji, index=receiver, dim=0, dim_size=num_nodes ) # [n_nodes, irreps] message = self.truncate_ghosts(message, n_real) node_attrs = self.truncate_ghosts(node_attrs, n_real) sc = self.truncate_ghosts(sc, n_real) message = self.linear(message) / self.avg_num_neighbors return ( self.reshape(message), sc, ) # [n_nodes, channels, (lmax + 1)**2] @compile_mode("script") class RealAgnosticDensityInteractionBlock(InteractionBlock): def _setup(self) -> None: if not hasattr(self, "cueq_config"): self.cueq_config = None # First linear self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, ) self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, cueq_config=self.cueq_config, ) # Convolution weights input_dim = self.edge_feats_irreps.num_irreps self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], torch.nn.functional.silu, ) # Linear self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config, ) # Density normalization self.density_fn = nn.FullyConnectedNet( [input_dim] + [ 1, ], torch.nn.functional.silu, ) # Reshape self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, node_attrs: torch.Tensor, node_feats: torch.Tensor, edge_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, lammps_class: Optional[Any] = None, lammps_natoms: Tuple[int, int] = (0, 0), first_layer: bool = False, ) -> Tuple[torch.Tensor, None]: sender = edge_index[0] receiver = edge_index[1] num_nodes = node_feats.shape[0] n_real = lammps_natoms[0] if lammps_class is not None else None node_feats = self.linear_up(node_feats) node_feats = self.handle_lammps( node_feats, lammps_class=lammps_class, lammps_natoms=lammps_natoms, first_layer=first_layer, ) tp_weights = self.conv_tp_weights(edge_feats) edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] density = scatter_sum( src=edge_density, index=receiver, dim=0, dim_size=num_nodes ) # [n_nodes, 1] message = scatter_sum( src=mji, index=receiver, dim=0, dim_size=num_nodes ) # [n_nodes, irreps] message = self.truncate_ghosts(message, n_real) node_attrs = self.truncate_ghosts(node_attrs, n_real) density = self.truncate_ghosts(density, n_real) message = self.linear(message) / (density + 1) message = self.skip_tp(message, node_attrs) return ( self.reshape(message), None, ) # [n_nodes, channels, (lmax + 1)**2] @compile_mode("script") class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: if not hasattr(self, "cueq_config"): self.cueq_config = None # First linear self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, ) self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, cueq_config=self.cueq_config, ) # Convolution weights input_dim = self.edge_feats_irreps.num_irreps self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], torch.nn.functional.silu, # gate ) # Linear self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps, cueq_config=self.cueq_config, ) # Density normalization self.density_fn = nn.FullyConnectedNet( [input_dim] + [ 1, ], torch.nn.functional.silu, ) # Reshape self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, node_attrs: torch.Tensor, node_feats: torch.Tensor, edge_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, lammps_class: Optional[Any] = None, lammps_natoms: Tuple[int, int] = (0, 0), first_layer: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: sender = edge_index[0] receiver = edge_index[1] num_nodes = node_feats.shape[0] n_real = lammps_natoms[0] if lammps_class is not None else None sc = self.skip_tp(node_feats, node_attrs) node_feats = self.linear_up(node_feats) node_feats = self.handle_lammps( node_feats, lammps_class=lammps_class, lammps_natoms=lammps_natoms, first_layer=first_layer, ) tp_weights = self.conv_tp_weights(edge_feats) edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] density = scatter_sum( src=edge_density, index=receiver, dim=0, dim_size=num_nodes ) # [n_nodes, 1] message = scatter_sum( src=mji, index=receiver, dim=0, dim_size=num_nodes ) # [n_nodes, irreps] message = self.truncate_ghosts(message, n_real) node_attrs = self.truncate_ghosts(node_attrs, n_real) density = self.truncate_ghosts(density, n_real) sc = self.truncate_ghosts(sc, n_real) message = self.linear(message) / (density + 1) return ( self.reshape(message), sc, ) # [n_nodes, channels, (lmax + 1)**2] @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: if not hasattr(self, "cueq_config"): self.cueq_config = None self.node_feats_down_irreps = o3.Irreps("64x0e") # First linear self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, ) self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, cueq_config=self.cueq_config, ) # Convolution weights self.linear_down = Linear( self.node_feats_irreps, self.node_feats_down_irreps, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) input_dim = ( self.edge_feats_irreps.num_irreps + 2 * self.node_feats_down_irreps.num_irreps ) self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], torch.nn.functional.silu, ) # Linear self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config, ) self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) # Skip connection. self.skip_linear = Linear( self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config ) # pylint: disable=unused-argument def forward( self, node_attrs: torch.Tensor, node_feats: torch.Tensor, edge_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, lammps_class: Optional[Any] = None, lammps_natoms: Tuple[int, int] = (0, 0), first_layer: bool = False, ) -> Tuple[torch.Tensor, None]: sender = edge_index[0] receiver = edge_index[1] num_nodes = node_feats.shape[0] sc = self.skip_linear(node_feats) node_feats_up = self.linear_up(node_feats) node_feats_down = self.linear_down(node_feats) augmented_edge_feats = torch.cat( [ edge_feats, node_feats_down[sender], node_feats_down[receiver], ], dim=-1, ) tp_weights = self.conv_tp_weights(augmented_edge_feats) mji = self.conv_tp( node_feats_up[sender], edge_attrs, tp_weights ) # [n_edges, irreps] message = scatter_sum( src=mji, index=receiver, dim=0, dim_size=num_nodes ) # [n_nodes, irreps] message = self.linear(message) / self.avg_num_neighbors return ( self.reshape(message), sc, ) # [n_nodes, channels, (lmax + 1)**2] @compile_mode("script") class ScaleShiftBlock(torch.nn.Module): def __init__(self, scale: float, shift: float): super().__init__() self.register_buffer( "scale", torch.tensor(scale, dtype=torch.get_default_dtype()), ) self.register_buffer( "shift", torch.tensor(shift, dtype=torch.get_default_dtype()), ) def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: return ( torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] ) def __repr__(self): formatted_scale = ( ", ".join([f"{x:.4f}" for x in self.scale]) if self.scale.numel() > 1 else f"{self.scale.item():.4f}" ) formatted_shift = ( ", ".join([f"{x:.4f}" for x in self.shift]) if self.shift.numel() > 1 else f"{self.shift.item():.4f}" ) return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})"