########################################################################################### # Implementation of MACE models and other models based E(3)-Equivariant MPNNs # Authors: Ilyes Batatia, Gregor Simm # This program is distributed under the MIT License (see MIT.md) ########################################################################################### from typing import Any, Callable, Dict, List, Optional, Type, Union import numpy as np import torch from e3nn import o3 from e3nn.util.jit import compile_mode from mace.modules.radial import ZBLBasis from mace.tools.scatter import scatter_sum from .blocks import ( AtomicEnergiesBlock, EquivariantProductBasisBlock, InteractionBlock, LinearDipoleReadoutBlock, LinearNodeEmbeddingBlock, LinearReadoutBlock, NonLinearDipoleReadoutBlock, NonLinearReadoutBlock, RadialEmbeddingBlock, ScaleShiftBlock, ) from .utils import ( compute_fixed_charge_dipole, get_atomic_virials_stresses, get_edge_vectors_and_lengths, get_outputs, get_symmetric_displacement, prepare_graph, ) # pylint: disable=C0302 @compile_mode("script") class MACE(torch.nn.Module): def __init__( self, r_max: float, num_bessel: int, num_polynomial_cutoff: int, max_ell: int, interaction_cls: Type[InteractionBlock], interaction_cls_first: Type[InteractionBlock], num_interactions: int, num_elements: int, hidden_irreps: o3.Irreps, MLP_irreps: o3.Irreps, atomic_energies: np.ndarray, avg_num_neighbors: float, atomic_numbers: List[int], correlation: Union[int, List[int]], gate: Optional[Callable], pair_repulsion: bool = False, distance_transform: str = "None", radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", heads: Optional[List[str]] = None, cueq_config: Optional[Dict[str, Any]] = None, lammps_mliap: Optional[bool] = False, ): super().__init__() self.register_buffer( "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) ) self.register_buffer( "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) ) self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) if heads is None: heads = ["Default"] self.heads = heads if isinstance(correlation, int): correlation = [correlation] * num_interactions self.lammps_mliap = lammps_mliap # Embedding node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( irreps_in=node_attr_irreps, irreps_out=node_feats_irreps, cueq_config=cueq_config, ) self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, radial_type=radial_type, distance_transform=distance_transform, ) edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") if pair_repulsion: self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) self.pair_repulsion = True sh_irreps = o3.Irreps.spherical_harmonics(max_ell) num_features = hidden_irreps.count(o3.Irrep(0, 1)) interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() self.spherical_harmonics = o3.SphericalHarmonics( sh_irreps, normalize=True, normalization="component" ) if radial_MLP is None: radial_MLP = [64, 64, 64] # Interactions and readout self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) inter = interaction_cls_first( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, cueq_config=cueq_config, ) self.interactions = torch.nn.ModuleList([inter]) # Use the appropriate self connection at the first layer for proper E0 use_sc_first = False if "Residual" in str(interaction_cls_first): use_sc_first = True node_feats_irreps_out = inter.target_irreps prod = EquivariantProductBasisBlock( node_feats_irreps=node_feats_irreps_out, target_irreps=hidden_irreps, correlation=correlation[0], num_elements=num_elements, use_sc=use_sc_first, cueq_config=cueq_config, ) self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() self.readouts.append( LinearReadoutBlock( hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config ) ) for i in range(num_interactions - 1): if i == num_interactions - 2: hidden_irreps_out = str( hidden_irreps[0] ) # Select only scalars for last layer else: hidden_irreps_out = hidden_irreps inter = interaction_cls( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, cueq_config=cueq_config, ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( node_feats_irreps=interaction_irreps, target_irreps=hidden_irreps_out, correlation=correlation[i + 1], num_elements=num_elements, use_sc=True, cueq_config=cueq_config, ) self.products.append(prod) if i == num_interactions - 2: self.readouts.append( NonLinearReadoutBlock( hidden_irreps_out, (len(heads) * MLP_irreps).simplify(), gate, o3.Irreps(f"{len(heads)}x0e"), len(heads), cueq_config, ) ) else: self.readouts.append( LinearReadoutBlock( hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config ) ) def forward( self, data: Dict[str, torch.Tensor], training: bool = False, compute_force: bool = True, compute_virials: bool = False, compute_stress: bool = False, compute_displacement: bool = False, compute_hessian: bool = False, compute_edge_forces: bool = False, compute_atomic_stresses: bool = False, lammps_mliap: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: # Setup ctx = prepare_graph( data, compute_virials=compute_virials, compute_stress=compute_stress, compute_displacement=compute_displacement, lammps_mliap=lammps_mliap, ) is_lammps = ctx.is_lammps num_atoms_arange = ctx.num_atoms_arange num_graphs = ctx.num_graphs displacement = ctx.displacement positions = ctx.positions vectors = ctx.vectors lengths = ctx.lengths cell = ctx.cell node_heads = ctx.node_heads interaction_kwargs = ctx.interaction_kwargs lammps_natoms = interaction_kwargs.lammps_natoms lammps_class = interaction_kwargs.lammps_class # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ num_atoms_arange, node_heads ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs ) # [n_graphs, n_heads] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding( lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers ) if hasattr(self, "pair_repulsion"): pair_node_energy = self.pair_repulsion_fn( lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers ) if is_lammps: pair_node_energy = pair_node_energy[: lammps_natoms[0]] pair_energy = scatter_sum( src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] else: pair_node_energy = torch.zeros_like(node_e0) pair_energy = torch.zeros_like(e0) # Interactions energies = [e0, pair_energy] node_energies_list = [node_e0, pair_node_energy] node_feats_concat: List[torch.Tensor] = [] for i, (interaction, product, readout) in enumerate( zip(self.interactions, self.products, self.readouts) ): node_attrs_slice = data["node_attrs"] if is_lammps and i > 0: node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] node_feats, sc = interaction( node_attrs=node_attrs_slice, node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, edge_index=data["edge_index"], first_layer=(i == 0), lammps_class=lammps_class, lammps_natoms=lammps_natoms, ) if is_lammps and i == 0: node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] node_feats = product( node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice ) node_feats_concat.append(node_feats) node_es = readout(node_feats, node_heads)[num_atoms_arange, node_heads] energy = scatter_sum(node_es, data["batch"], dim=0, dim_size=num_graphs) energies.append(energy) node_energies_list.append(node_es) contributions = torch.stack(energies, dim=-1) total_energy = torch.sum(contributions, dim=-1) node_energy = torch.sum(torch.stack(node_energies_list, dim=-1), dim=-1) node_feats_out = torch.cat(node_feats_concat, dim=-1) node_energy = node_e0.double() + pair_node_energy.double() forces, virials, stress, hessian, edge_forces = get_outputs( energy=total_energy, positions=positions, displacement=displacement, vectors=vectors, cell=cell, training=training, compute_force=compute_force, compute_virials=compute_virials, compute_stress=compute_stress, compute_hessian=compute_hessian, compute_edge_forces=compute_edge_forces, ) atomic_virials: Optional[torch.Tensor] = None atomic_stresses: Optional[torch.Tensor] = None if compute_atomic_stresses and edge_forces is not None: atomic_virials, atomic_stresses = get_atomic_virials_stresses( edge_forces=edge_forces, edge_index=data["edge_index"], vectors=vectors, num_atoms=positions.shape[0], batch=data["batch"], cell=cell, ) return { "energy": total_energy, "node_energy": node_energy, "contributions": contributions, "forces": forces, "edge_forces": edge_forces, "virials": virials, "stress": stress, "atomic_virials": atomic_virials, "atomic_stresses": atomic_stresses, "displacement": displacement, "hessian": hessian, "node_feats": node_feats_out, } @compile_mode("script") class ScaleShiftMACE(MACE): def __init__( self, atomic_inter_scale: float, atomic_inter_shift: float, **kwargs, ): super().__init__(**kwargs) self.scale_shift = ScaleShiftBlock( scale=atomic_inter_scale, shift=atomic_inter_shift ) def forward( self, data: Dict[str, torch.Tensor], training: bool = False, compute_force: bool = True, compute_virials: bool = False, compute_stress: bool = False, compute_displacement: bool = False, compute_hessian: bool = False, compute_edge_forces: bool = False, compute_atomic_stresses: bool = False, lammps_mliap: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: # Setup ctx = prepare_graph( data, compute_virials=compute_virials, compute_stress=compute_stress, compute_displacement=compute_displacement, lammps_mliap=lammps_mliap, ) is_lammps = ctx.is_lammps num_atoms_arange = ctx.num_atoms_arange num_graphs = ctx.num_graphs displacement = ctx.displacement positions = ctx.positions vectors = ctx.vectors lengths = ctx.lengths cell = ctx.cell node_heads = ctx.node_heads interaction_kwargs = ctx.interaction_kwargs lammps_natoms = interaction_kwargs.lammps_natoms lammps_class = interaction_kwargs.lammps_class # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ num_atoms_arange, node_heads ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs ) # [n_graphs, num_heads] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding( lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers ) if hasattr(self, "pair_repulsion"): pair_node_energy = self.pair_repulsion_fn( lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers ) if is_lammps: pair_node_energy = pair_node_energy[: lammps_natoms[0]] else: pair_node_energy = torch.zeros_like(node_e0) # Interactions node_es_list = [pair_node_energy] node_feats_list: List[torch.Tensor] = [] for i, (interaction, product, readout) in enumerate( zip(self.interactions, self.products, self.readouts) ): node_attrs_slice = data["node_attrs"] if is_lammps and i > 0: node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] node_feats, sc = interaction( node_attrs=node_attrs_slice, node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, edge_index=data["edge_index"], first_layer=(i == 0), lammps_class=lammps_class, lammps_natoms=lammps_natoms, ) if is_lammps and i == 0: node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] node_feats = product( node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice ) node_feats_list.append(node_feats) node_es_list.append( readout(node_feats, node_heads)[num_atoms_arange, node_heads] ) node_feats_out = torch.cat(node_feats_list, dim=-1) node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0) node_inter_es = self.scale_shift(node_inter_es, node_heads) inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs) total_energy = e0 + inter_e node_energy = node_e0.clone().double() + node_inter_es.clone().double() forces, virials, stress, hessian, edge_forces = get_outputs( energy=inter_e, positions=positions, displacement=displacement, vectors=vectors, cell=cell, training=training, compute_force=compute_force, compute_virials=compute_virials, compute_stress=compute_stress, compute_hessian=compute_hessian, compute_edge_forces=compute_edge_forces or compute_atomic_stresses, ) atomic_virials: Optional[torch.Tensor] = None atomic_stresses: Optional[torch.Tensor] = None if compute_atomic_stresses and edge_forces is not None: atomic_virials, atomic_stresses = get_atomic_virials_stresses( edge_forces=edge_forces, edge_index=data["edge_index"], vectors=vectors, num_atoms=positions.shape[0], batch=data["batch"], cell=cell, ) return { "energy": total_energy, "node_energy": node_energy, "interaction_energy": inter_e, "forces": forces, "edge_forces": edge_forces, "virials": virials, "stress": stress, "atomic_virials": atomic_virials, "atomic_stresses": atomic_stresses, "hessian": hessian, "displacement": displacement, "node_feats": node_feats_out, } @compile_mode("script") class AtomicDipolesMACE(torch.nn.Module): def __init__( self, r_max: float, num_bessel: int, num_polynomial_cutoff: int, max_ell: int, interaction_cls: Type[InteractionBlock], interaction_cls_first: Type[InteractionBlock], num_interactions: int, num_elements: int, hidden_irreps: o3.Irreps, MLP_irreps: o3.Irreps, avg_num_neighbors: float, atomic_numbers: List[int], correlation: int, gate: Optional[Callable], atomic_energies: Optional[ None ], # Just here to make it compatible with energy models, MUST be None radial_type: Optional[str] = "bessel", radial_MLP: Optional[List[int]] = None, cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) ) self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) assert atomic_energies is None # Embedding node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( irreps_in=node_attr_irreps, irreps_out=node_feats_irreps ) self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, radial_type=radial_type, ) edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") sh_irreps = o3.Irreps.spherical_harmonics(max_ell) num_features = hidden_irreps.count(o3.Irrep(0, 1)) interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() self.spherical_harmonics = o3.SphericalHarmonics( sh_irreps, normalize=True, normalization="component" ) if radial_MLP is None: radial_MLP = [64, 64, 64] # Interactions and readouts inter = interaction_cls_first( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, ) self.interactions = torch.nn.ModuleList([inter]) # Use the appropriate self connection at the first layer use_sc_first = False if "Residual" in str(interaction_cls_first): use_sc_first = True node_feats_irreps_out = inter.target_irreps prod = EquivariantProductBasisBlock( node_feats_irreps=node_feats_irreps_out, target_irreps=hidden_irreps, correlation=correlation, num_elements=num_elements, use_sc=use_sc_first, ) self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) for i in range(num_interactions - 1): if i == num_interactions - 2: assert ( len(hidden_irreps) > 1 ), "To predict dipoles use at least l=1 hidden_irreps" hidden_irreps_out = str( hidden_irreps[1] ) # Select only l=1 vectors for last layer else: hidden_irreps_out = hidden_irreps inter = interaction_cls( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( node_feats_irreps=interaction_irreps, target_irreps=hidden_irreps_out, correlation=correlation, num_elements=num_elements, use_sc=True, ) self.products.append(prod) if i == num_interactions - 2: self.readouts.append( NonLinearDipoleReadoutBlock( hidden_irreps_out, MLP_irreps, gate, dipole_only=True ) ) else: self.readouts.append( LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) ) def forward( self, data: Dict[str, torch.Tensor], training: bool = False, # pylint: disable=W0613 compute_force: bool = False, compute_virials: bool = False, compute_stress: bool = False, compute_displacement: bool = False, compute_edge_forces: bool = False, # pylint: disable=W0613 compute_atomic_stresses: bool = False, # pylint: disable=W0613 ) -> Dict[str, Optional[torch.Tensor]]: assert compute_force is False assert compute_virials is False assert compute_stress is False assert compute_displacement is False # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 # Embeddings node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( positions=data["positions"], edge_index=data["edge_index"], shifts=data["shifts"], ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding( lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers ) # Interactions dipoles = [] for interaction, product, readout in zip( self.interactions, self.products, self.readouts ): node_feats, sc = interaction( node_attrs=data["node_attrs"], node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, edge_index=data["edge_index"], ) node_feats = product( node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"], ) node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] dipoles.append(node_dipoles) # Compute the dipoles contributions_dipoles = torch.stack( dipoles, dim=-1 ) # [n_nodes,3,n_contributions] atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] total_dipole = scatter_sum( src=atomic_dipoles, index=data["batch"], dim=0, dim_size=num_graphs, ) # [n_graphs,3] baseline = compute_fixed_charge_dipole( charges=data["charges"], positions=data["positions"], batch=data["batch"], num_graphs=num_graphs, ) # [n_graphs,3] total_dipole = total_dipole + baseline output = { "dipole": total_dipole, "atomic_dipoles": atomic_dipoles, } return output @compile_mode("script") class EnergyDipolesMACE(torch.nn.Module): def __init__( self, r_max: float, num_bessel: int, num_polynomial_cutoff: int, max_ell: int, interaction_cls: Type[InteractionBlock], interaction_cls_first: Type[InteractionBlock], num_interactions: int, num_elements: int, hidden_irreps: o3.Irreps, MLP_irreps: o3.Irreps, avg_num_neighbors: float, atomic_numbers: List[int], correlation: int, gate: Optional[Callable], atomic_energies: Optional[np.ndarray], radial_MLP: Optional[List[int]] = None, cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) ) self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) # Embedding node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( irreps_in=node_attr_irreps, irreps_out=node_feats_irreps ) self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, ) edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") sh_irreps = o3.Irreps.spherical_harmonics(max_ell) num_features = hidden_irreps.count(o3.Irrep(0, 1)) interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() self.spherical_harmonics = o3.SphericalHarmonics( sh_irreps, normalize=True, normalization="component" ) if radial_MLP is None: radial_MLP = [64, 64, 64] # Interactions and readouts self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) inter = interaction_cls_first( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, ) self.interactions = torch.nn.ModuleList([inter]) # Use the appropriate self connection at the first layer use_sc_first = False if "Residual" in str(interaction_cls_first): use_sc_first = True node_feats_irreps_out = inter.target_irreps prod = EquivariantProductBasisBlock( node_feats_irreps=node_feats_irreps_out, target_irreps=hidden_irreps, correlation=correlation, num_elements=num_elements, use_sc=use_sc_first, ) self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) for i in range(num_interactions - 1): if i == num_interactions - 2: assert ( len(hidden_irreps) > 1 ), "To predict dipoles use at least l=1 hidden_irreps" hidden_irreps_out = str( hidden_irreps[:2] ) # Select scalars and l=1 vectors for last layer else: hidden_irreps_out = hidden_irreps inter = interaction_cls( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( node_feats_irreps=interaction_irreps, target_irreps=hidden_irreps_out, correlation=correlation, num_elements=num_elements, use_sc=True, ) self.products.append(prod) if i == num_interactions - 2: self.readouts.append( NonLinearDipoleReadoutBlock( hidden_irreps_out, MLP_irreps, gate, dipole_only=False ) ) else: self.readouts.append( LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) ) def forward( self, data: Dict[str, torch.Tensor], training: bool = False, compute_force: bool = True, compute_virials: bool = False, compute_stress: bool = False, compute_displacement: bool = False, compute_edge_forces: bool = False, # pylint: disable=W0613 compute_atomic_stresses: bool = False, # pylint: disable=W0613 ) -> Dict[str, Optional[torch.Tensor]]: # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 num_atoms_arange = torch.arange(data["positions"].shape[0]) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, device=data["positions"].device, ) if compute_virials or compute_stress or compute_displacement: ( data["positions"], data["shifts"], displacement, ) = get_symmetric_displacement( positions=data["positions"], unit_shifts=data["unit_shifts"], cell=data["cell"], edge_index=data["edge_index"], num_graphs=num_graphs, batch=data["batch"], ) # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ num_atoms_arange, data["head"][data["batch"]] ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( positions=data["positions"], edge_index=data["edge_index"], shifts=data["shifts"], ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding( lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers ) # Interactions energies = [e0] node_energies_list = [node_e0] dipoles = [] for interaction, product, readout in zip( self.interactions, self.products, self.readouts ): node_feats, sc = interaction( node_attrs=data["node_attrs"], node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, edge_index=data["edge_index"], ) node_feats = product( node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"], ) node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] node_energies = node_out[:, 0] energy = scatter_sum( src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] energies.append(energy) node_dipoles = node_out[:, 1:] dipoles.append(node_dipoles) # Compute the energies and dipoles contributions = torch.stack(energies, dim=-1) total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] node_energy_contributions = torch.stack(node_energies_list, dim=-1) node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] contributions_dipoles = torch.stack( dipoles, dim=-1 ) # [n_nodes,3,n_contributions] atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] total_dipole = scatter_sum( src=atomic_dipoles, index=data["batch"].unsqueeze(-1), dim=0, dim_size=num_graphs, ) # [n_graphs,3] baseline = compute_fixed_charge_dipole( charges=data["charges"], positions=data["positions"], batch=data["batch"], num_graphs=num_graphs, ) # [n_graphs,3] total_dipole = total_dipole + baseline forces, virials, stress, _, _ = get_outputs( energy=total_energy, positions=data["positions"], displacement=displacement, cell=data["cell"], training=training, compute_force=compute_force, compute_virials=compute_virials, compute_stress=compute_stress, ) output = { "energy": total_energy, "node_energy": node_energy, "contributions": contributions, "forces": forces, "virials": virials, "stress": stress, "displacement": displacement, "dipole": total_dipole, "atomic_dipoles": atomic_dipoles, } return output