Unverified Commit a69b2190 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Delete modules directory

parent d47e8ba0
###########################################################################################
# Elementary tools for handling irreducible representations
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import List, Optional, Tuple
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from mace.modules.wrapper_ops import CuEquivarianceConfig
# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
) -> Tuple[o3.Irreps, List]:
trainable = True
# Collect possible irreps and their instructions
irreps_out_list: List[Tuple[int, o3.Irreps]] = []
instructions = []
for i, (mul, ir_in) in enumerate(irreps1):
for j, (_, ir_edge) in enumerate(irreps2):
for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
if ir_out in target_irreps:
k = len(irreps_out_list) # instruction index
irreps_out_list.append((mul, ir_out))
instructions.append((i, j, k, "uvu", trainable))
# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_out = o3.Irreps(irreps_out_list)
irreps_out, permut, _ = irreps_out.sort()
# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, permut[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
instructions = sorted(instructions, key=lambda x: x[2])
return irreps_out, instructions
def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps:
# Assuming simplified irreps
irreps_mid = []
for _, ir_in in irreps:
found = False
for mul, ir_out in target_irreps:
if ir_in == ir_out:
irreps_mid.append((mul, ir_out))
found = True
break
if not found:
raise RuntimeError(f"{ir_in} not in {target_irreps}")
return o3.Irreps(irreps_mid)
@compile_mode("script")
class reshape_irreps(torch.nn.Module):
def __init__(
self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None
) -> None:
super().__init__()
self.irreps = o3.Irreps(irreps)
self.cueq_config = cueq_config
self.dims = []
self.muls = []
for mul, ir in self.irreps:
d = ir.dim
self.dims.append(d)
self.muls.append(mul)
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
ix = 0
out = []
batch, _ = tensor.shape
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.layout_str == "mul_ir":
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, d, mul)
else:
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, mul, d)
out.append(field)
if hasattr(self, "cueq_config"):
if self.cueq_config is not None: # pylint: disable=no-else-return
if self.cueq_config.layout_str == "mul_ir":
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-2)
else:
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-1)
def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor:
mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device)
idx = torch.arange(mask.shape[0], device=x.device)
mask[idx, :, head] = 1
mask = mask.permute(0, 2, 1).reshape(x.shape)
return x * mask
This diff is collapsed.
This diff is collapsed.
###########################################################################################
# Radial basis and cutoff
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
import ase
import numpy as np
import torch
from e3nn.util.jit import compile_mode
from mace.tools.scatter import scatter_sum
@compile_mode("script")
class BesselBasis(torch.nn.Module):
"""
Equation (7)
"""
def __init__(self, r_max: float, num_basis=8, trainable=False):
super().__init__()
bessel_weights = (
np.pi
/ r_max
* torch.linspace(
start=1.0,
end=num_basis,
steps=num_basis,
dtype=torch.get_default_dtype(),
)
)
if trainable:
self.bessel_weights = torch.nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
self.register_buffer(
"prefactor",
torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()),
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
numerator = torch.sin(self.bessel_weights * x) # [..., num_basis]
return self.prefactor * (numerator / x)
def __repr__(self):
return (
f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, "
f"trainable={self.bessel_weights.requires_grad})"
)
@compile_mode("script")
class ChebychevBasis(torch.nn.Module):
"""
Equation (7)
"""
def __init__(self, r_max: float, num_basis=8):
super().__init__()
self.register_buffer(
"n",
torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze(
0
),
)
self.num_basis = num_basis
self.r_max = r_max
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
x = x.repeat(1, self.num_basis)
n = self.n.repeat(len(x), 1)
return torch.special.chebyshev_polynomial_t(x, n)
def __repr__(self):
return (
f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis},"
)
@compile_mode("script")
class GaussianBasis(torch.nn.Module):
"""
Gaussian basis functions
"""
def __init__(self, r_max: float, num_basis=128, trainable=False):
super().__init__()
gaussian_weights = torch.linspace(
start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype()
)
if trainable:
self.gaussian_weights = torch.nn.Parameter(
gaussian_weights, requires_grad=True
)
else:
self.register_buffer("gaussian_weights", gaussian_weights)
self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
x = x - self.gaussian_weights
return torch.exp(self.coeff * torch.pow(x, 2))
@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""
p: torch.Tensor
r_max: torch.Tensor
def __init__(self, r_max: float, p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.calculate_envelope(x, self.r_max, self.p.to(torch.int))
@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
return envelope * (x < r_max)
def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"
@compile_mode("script")
class ZBLBasis(torch.nn.Module):
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""
p: torch.Tensor
def __init__(self, p=6, trainable=False, **kwargs):
super().__init__()
if "r_max" in kwargs:
logging.warning(
"r_max is deprecated. r_max is determined from the covalent radii."
)
# Pre-calculate the p coefficients for the ZBL potential
self.register_buffer(
"c",
torch.tensor(
[0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype()
),
)
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
if trainable:
self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True))
self.a_prefactor = torch.nn.Parameter(
torch.tensor(0.4543, requires_grad=True)
)
else:
self.register_buffer("a_exp", torch.tensor(0.300))
self.register_buffer("a_prefactor", torch.tensor(0.4543))
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
a = (
self.a_prefactor
* 0.529
/ (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp))
)
r_over_a = x / a
phi = (
self.c[0] * torch.exp(-3.2 * r_over_a)
+ self.c[1] * torch.exp(-0.9423 * r_over_a)
+ self.c[2] * torch.exp(-0.4028 * r_over_a)
+ self.c[3] * torch.exp(-0.2016 * r_over_a)
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
def __repr__(self):
return f"{self.__class__.__name__}(c={self.c})"
@compile_mode("script")
class AgnesiTransform(torch.nn.Module):
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""
def __init__(
self,
q: float = 0.9183,
p: float = 4.5791,
a: float = 1.0805,
trainable=False,
):
super().__init__()
self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype()))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
if trainable:
self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True))
self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True))
self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True))
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()
def __repr__(self):
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""
Tanh-based smooth transformation:
T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))],
which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0.
"""
def __init__(self, alpha: float = 4.0, trainable=False):
"""
Args:
p1 (float): Lower "clamp" point.
alpha (float): Steepness; if None, defaults to ~6/(r0-p1).
trainable (bool): Whether to make parameters trainable.
"""
super().__init__()
# Initialize parameters
self.register_buffer(
"alpha", torch.tensor(alpha, dtype=torch.get_default_dtype())
)
if trainable:
self.alpha = torch.nn.Parameter(self.alpha.clone())
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
def compute_r_0(
self,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
"""
Compute r_0 based on atomic information.
Args:
node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers).
edge_index (torch.Tensor): Edge index indicating connections.
atomic_numbers (torch.Tensor): Atomic numbers.
Returns:
torch.Tensor: r_0 values for each edge.
"""
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
return r_0
def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers)
p_0 = (3 / 4) * r_0
p_1 = (4 / 3) * r_0
m = 0.5 * (p_0 + p_1)
alpha = self.alpha / (p_1 - p_0)
s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m)))
return p_0 + (x - p_0) * s_x
def __repr__(self):
return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})"
###########################################################################################
# Implementation of the symmetric contraction algorithm presented in the MACE paper
# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11)
# Authors: Ilyes Batatia
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import Dict, Optional, Union
import opt_einsum_fx
import torch
import torch.fx
from e3nn import o3
from e3nn.util.codegen import CodeGenMixin
from e3nn.util.jit import compile_mode
from mace.tools.cg import U_matrix_real
BATCH_EXAMPLE = 10
ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"]
@compile_mode("script")
class SymmetricContraction(CodeGenMixin, torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
correlation: Union[int, Dict[str, int]],
irrep_normalization: str = "component",
path_normalization: str = "element",
internal_weights: Optional[bool] = None,
shared_weights: Optional[bool] = None,
num_elements: Optional[int] = None,
) -> None:
super().__init__()
if irrep_normalization is None:
irrep_normalization = "component"
if path_normalization is None:
path_normalization = "element"
assert irrep_normalization in ["component", "norm", "none"]
assert path_normalization in ["element", "path", "none"]
self.irreps_in = o3.Irreps(irreps_in)
self.irreps_out = o3.Irreps(irreps_out)
del irreps_in, irreps_out
if not isinstance(correlation, tuple):
corr = correlation
correlation = {}
for irrep_out in self.irreps_out:
correlation[irrep_out] = corr
assert shared_weights or not internal_weights
if internal_weights is None:
internal_weights = True
self.internal_weights = internal_weights
self.shared_weights = shared_weights
del internal_weights, shared_weights
self.contractions = torch.nn.ModuleList()
for irrep_out in self.irreps_out:
self.contractions.append(
Contraction(
irreps_in=self.irreps_in,
irrep_out=o3.Irreps(str(irrep_out.ir)),
correlation=correlation[irrep_out],
internal_weights=self.internal_weights,
num_elements=num_elements,
weights=self.shared_weights,
)
)
def forward(self, x: torch.Tensor, y: torch.Tensor):
outs = [contraction(x, y) for contraction in self.contractions]
return torch.cat(outs, dim=-1)
@compile_mode("script")
class Contraction(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irrep_out: o3.Irreps,
correlation: int,
internal_weights: bool = True,
num_elements: Optional[int] = None,
weights: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.num_features = irreps_in.count((0, 1))
self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
self.correlation = correlation
dtype = torch.get_default_dtype()
for nu in range(1, correlation + 1):
U_matrix = U_matrix_real(
irreps_in=self.coupling_irreps,
irreps_out=irrep_out,
correlation=nu,
dtype=dtype,
)[-1]
self.register_buffer(f"U_matrix_{nu}", U_matrix)
# Tensor contraction equations
self.contractions_weighting = torch.nn.ModuleList()
self.contractions_features = torch.nn.ModuleList()
# Create weight for product basis
self.weights = torch.nn.ParameterList([])
for i in range(correlation, 0, -1):
# Shapes definying
num_params = self.U_tensors(i).size()[-1]
num_equivariance = 2 * irrep_out.lmax + 1
num_ell = self.U_tensors(i).size()[-2]
if i == correlation:
parse_subscript_main = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
+ ["ik,ekc,bci,be -> bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
)
graph_module_main = torch.fx.symbolic_trace(
lambda x, y, w, z: torch.einsum(
"".join(parse_subscript_main), x, y, w, z
)
)
# Optimizing the contractions
self.graph_opt_main = opt_einsum_fx.optimize_einsums_full(
model=graph_module_main,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights_max = w
else:
# Generate optimized contractions equations
parse_subscript_weighting = (
[ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
+ ["k,ekc,be->bc"]
+ [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
)
parse_subscript_features = (
["bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
+ ["i,bci->bc"]
+ [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
)
# Symbolic tracing of contractions
graph_module_weighting = torch.fx.symbolic_trace(
lambda x, y, z: torch.einsum(
"".join(parse_subscript_weighting), x, y, z
)
)
graph_module_features = torch.fx.symbolic_trace(
lambda x, y: torch.einsum("".join(parse_subscript_features), x, y)
)
# Optimizing the contractions
graph_opt_weighting = opt_einsum_fx.optimize_einsums_full(
model=graph_module_weighting,
example_inputs=(
torch.randn(
[num_equivariance] + [num_ell] * i + [num_params]
).squeeze(0),
torch.randn((num_elements, num_params, self.num_features)),
torch.randn((BATCH_EXAMPLE, num_elements)),
),
)
graph_opt_features = opt_einsum_fx.optimize_einsums_full(
model=graph_module_features,
example_inputs=(
torch.randn(
[BATCH_EXAMPLE, self.num_features, num_equivariance]
+ [num_ell] * i
).squeeze(2),
torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)),
),
)
self.contractions_weighting.append(graph_opt_weighting)
self.contractions_features.append(graph_opt_features)
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
self.weights.append(w)
if not internal_weights:
self.weights = weights[:-1]
self.weights_max = weights[-1]
def forward(self, x: torch.Tensor, y: torch.Tensor):
out = self.graph_opt_main(
self.U_tensors(self.correlation),
self.weights_max,
x,
y,
)
for i, (weight, contract_weights, contract_features) in enumerate(
zip(self.weights, self.contractions_weighting, self.contractions_features)
):
c_tensor = contract_weights(
self.U_tensors(self.correlation - i - 1),
weight,
y,
)
c_tensor = c_tensor + out
out = contract_features(c_tensor, x)
return out.view(out.shape[0], -1)
def U_tensors(self, nu: int):
return dict(self.named_buffers())[f"U_matrix_{nu}"]
This diff is collapsed.
"""
Wrapper class for o3.Linear that optionally uses cuet.Linear
"""
import dataclasses
from typing import List, Optional
import torch
from e3nn import o3
from mace.modules.symmetric_contraction import SymmetricContraction
from mace.tools.cg import O3_e3nn
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
CUET_AVAILABLE = True
except ImportError:
CUET_AVAILABLE = False
@dataclasses.dataclass
class CuEquivarianceConfig:
"""Configuration for cuequivariance acceleration"""
enabled: bool = False
layout: str = "mul_ir" # One of: mul_ir, ir_mul
layout_str: str = "mul_ir"
group: str = "O3"
optimize_all: bool = False # Set to True to enable all optimizations
optimize_linear: bool = False
optimize_channelwise: bool = False
optimize_symmetric: bool = False
optimize_fctp: bool = False
def __post_init__(self):
if self.enabled and CUET_AVAILABLE:
self.layout_str = self.layout
self.layout = getattr(cue, self.layout)
self.group = (
O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group)
)
if not CUET_AVAILABLE:
self.enabled = False
class Linear:
"""Returns either a cuet.Linear or o3.Linear based on config"""
def __new__(
cls,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
shared_weights: bool = True,
internal_weights: bool = True,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_linear)
):
return cuet.Linear(
cue.Irreps(cueq_config.group, irreps_in),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
use_fallback=True,
)
return o3.Linear(
irreps_in,
irreps_out,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class TensorProduct:
"""Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct"""
def __new__(
cls,
irreps_in1: o3.Irreps,
irreps_in2: o3.Irreps,
irreps_out: o3.Irreps,
instructions: Optional[List] = None,
shared_weights: bool = False,
internal_weights: bool = False,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_channelwise)
):
return cuet.ChannelWiseTensorProduct(
cue.Irreps(cueq_config.group, irreps_in1),
cue.Irreps(cueq_config.group, irreps_in2),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
internal_weights=internal_weights,
dtype=torch.get_default_dtype(),
math_dtype=torch.get_default_dtype(),
)
return o3.TensorProduct(
irreps_in1,
irreps_in2,
irreps_out,
instructions=instructions,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class FullyConnectedTensorProduct:
"""Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct"""
def __new__(
cls,
irreps_in1: o3.Irreps,
irreps_in2: o3.Irreps,
irreps_out: o3.Irreps,
shared_weights: bool = True,
internal_weights: bool = True,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_fctp)
):
return cuet.FullyConnectedTensorProduct(
cue.Irreps(cueq_config.group, irreps_in1),
cue.Irreps(cueq_config.group, irreps_in2),
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
internal_weights=internal_weights,
use_fallback=True,
)
return o3.FullyConnectedTensorProduct(
irreps_in1,
irreps_in2,
irreps_out,
shared_weights=shared_weights,
internal_weights=internal_weights,
)
class SymmetricContractionWrapper:
"""Wrapper around SymmetricContraction/cuet.SymmetricContraction"""
def __new__(
cls,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
correlation: int,
num_elements: Optional[int] = None,
cueq_config: Optional[CuEquivarianceConfig] = None,
):
if (
CUET_AVAILABLE
and cueq_config is not None
and cueq_config.enabled
and (cueq_config.optimize_all or cueq_config.optimize_symmetric)
):
return cuet.SymmetricContraction(
cue.Irreps(cueq_config.group, irreps_in),
cue.Irreps(cueq_config.group, irreps_out),
layout_in=cue.ir_mul,
layout_out=cueq_config.layout,
contraction_degree=correlation,
num_elements=num_elements,
original_mace=True,
dtype=torch.get_default_dtype(),
math_dtype=torch.get_default_dtype(),
)
return SymmetricContraction(
irreps_in=irreps_in,
irreps_out=irreps_out,
correlation=correlation,
num_elements=num_elements,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment