Unverified Commit 73ff4f3a authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent fb246ae0
This diff is collapsed.
import logging
import os
from e3nn import o3
def check_args(args):
"""
Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing
the (potentially) modified args and a list of log messages.
"""
log_messages = []
# Directories
# Use work_dir for all other directories as well, unless they were specified by the user
if args.log_dir is None:
args.log_dir = os.path.join(args.work_dir, "logs")
if args.model_dir is None:
args.model_dir = args.work_dir
if args.checkpoints_dir is None:
args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints")
if args.results_dir is None:
args.results_dir = os.path.join(args.work_dir, "results")
if args.downloads_dir is None:
args.downloads_dir = os.path.join(args.work_dir, "downloads")
# Model
# Check if hidden_irreps, num_channels and max_L are consistent
if args.hidden_irreps is None and args.num_channels is None and args.max_L is None:
args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1
elif (
args.hidden_irreps is not None
and args.num_channels is not None
and args.max_L is not None
):
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
log_messages.append(
(
"All of hidden_irreps, num_channels and max_L are specified",
logging.WARNING,
)
)
log_messages.append(
(
f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.",
logging.WARNING,
)
)
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
elif args.num_channels is not None and args.max_L is not None:
assert args.num_channels > 0, "num_channels must be positive integer"
assert args.max_L >= 0, "max_L must be non-negative integer"
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
elif args.hidden_irreps is not None:
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
args.num_channels = list(
{irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}
)[0]
args.max_L = o3.Irreps(args.hidden_irreps).lmax
elif args.max_L is not None and args.num_channels is None:
assert args.max_L >= 0, "max_L must be non-negative integer"
args.num_channels = 128
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
elif args.max_L is None and args.num_channels is not None:
assert args.num_channels > 0, "num_channels must be positive integer"
args.max_L = 1
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
# Loss and optimization
# Check Stage Two loss start
if args.start_swa is not None:
args.swa = True
log_messages.append(
(
"Stage Two is activated as start_stage_two was defined",
logging.INFO,
)
)
if args.swa:
if args.start_swa is None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
if args.start_swa > args.max_num_epochs:
log_messages.append(
(
f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}",
logging.WARNING,
)
)
log_messages.append(
(
"Stage Two will not start, as start_stage_two > max_num_epochs",
logging.WARNING,
)
)
args.swa = False
return args, log_messages
###########################################################################################
# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger)
# Authors: Ilyes Batatia
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import collections
import itertools
import os
from typing import Iterator, List, Union
import numpy as np
import torch
from e3nn import o3
try:
import cuequivariance as cue
CUET_AVAILABLE = True
except ImportError:
CUET_AVAILABLE = False
USE_CUEQ_CG = os.environ.get("MACE_USE_CUEQ_CG", "0").lower() in (
"1",
"true",
"yes",
"y",
)
_TP = collections.namedtuple("_TP", "op, args")
_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop")
def _wigner_nj(
irrepss: List[o3.Irreps],
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
):
irrepss = [o3.Irreps(irreps) for irreps in irrepss]
if filter_ir_mid is not None:
filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]
if len(irrepss) == 1:
(irreps,) = irrepss
ret = []
e = torch.eye(irreps.dim, dtype=dtype)
i = 0
for mul, ir in irreps:
for _ in range(mul):
sl = slice(i, i + ir.dim)
ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
i += ir.dim
return ret
*irrepss_left, irreps_right = irrepss
ret = []
for ir_left, path_left, C_left in _wigner_nj(
irrepss_left,
normalization=normalization,
filter_ir_mid=filter_ir_mid,
dtype=dtype,
):
i = 0
for mul, ir in irreps_right:
for ir_out in ir_left * ir:
if filter_ir_mid is not None and ir_out not in filter_ir_mid:
continue
C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype)
if normalization == "component":
C *= ir_out.dim**0.5
if normalization == "norm":
C *= ir_left.dim**0.5 * ir.dim**0.5
C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
C = C.reshape(
ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
)
for u in range(mul):
E = torch.zeros(
ir_out.dim,
*(irreps.dim for irreps in irrepss_left),
irreps_right.dim,
dtype=dtype,
)
sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
E[..., sl] = C
ret += [
(
ir_out,
_TP(
op=(ir_left, ir, ir_out),
args=(
path_left,
_INPUT(len(irrepss_left), sl.start, sl.stop),
),
),
E,
)
]
i += mul * ir.dim
return sorted(ret, key=lambda x: x[0])
def U_matrix_real(
irreps_in: Union[str, o3.Irreps],
irreps_out: Union[str, o3.Irreps],
correlation: int,
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
use_cueq_cg=None,
):
irreps_out = o3.Irreps(irreps_out)
irrepss = [o3.Irreps(irreps_in)] * correlation
if correlation == 4:
filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)]
if use_cueq_cg is None:
use_cueq_cg = USE_CUEQ_CG
if use_cueq_cg and CUET_AVAILABLE:
return compute_U_cueq(irreps_in, irreps_out=irreps_out, correlation=correlation)
try:
wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype)
except NotImplementedError as e:
if CUET_AVAILABLE:
return compute_U_cueq(
irreps_in, irreps_out=irreps_out, correlation=correlation
)
raise NotImplementedError(
"The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance"
) from e
current_ir = wigners[0][0]
out = []
stack = torch.tensor([])
for ir, _, base_o3 in wigners:
if ir in irreps_out and ir == current_ir:
stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
last_ir = current_ir
elif ir in irreps_out and ir != current_ir:
if len(stack) != 0:
out += [last_ir, stack]
stack = base_o3.squeeze().unsqueeze(-1)
current_ir, last_ir = ir, ir
else:
current_ir = ir
out += [last_ir, stack]
return out
if CUET_AVAILABLE:
def compute_U_cueq(irreps_in, irreps_out, correlation=2):
U = []
irreps_in = cue.Irreps(O3_e3nn, str(irreps_in))
irreps_out = cue.Irreps(O3_e3nn, str(irreps_out))
for _, ir in irreps_out:
ir_str = str(ir)
U.append(ir_str)
U_matrix = cue.reduced_symmetric_tensor_product_basis(
irreps_in, correlation, keep_ir=ir, layout=cue.ir_mul
).array
U_matrix = U_matrix.reshape(ir.dim, *([irreps_in.dim] * correlation), -1)
if ir.dim == 1:
U_matrix = U_matrix[0]
U.append(torch.tensor(U_matrix))
return U
class O3_e3nn(cue.O3):
def __mul__( # pylint: disable=no-self-argument
rep1: "O3_e3nn", rep2: "O3_e3nn"
) -> Iterator["O3_e3nn"]:
return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)]
@classmethod
def clebsch_gordan(
cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn"
) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim
)
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # pylint: disable=no-self-argument
rep1: "O3_e3nn", rep2: "O3_e3nn"
) -> bool:
rep2 = rep1._from(rep2)
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod
def iterator(cls) -> Iterator["O3_e3nn"]:
for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
else:
class O3_e3nn:
pass
print(
"cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled."
)
###########################################################################################
# Checkpointing
# Authors: Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import dataclasses
import logging
import os
import re
from typing import Dict, List, Optional, Tuple
import torch
from .torch_tools import TensorDict
Checkpoint = Dict[str, TensorDict]
@dataclasses.dataclass
class CheckpointState:
model: torch.nn.Module
optimizer: torch.optim.Optimizer
lr_scheduler: torch.optim.lr_scheduler.ExponentialLR
class CheckpointBuilder:
@staticmethod
def create_checkpoint(state: CheckpointState) -> Checkpoint:
return {
"model": state.model.state_dict(),
"optimizer": state.optimizer.state_dict(),
"lr_scheduler": state.lr_scheduler.state_dict(),
}
@staticmethod
def load_checkpoint(
state: CheckpointState, checkpoint: Checkpoint, strict: bool
) -> None:
state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore
state.optimizer.load_state_dict(checkpoint["optimizer"])
state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
@dataclasses.dataclass
class CheckpointPathInfo:
path: str
tag: str
epochs: int
swa: bool
class CheckpointIO:
def __init__(
self, directory: str, tag: str, keep: bool = False, swa_start: int = None
) -> None:
self.directory = directory
self.tag = tag
self.keep = keep
self.old_path: Optional[str] = None
self.swa_start = swa_start
self._epochs_string = "_epoch-"
self._filename_extension = "pt"
def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str:
if swa_start is not None and epochs >= swa_start:
return (
self.tag
+ self._epochs_string
+ str(epochs)
+ "_swa"
+ "."
+ self._filename_extension
)
return (
self.tag
+ self._epochs_string
+ str(epochs)
+ "."
+ self._filename_extension
)
def _list_file_paths(self) -> List[str]:
if not os.path.isdir(self.directory):
return []
all_paths = [
os.path.join(self.directory, f) for f in os.listdir(self.directory)
]
return [path for path in all_paths if os.path.isfile(path)]
def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]:
filename = os.path.basename(path)
regex = re.compile(
rf"^(?P<tag>.+){self._epochs_string}(?P<epochs>\d+)\.{self._filename_extension}$"
)
regex2 = re.compile(
rf"^(?P<tag>.+){self._epochs_string}(?P<epochs>\d+)_swa\.{self._filename_extension}$"
)
match = regex.match(filename)
match2 = regex2.match(filename)
swa = False
if not match:
if not match2:
return None
match = match2
swa = True
return CheckpointPathInfo(
path=path,
tag=match.group("tag"),
epochs=int(match.group("epochs")),
swa=swa,
)
def _get_latest_checkpoint_path(self, swa) -> Optional[str]:
all_file_paths = self._list_file_paths()
checkpoint_info_list = [
self._parse_checkpoint_path(path) for path in all_file_paths
]
selected_checkpoint_info_list = [
info for info in checkpoint_info_list if info and info.tag == self.tag
]
if len(selected_checkpoint_info_list) == 0:
logging.warning(
f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'"
)
return None
selected_checkpoint_info_list_swa = []
selected_checkpoint_info_list_no_swa = []
for ckp in selected_checkpoint_info_list:
if ckp.swa:
selected_checkpoint_info_list_swa.append(ckp)
else:
selected_checkpoint_info_list_no_swa.append(ckp)
if swa:
try:
latest_checkpoint_info = max(
selected_checkpoint_info_list_swa, key=lambda info: info.epochs
)
except ValueError:
logging.warning(
"No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint."
)
else:
latest_checkpoint_info = max(
selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs
)
return latest_checkpoint_info.path
def save(
self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False
) -> None:
if not self.keep and self.old_path and not keep_last:
logging.debug(f"Deleting old checkpoint file: {self.old_path}")
os.remove(self.old_path)
filename = self._get_checkpoint_filename(epochs, self.swa_start)
path = os.path.join(self.directory, filename)
logging.debug(f"Saving checkpoint: {path}")
os.makedirs(self.directory, exist_ok=True)
torch.save(obj=checkpoint, f=path)
self.old_path = path
def load_latest(
self, swa: Optional[bool] = False, device: Optional[torch.device] = None
) -> Optional[Tuple[Checkpoint, int]]:
path = self._get_latest_checkpoint_path(swa=swa)
if path is None:
return None
return self.load(path, device=device)
def load(
self, path: str, device: Optional[torch.device] = None
) -> Tuple[Checkpoint, int]:
checkpoint_info = self._parse_checkpoint_path(path)
if checkpoint_info is None:
raise RuntimeError(f"Cannot find path '{path}'")
logging.info(f"Loading checkpoint: {checkpoint_info.path}")
return (
torch.load(f=checkpoint_info.path, map_location=device),
checkpoint_info.epochs,
)
class CheckpointHandler:
def __init__(self, *args, **kwargs) -> None:
self.io = CheckpointIO(*args, **kwargs)
self.builder = CheckpointBuilder()
def save(
self, state: CheckpointState, epochs: int, keep_last: bool = False
) -> None:
checkpoint = self.builder.create_checkpoint(state)
self.io.save(checkpoint, epochs, keep_last)
def load_latest(
self,
state: CheckpointState,
swa: Optional[bool] = False,
device: Optional[torch.device] = None,
strict=False,
) -> Optional[int]:
result = self.io.load_latest(swa=swa, device=device)
if result is None:
return None
checkpoint, epochs = result
self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict)
return epochs
def load(
self,
state: CheckpointState,
path: str,
strict=False,
device: Optional[torch.device] = None,
) -> int:
checkpoint, epochs = self.io.load(path, device=device)
self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict)
return epochs
from contextlib import contextmanager
from functools import wraps
from typing import Callable, Tuple
try:
import torch._dynamo as dynamo
except ImportError:
dynamo = None
from e3nn import get_optimization_defaults, set_optimization_defaults
from torch import autograd, nn
from torch.fx import symbolic_trace
ModuleFactory = Callable[..., nn.Module]
TypeTuple = Tuple[type, ...]
@contextmanager
def disable_e3nn_codegen():
"""Context manager that disables the legacy PyTorch code generation used in e3nn."""
init_val = get_optimization_defaults()["jit_script_fx"]
set_optimization_defaults(jit_script_fx=False)
yield
set_optimization_defaults(jit_script_fx=init_val)
def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory:
"""Function transform that prepares a MACE module for torch.compile
Args:
func (ModuleFactory): A function that creates an nn.Module
allow_autograd (bool, optional): Force inductor compiler to inline call to
`torch.autograd.grad`. Defaults to True.
Returns:
ModuleFactory: Decorated function that creates a torch.compile compatible module
"""
if allow_autograd:
dynamo.allow_in_graph(autograd.grad)
else:
dynamo.disallow_in_graph(autograd.grad)
@wraps(func)
def wrapper(*args, **kwargs):
with disable_e3nn_codegen():
model = func(*args, **kwargs)
model = simplify(model)
return model
return wrapper
_SIMPLIFY_REGISTRY = set()
def simplify_if_compile(module: nn.Module) -> nn.Module:
"""Decorator to register a module for symbolic simplification
The decorated module will be simplifed using `torch.fx.symbolic_trace`.
This constrains the module to not have any dynamic control flow, see:
https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing
Args:
module (nn.Module): the module to register
Returns:
nn.Module: registered module
"""
_SIMPLIFY_REGISTRY.add(module)
return module
def simplify(module: nn.Module) -> nn.Module:
"""Recursively searches for registered modules to simplify with
`torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler.
Modules are registered with the `simplify_if_compile` decorator and
Args:
module (nn.Module): the module to simplify
Returns:
nn.Module: the simplified module
"""
simplify_types = tuple(_SIMPLIFY_REGISTRY)
for name, child in module.named_children():
if isinstance(child, simplify_types):
traced = symbolic_trace(child)
setattr(module, name, traced)
else:
simplify(child)
return module
from __future__ import annotations
from enum import Enum
class DefaultKeys(Enum):
ENERGY = "REF_energy"
FORCES = "REF_forces"
STRESS = "REF_stress"
VIRIALS = "REF_virials"
DIPOLE = "dipole"
HEAD = "head"
CHARGES = "REF_charges"
@staticmethod
def keydict() -> dict[str, str]:
key_dict = {}
for member in DefaultKeys:
key_name = f"{member.name.lower()}_key"
key_dict[key_name] = member.value
return key_dict
from .lmdb_dataset_tools import AseDBDataset
__all__ = ["AseDBDataset"]
# AseDBDataset Library
This library provides a standalone implementation of the AseDBDataset class extracted from the FairChem codebase. The AseDBDataset allows you to connect to ASE databases with various backends including JSON, SQLite, and LMDB.
## License Information
The code in this repository contains components from multiple sources with different licenses:
1. **Main Code (AseDBDataset, AseAtomsDataset, BaseDataset, etc.)**:
- Original Source: Meta's FairChem codebase
- License: MIT License
- Copyright: Meta, Inc. and its affiliates
2. **LMDBDatabase Component**:
- Original Source: Modified from ASE database JSON backend
- License: LGPL 2.1
- The ASE notice for the LGPL 2.1 license is available at: https://gitlab.com/ase/ase/-/blob/master/LICENSE
import torch
from mace.tools.utils import AtomicNumberTable
def load_foundations_elements(
model: torch.nn.Module,
model_foundations: torch.nn.Module,
table: AtomicNumberTable,
load_readout=False,
use_shift=True,
use_scale=True,
max_L=2,
):
"""
Load the foundations of a model into a model for fine-tuning.
"""
assert model_foundations.r_max == model.r_max
z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers])
model_heads = model.heads
new_z_table = table
num_species_foundations = len(z_table.zs)
num_channels_foundation = (
model_foundations.node_embedding.linear.weight.shape[0]
// num_species_foundations
)
indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs]
num_radial = model.radial_embedding.out_dim
num_species = len(indices_weights)
max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access
model.node_embedding.linear.weight = torch.nn.Parameter(
model_foundations.node_embedding.linear.weight.view(
num_species_foundations, -1
)[indices_weights, :]
.flatten()
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis":
model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter(
model_foundations.radial_embedding.bessel_fn.bessel_weights.clone()
)
for i in range(int(model.num_interactions)):
model.interactions[i].linear_up.weight = torch.nn.Parameter(
model_foundations.interactions[i].linear_up.weight.clone()
)
model.interactions[i].avg_num_neighbors = model_foundations.interactions[
i
].avg_num_neighbors
for j in range(4): # Assuming 4 layers in conv_tp_weights,
layer_name = f"layer{j}"
if j == 0:
getattr(model.interactions[i].conv_tp_weights, layer_name).weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
)
.weight[:num_radial, :]
.clone()
)
)
else:
getattr(model.interactions[i].conv_tp_weights, layer_name).weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
).weight.clone()
)
)
model.interactions[i].linear.weight = torch.nn.Parameter(
model_foundations.interactions[i].linear.weight.clone()
)
if model.interactions[i].__class__.__name__ in [
"RealAgnosticResidualInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
.skip_tp.weight.reshape(
num_channels_foundation,
num_species_foundations,
num_channels_foundation,
)[:, indices_weights, :]
.flatten()
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
else:
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
.skip_tp.weight.reshape(
num_channels_foundation,
(max_ell + 1),
num_species_foundations,
num_channels_foundation,
)[:, :, indices_weights, :]
.flatten()
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
if model.interactions[i].__class__.__name__ in [
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
).weight.clone()
)
)
# Transferring products
for i in range(2): # Assuming 2 products modules
max_range = max_L + 1 if i == 0 else 1
for j in range(max_range): # Assuming 3 contractions in symmetric_contractions
model.products[i].symmetric_contractions.contractions[j].weights_max = (
torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights_max[indices_weights, :, :]
.clone()
)
)
for k in range(2): # Assuming 2 weights in each contraction
model.products[i].symmetric_contractions.contractions[j].weights[k] = (
torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights[k][indices_weights, :, :]
.clone()
)
)
model.products[i].linear.weight = torch.nn.Parameter(
model_foundations.products[i].linear.weight.clone()
)
if load_readout:
# Transferring readouts
model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone()
model_readouts_zero_linear_weight = (
model_foundations.readouts[0]
.linear.weight.view(num_channels_foundation, -1)
.repeat(1, len(model_heads))
.flatten()
.clone()
)
model.readouts[0].linear.weight = torch.nn.Parameter(
model_readouts_zero_linear_weight
)
shape_input_1 = (
model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps
)
shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps
model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone()
model_readouts_one_linear_1_weight = (
model_foundations.readouts[1]
.linear_1.weight.view(num_channels_foundation, -1)
.repeat(1, len(model_heads))
.flatten()
.clone()
)
model.readouts[1].linear_1.weight = torch.nn.Parameter(
model_readouts_one_linear_1_weight
)
model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone()
model_readouts_one_linear_2_weight = model_foundations.readouts[
1
].linear_2.weight.view(shape_input_1, -1).repeat(
len(model_heads), len(model_heads)
).flatten().clone() / (
((shape_input_1) / (shape_output_1)) ** 0.5
)
model.readouts[1].linear_2.weight = torch.nn.Parameter(
model_readouts_one_linear_2_weight
)
if model_foundations.scale_shift is not None:
if use_scale:
model.scale_shift.scale = model_foundations.scale_shift.scale.repeat(
len(model_heads)
).clone()
if use_shift:
model.scale_shift.shift = model_foundations.scale_shift.shift.repeat(
len(model_heads)
).clone()
return model
def load_foundations(
model,
model_foundations,
):
for name, param in model_foundations.named_parameters():
if name in model.state_dict().keys():
if "readouts" not in name:
model.state_dict()[name].copy_(param)
return model
import ast
import logging
import numpy as np
from e3nn import o3
from mace import modules
from mace.tools.finetuning_utils import load_foundations_elements
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.utils import AtomicNumberTable
def configure_model(
args,
train_loader,
atomic_energies,
model_foundation=None,
heads=None,
z_table=None,
head_configs=None,
):
# Selecting outputs
compute_virials = args.loss == "virials"
compute_stress = args.loss in ("stress", "huber", "universal")
if compute_virials:
args.compute_virials = True
args.error_table = "PerAtomRMSEstressvirials"
elif compute_stress:
args.compute_stress = True
args.error_table = "PerAtomRMSEstressvirials"
output_args = {
"energy": args.compute_energy,
"forces": args.compute_forces,
"virials": compute_virials,
"stress": compute_stress,
"dipoles": args.compute_dipole,
}
logging.info(
f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}"
)
logging.info("===========MODEL DETAILS===========")
if args.scaling == "no_scaling":
args.std = 1.0
if head_configs is not None:
for head_config in head_configs:
head_config.std = 1.0
logging.info("No scaling selected")
if (
head_configs is not None
and args.std is not None
and not isinstance(args.std, list)
):
atomic_inter_scale = []
for head_config in head_configs:
if hasattr(head_config, "std") and head_config.std is not None:
atomic_inter_scale.append(head_config.std)
elif args.std is not None:
atomic_inter_scale.append(
args.std if isinstance(args.std, float) else 1.0
)
args.std = atomic_inter_scale
elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE":
args.mean, args.std = modules.scaling_classes[args.scaling](
train_loader, atomic_energies
)
# Build model
if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]:
logging.info("Loading FOUNDATION model")
model_config_foundation = extract_config_mace_model(model_foundation)
model_config_foundation["atomic_energies"] = atomic_energies
if args.foundation_model_elements:
foundation_z_table = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
model_config_foundation["atomic_numbers"] = foundation_z_table.zs
model_config_foundation["num_elements"] = len(foundation_z_table)
z_table = foundation_z_table
logging.info(
f"Using all elements from foundation model: {foundation_z_table.zs}"
)
else:
model_config_foundation["atomic_numbers"] = z_table.zs
model_config_foundation["num_elements"] = len(z_table)
logging.info(f"Using filtered elements: {z_table.zs}")
args.max_L = model_config_foundation["hidden_irreps"].lmax
if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE":
model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads)
else:
model_config_foundation["atomic_inter_shift"] = (
_determine_atomic_inter_shift(args.mean, heads)
)
model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads)
args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"]
args.model = "FoundationMACE"
model_config_foundation["heads"] = heads
model_config = model_config_foundation
logging.info("Model configuration extracted from foundation model")
logging.info("Using universal loss function for fine-tuning")
logging.info(
f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})"
)
logging.info(
f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}"
)
logging.info(
f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)"
)
logging.info(
f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}"
)
else:
logging.info("Building model")
logging.info(
f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})"
)
logging.info(
f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}"
)
logging.info(
f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions"
)
logging.info(
f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)"
)
logging.info(
f"Distance transform for radial basis functions: {args.distance_transform}"
)
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
logging.info(f"Hidden irreps: {args.hidden_irreps}")
model_config = dict(
r_max=args.r_max,
num_bessel=args.num_radial_basis,
num_polynomial_cutoff=args.num_cutoff_basis,
max_ell=args.max_ell,
interaction_cls=modules.interaction_classes[args.interaction],
num_interactions=args.num_interactions,
num_elements=len(z_table),
hidden_irreps=o3.Irreps(args.hidden_irreps),
atomic_energies=atomic_energies,
avg_num_neighbors=args.avg_num_neighbors,
atomic_numbers=z_table.zs,
)
model_config_foundation = None
model = _build_model(args, model_config, model_config_foundation, heads)
if model_foundation is not None:
model = load_foundations_elements(
model,
model_foundation,
z_table,
load_readout=args.foundation_filter_elements,
max_L=args.max_L,
)
return model, output_args
def _determine_atomic_inter_shift(mean, heads):
if isinstance(mean, np.ndarray):
if mean.size == 1:
return mean.item()
if mean.size == len(heads):
return mean.tolist()
logging.info("Mean not in correct format, using default value of 0.0")
return [0.0] * len(heads)
if isinstance(mean, list) and len(mean) == len(heads):
return mean
if isinstance(mean, float):
return [mean] * len(heads)
logging.info("Mean not in correct format, using default value of 0.0")
return [0.0] * len(heads)
def _build_model(
args, model_config, model_config_foundation, heads
): # pylint: disable=too-many-return-statements
if args.model == "MACE":
if args.interaction_first not in [
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
]:
args.interaction_first = "RealAgnosticInteractionBlock"
return modules.ScaleShiftMACE(
**model_config,
pair_repulsion=args.pair_repulsion,
distance_transform=args.distance_transform,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[args.interaction_first],
MLP_irreps=o3.Irreps(args.MLP_irreps),
atomic_inter_scale=args.std,
atomic_inter_shift=[0.0] * len(heads),
radial_MLP=ast.literal_eval(args.radial_MLP),
radial_type=args.radial_type,
heads=heads,
)
if args.model == "ScaleShiftMACE":
return modules.ScaleShiftMACE(
**model_config,
pair_repulsion=args.pair_repulsion,
distance_transform=args.distance_transform,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[args.interaction_first],
MLP_irreps=o3.Irreps(args.MLP_irreps),
atomic_inter_scale=args.std,
atomic_inter_shift=args.mean,
radial_MLP=ast.literal_eval(args.radial_MLP),
radial_type=args.radial_type,
heads=heads,
)
if args.model == "FoundationMACE":
return modules.ScaleShiftMACE(**model_config_foundation)
if args.model == "ScaleShiftBOTNet":
# say it is deprecated
raise RuntimeError("ScaleShiftBOTNet is deprecated, use MACE instead")
if args.model == "BOTNet":
raise RuntimeError("BOTNet is deprecated, use MACE instead")
if args.model == "AtomicDipolesMACE":
assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model"
assert (
args.error_table == "DipoleRMSE"
), "Use error_table DipoleRMSE with AtomicDipolesMACE model"
return modules.AtomicDipolesMACE(
**model_config,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[
"RealAgnosticInteractionBlock"
],
MLP_irreps=o3.Irreps(args.MLP_irreps),
)
if args.model == "EnergyDipolesMACE":
assert (
args.loss == "energy_forces_dipole"
), "Use energy_forces_dipole loss with EnergyDipolesMACE model"
assert (
args.error_table == "EnergyDipoleRMSE"
), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model"
return modules.EnergyDipolesMACE(
**model_config,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[
"RealAgnosticInteractionBlock"
],
MLP_irreps=o3.Irreps(args.MLP_irreps),
)
raise RuntimeError(f"Unknown model: '{args.model}'")
import argparse
import ast
import dataclasses
import logging
import os
import urllib.request
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from mace.cli.fine_tuning_select import (
FilteringType,
SelectionSettings,
SubselectType,
select_samples,
)
from mace.data import KeySpecification
from mace.tools.scripts_utils import SubsetCollection, get_dataset_from_xyz
@dataclasses.dataclass
class HeadConfig:
head_name: str
key_specification: KeySpecification
train_file: Optional[Union[str, List[str]]] = None
valid_file: Optional[Union[str, List[str]]] = None
test_file: Optional[str] = None
test_dir: Optional[str] = None
E0s: Optional[Any] = None
statistics_file: Optional[str] = None
valid_fraction: Optional[float] = None
config_type_weights: Optional[Dict[str, float]] = None
keep_isolated_atoms: Optional[bool] = None
atomic_numbers: Optional[Union[List[int], List[str]]] = None
mean: Optional[float] = None
std: Optional[float] = None
avg_num_neighbors: Optional[float] = None
compute_avg_num_neighbors: Optional[bool] = None
collections: Optional[SubsetCollection] = None
train_loader: Optional[torch.utils.data.DataLoader] = None
z_table: Optional[Any] = None
atomic_energies_dict: Optional[Dict[str, float]] = None
def dict_head_to_dataclass(
head: Dict[str, Any], head_name: str, args: argparse.Namespace
) -> HeadConfig:
"""Convert head dictionary to HeadConfig dataclass."""
# parser+head args that have no defaults but are required
if (args.train_file is None) and (head.get("train_file", None) is None):
raise ValueError(
"train file is not set in the head config yaml or via command line args"
)
return HeadConfig(
head_name=head_name,
train_file=head.get("train_file", args.train_file),
valid_file=head.get("valid_file", args.valid_file),
test_file=head.get("test_file", None),
test_dir=head.get("test_dir", None),
E0s=head.get("E0s", args.E0s),
statistics_file=head.get("statistics_file", args.statistics_file),
valid_fraction=head.get("valid_fraction", args.valid_fraction),
config_type_weights=head.get("config_type_weights", args.config_type_weights),
compute_avg_num_neighbors=head.get(
"compute_avg_num_neighbors", args.compute_avg_num_neighbors
),
atomic_numbers=head.get("atomic_numbers", args.atomic_numbers),
mean=head.get("mean", args.mean),
std=head.get("std", args.std),
avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors),
key_specification=head["key_specification"],
keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms),
)
def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]:
"""Prepare a default head from args."""
return {
"Default": {
"train_file": args.train_file,
"valid_file": args.valid_file,
"test_file": args.test_file,
"test_dir": args.test_dir,
"E0s": args.E0s,
"statistics_file": args.statistics_file,
"key_specification": args.key_specification,
"valid_fraction": args.valid_fraction,
"config_type_weights": args.config_type_weights,
"keep_isolated_atoms": args.keep_isolated_atoms,
}
}
def prepare_pt_head(
args: argparse.Namespace,
pt_keyspec: KeySpecification,
foundation_model_num_neighbours: float,
) -> Dict[str, Any]:
"""Prepare a pretraining head from args."""
if (
args.foundation_model in ["small", "medium", "large"]
or args.pt_train_file == "mp"
):
logging.info(
"Using foundation model for multiheads finetuning with Materials Project data"
)
pt_keyspec.update(
info_keys={"energy": "energy", "stress": "stress"},
arrays_keys={"forces": "forces"},
)
pt_head = {
"train_file": "mp",
"E0s": "foundation",
"statistics_file": None,
"key_specification": pt_keyspec,
"avg_num_neighbors": foundation_model_num_neighbours,
"compute_avg_num_neighbors": False,
}
else:
pt_head = {
"train_file": args.pt_train_file,
"valid_file": args.pt_valid_file,
"E0s": "foundation",
"statistics_file": args.statistics_file,
"valid_fraction": args.valid_fraction,
"key_specification": pt_keyspec,
"avg_num_neighbors": foundation_model_num_neighbours,
"keep_isolated_atoms": args.keep_isolated_atoms,
"compute_avg_num_neighbors": False,
}
return pt_head
def assemble_mp_data(
args: argparse.Namespace,
head_config_pt: HeadConfig,
tag: str,
) -> SubsetCollection:
"""Assemble Materials Project data for fine-tuning."""
try:
checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz"
cache_dir = (
Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace"
)
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_dataset_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
logging.info("Downloading MP structures for finetuning")
_, http_msg = urllib.request.urlretrieve(
checkpoint_url, cached_dataset_path
)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Dataset download failed, please check the URL {checkpoint_url}"
)
logging.info(f"Materials Project dataset to {cached_dataset_path}")
output = f"mp_finetuning-{tag}.xyz"
atomic_numbers = (
ast.literal_eval(args.atomic_numbers)
if args.atomic_numbers is not None
else None
)
settings = SelectionSettings(
configs_pt=cached_dataset_path,
output=f"mp_finetuning-{tag}.xyz",
atomic_numbers=atomic_numbers,
num_samples=args.num_samples_pt,
seed=args.seed,
head_pt="pbe_mp",
weight_pt=args.weight_pt_head,
filtering_type=FilteringType(args.filter_type_pt),
subselect=SubselectType(args.subselect_pt),
default_dtype=args.default_dtype,
)
select_samples(settings)
head_config_pt.train_file = [output]
collections_mp, _ = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=output,
valid_path=None,
valid_fraction=args.valid_fraction,
config_type_weights=None,
test_path=None,
seed=args.seed,
key_specification=head_config_pt.key_specification,
head_name="pt_head",
keep_isolated_atoms=args.keep_isolated_atoms,
)
return collections_mp
except Exception as exc:
raise RuntimeError(
"Model or descriptors download failed and no local model found"
) from exc
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment