Unverified Commit 73ff4f3a authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent fb246ae0
"""basic scatter_sum operations from torch_scatter from
https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py
Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency.
PyTorch plans to move these features into the main repo, but until then,
to make installation simpler, we need this pure python set of wrappers
that don't require installing PyTorch C++ extensions.
See https://github.com/pytorch/pytorch/issues/63780.
"""
from typing import Optional
import torch
def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
def scatter_sum(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum",
) -> torch.Tensor:
assert reduce == "sum" # for now, TODO
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
def scatter_std(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True,
) -> torch.Tensor:
if out is not None:
dim_size = out.size(dim)
if dim < 0:
dim = src.dim() + dim
count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, count_dim, dim_size=dim_size)
index = _broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = _broadcast(count, tmp, dim).clamp(1)
mean = tmp.div(count)
var = src - mean.gather(dim, index)
var = var * var
out = scatter_sum(var, index, dim, out, dim_size)
if unbiased:
count = count.sub(1).clamp_(1)
out = out.div(count + 1e-6).sqrt()
return out
def scatter_mean(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = _broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode="floor")
return out
###########################################################################################
# Training utils
# Authors: David Kovacs, Ilyes Batatia
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import argparse
import ast
import dataclasses
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed
from e3nn import o3
from torch.optim.swa_utils import SWALR, AveragedModel
from mace import data, modules, tools
from mace.data import KeySpecification
from mace.tools.train import SWAContainer
@dataclasses.dataclass
class SubsetCollection:
train: data.Configurations
valid: data.Configurations
tests: List[Tuple[str, data.Configurations]]
def log_dataset_contents(dataset: data.Configurations, dataset_name: str) -> None:
log_string = f"{dataset_name} ["
for prop_name in dataset[0].properties.keys():
if prop_name == "dipole":
log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, "
else:
log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, "
log_string = log_string[:-2] + "]"
logging.info(log_string)
def get_dataset_from_xyz(
work_dir: str,
train_path: Union[str, List[str]],
valid_path: Optional[Union[str, List[str]]],
valid_fraction: float,
key_specification: KeySpecification,
config_type_weights: Optional[Dict] = None,
test_path: Optional[Union[str, List[str]]] = None,
seed: int = 1234,
keep_isolated_atoms: bool = False,
head_name: str = "Default",
) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]:
"""
Load training, validation, and test datasets from xyz files.
Args:
work_dir: Working directory for saving split information
train_path: Path or list of paths to training xyz files
valid_path: Path or list of paths to validation xyz files
valid_fraction: Fraction of training data to use for validation if valid_path is None
config_type_weights: Dictionary of weights for each configuration type
key_specification: KeySpecification object for loading data
test_path: Path or list of paths to test xyz files
seed: Random seed for train/validation split
keep_isolated_atoms: Whether to keep isolated atoms in the dataset
head_name: Name of the head for multi-head models
Returns:
Tuple containing:
- SubsetCollection with train, valid, and test configurations
- Dictionary of atomic energies (or None if not available)
"""
# Convert input paths to lists if they're not already
train_paths = [train_path] if isinstance(train_path, str) else train_path
valid_paths = (
[valid_path]
if isinstance(valid_path, str) and valid_path is not None
else valid_path
)
test_paths = (
[test_path]
if isinstance(test_path, str) and test_path is not None
else test_path
)
# Initialize collections and atomic energies tracking
all_train_configs = []
all_valid_configs = []
all_test_configs = []
# For tracking atomic energies across files
atomic_energies_values = {} # Element Z -> list of energy values
atomic_energies_counts = {} # Element Z -> count of files with this element
# Process training files
for i, path in enumerate(train_paths):
logging.debug(f"Loading training file: {path}")
ae_dict, train_configs = data.load_from_xyz(
file_path=path,
config_type_weights=config_type_weights,
key_specification=key_specification,
extract_atomic_energies=True, # Extract from all files to average
keep_isolated_atoms=keep_isolated_atoms,
head_name=head_name,
)
all_train_configs.extend(train_configs)
# Track atomic energies from each file for averaging
if ae_dict:
for element, energy in ae_dict.items():
if element not in atomic_energies_values:
atomic_energies_values[element] = []
atomic_energies_counts[element] = 0
atomic_energies_values[element].append(energy)
atomic_energies_counts[element] += 1
log_dataset_contents(train_configs, f"Training set {i+1}/{len(train_paths)}")
# Log total training set info
log_dataset_contents(all_train_configs, "Total Training set")
# Process validation files if provided
if valid_paths:
for i, path in enumerate(valid_paths):
_, valid_configs = data.load_from_xyz(
file_path=path,
config_type_weights=config_type_weights,
key_specification=key_specification,
extract_atomic_energies=False,
head_name=head_name,
)
all_valid_configs.extend(valid_configs)
log_dataset_contents(
valid_configs, f"Validation set {i+1}/{len(valid_paths)}"
)
# Log total validation set info
log_dataset_contents(all_valid_configs, "Total Validation set")
train_configs = all_train_configs
valid_configs = all_valid_configs
else:
# Split training data if no validation files are provided
logging.info("No validation set provided, splitting training data instead.")
train_configs, valid_configs = data.random_train_valid_split(
all_train_configs, valid_fraction, seed, work_dir
)
log_dataset_contents(train_configs, "Random Split Training set")
log_dataset_contents(valid_configs, "Random Split Validation set")
test_configs_by_type = []
if test_paths:
for i, path in enumerate(test_paths):
_, test_configs = data.load_from_xyz(
file_path=path,
config_type_weights=config_type_weights,
key_specification=key_specification,
extract_atomic_energies=False,
head_name=head_name,
)
all_test_configs.extend(test_configs)
log_dataset_contents(test_configs, f"Test set {i+1}/{len(test_paths)}")
# Create list of tuples (config_type, list(Atoms))
test_configs_by_type = data.test_config_types(all_test_configs)
log_dataset_contents(all_test_configs, "Total Test set")
atomic_energies_dict = {}
for element, values in atomic_energies_values.items():
if atomic_energies_counts[element] > 1:
atomic_energies_dict[element] = sum(values) / len(values)
logging.debug(
f"Element {element} found in {atomic_energies_counts[element]} files. Using average E0: {atomic_energies_dict[element]:.6f} eV"
)
else:
atomic_energies_dict[element] = values[0]
logging.debug(
f"Element {element} found in 1 file. Using E0: {atomic_energies_dict[element]:.6f} eV"
)
return (
SubsetCollection(
train=train_configs, valid=valid_configs, tests=test_configs_by_type
),
atomic_energies_dict if atomic_energies_dict else None,
)
def get_config_type_weights(ct_weights):
"""
Get config type weights from command line argument
"""
try:
config_type_weights = ast.literal_eval(ct_weights)
assert isinstance(config_type_weights, dict)
except Exception as e: # pylint: disable=W0703
logging.warning(
f"Config type weights not specified correctly ({e}), using Default"
)
config_type_weights = {"Default": 1.0}
return config_type_weights
def print_git_commit():
try:
import git
repo = git.Repo(search_parent_directories=True)
commit = repo.head.commit.hexsha
logging.debug(f"Current Git commit: {commit}")
return commit
except Exception as e: # pylint: disable=W0703
logging.debug(f"Error accessing Git repository: {e}")
return "None"
def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]:
if model.__class__.__name__ != "ScaleShiftMACE":
return {"error": "Model is not a ScaleShiftMACE model"}
def radial_to_name(radial_type):
if radial_type == "BesselBasis":
return "bessel"
if radial_type == "GaussianBasis":
return "gaussian"
if radial_type == "ChebychevBasis":
return "chebyshev"
return radial_type
def radial_to_transform(radial):
if not hasattr(radial, "distance_transform"):
return None
if radial.distance_transform.__class__.__name__ == "AgnesiTransform":
return "Agnesi"
if radial.distance_transform.__class__.__name__ == "SoftTransform":
return "Soft"
return radial.distance_transform.__class__.__name__
scale = model.scale_shift.scale
shift = model.scale_shift.shift
heads = model.heads if hasattr(model, "heads") else ["default"]
model_mlp_irreps = (
o3.Irreps(str(model.readouts[-1].hidden_irreps))
if model.num_interactions.item() > 1
else 1
)
mlp_irreps = o3.Irreps(f"{model_mlp_irreps.count((0, 1)) // len(heads)}x0e")
try:
correlation = (
len(model.products[0].symmetric_contractions.contractions[0].weights) + 1
)
except AttributeError:
correlation = model.products[0].symmetric_contractions.contraction_degree
config = {
"r_max": model.r_max.item(),
"num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights),
"num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(),
"max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access
"interaction_cls": model.interactions[-1].__class__,
"interaction_cls_first": model.interactions[0].__class__,
"num_interactions": model.num_interactions.item(),
"num_elements": len(model.atomic_numbers),
"hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)),
"MLP_irreps": (mlp_irreps if model.num_interactions.item() > 1 else 1),
"gate": (
model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f
if model.num_interactions.item() > 1
else None
),
"atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(),
"avg_num_neighbors": model.interactions[0].avg_num_neighbors,
"atomic_numbers": model.atomic_numbers,
"correlation": correlation,
"radial_type": radial_to_name(
model.radial_embedding.bessel_fn.__class__.__name__
),
"radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1],
"pair_repulsion": hasattr(model, "pair_repulsion_fn"),
"distance_transform": radial_to_transform(model.radial_embedding),
"atomic_inter_scale": scale.cpu().numpy(),
"atomic_inter_shift": shift.cpu().numpy(),
"heads": heads,
}
return config
def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module:
return extract_model(
torch.load(f=f, map_location=map_location), map_location=map_location
)
def remove_pt_head(
model: torch.nn.Module, head_to_keep: Optional[str] = None
) -> torch.nn.Module:
"""Converts a multihead MACE model to a single head model by removing the pretraining head.
Args:
model (ScaleShiftMACE): The multihead MACE model to convert
head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head.
Returns:
ScaleShiftMACE: A new MACE model with only the specified head
Raises:
ValueError: If the model is not a multihead model or if the specified head is not found
"""
if not hasattr(model, "heads") or len(model.heads) <= 1:
raise ValueError("Model must be a multihead model with more than one head")
# Get index of head to keep
if head_to_keep is None:
# Find first non-PT head
try:
head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head")
except StopIteration as e:
raise ValueError("No non-PT head found in model") from e
else:
try:
head_idx = model.heads.index(head_to_keep)
except ValueError as e:
raise ValueError(f"Head {head_to_keep} not found in model") from e
# Extract config and modify for single head
model_config = extract_config_mace_model(model)
model_config["heads"] = [model.heads[head_idx]]
model_config["atomic_energies"] = (
model.atomic_energies_fn.atomic_energies[head_idx]
.unsqueeze(0)
.detach()
.cpu()
.numpy()
)
model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item()
model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item()
mlp_count_irreps = model_config["MLP_irreps"].count((0, 1))
# model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e")
new_model = model.__class__(**model_config)
state_dict = model.state_dict()
new_state_dict = {}
for name, param in state_dict.items():
if "atomic_energies" in name:
new_state_dict[name] = param[head_idx : head_idx + 1]
elif "scale" in name or "shift" in name:
new_state_dict[name] = param[head_idx : head_idx + 1]
elif "readouts" in name:
channels_per_head = param.shape[0] // len(model.heads)
start_idx = head_idx * channels_per_head
end_idx = start_idx + channels_per_head
if "linear_2.weight" in name:
end_idx = start_idx + channels_per_head // 2
# if (
# "readouts.0.linear.weight" in name
# or "readouts.1.linear_2.weight" in name
# ):
# new_state_dict[name] = param[start_idx:end_idx] / (
# len(model.heads) ** 0.5
# )
if "readouts.0.linear.weight" in name:
new_state_dict[name] = param.reshape(-1, len(model.heads))[
:, head_idx
].flatten()
elif "readouts.1.linear_1.weight" in name:
new_state_dict[name] = param.reshape(
-1, len(model.heads), mlp_count_irreps
)[:, head_idx, :].flatten()
elif "readouts.1.linear_2.weight" in name:
new_state_dict[name] = param.reshape(
len(model.heads), -1, len(model.heads)
)[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5)
else:
new_state_dict[name] = param[start_idx:end_idx]
else:
new_state_dict[name] = param
# Load state dict into new model
new_model.load_state_dict(new_state_dict)
return new_model
def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module:
model_copy = model.__class__(**extract_config_mace_model(model))
model_copy.load_state_dict(model.state_dict())
return model_copy.to(map_location)
def convert_to_json_format(dict_input):
for key, value in dict_input.items():
if isinstance(value, (np.ndarray, torch.Tensor)):
dict_input[key] = value.tolist()
# # check if the value is a class and convert it to a string
elif hasattr(value, "__class__"):
dict_input[key] = str(value)
return dict_input
def convert_from_json_format(dict_input):
dict_output = dict_input.copy()
if (
dict_input["interaction_cls"]
== "<class 'mace.modules.blocks.RealAgnosticResidualInteractionBlock'>"
):
dict_output["interaction_cls"] = (
modules.blocks.RealAgnosticResidualInteractionBlock
)
if (
dict_input["interaction_cls"]
== "<class 'mace.modules.blocks.RealAgnosticInteractionBlock'>"
):
dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock
if (
dict_input["interaction_cls_first"]
== "<class 'mace.modules.blocks.RealAgnosticResidualInteractionBlock'>"
):
dict_output["interaction_cls_first"] = (
modules.blocks.RealAgnosticResidualInteractionBlock
)
if (
dict_input["interaction_cls_first"]
== "<class 'mace.modules.blocks.RealAgnosticInteractionBlock'>"
):
dict_output["interaction_cls_first"] = (
modules.blocks.RealAgnosticInteractionBlock
)
dict_output["r_max"] = float(dict_input["r_max"])
dict_output["num_bessel"] = int(dict_input["num_bessel"])
dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"])
dict_output["max_ell"] = int(dict_input["max_ell"])
dict_output["num_interactions"] = int(dict_input["num_interactions"])
dict_output["num_elements"] = int(dict_input["num_elements"])
dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"])
dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"])
dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"])
dict_output["gate"] = torch.nn.functional.silu
dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"])
dict_output["atomic_numbers"] = dict_input["atomic_numbers"]
dict_output["correlation"] = int(dict_input["correlation"])
dict_output["radial_type"] = dict_input["radial_type"]
dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"])
dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"])
dict_output["distance_transform"] = dict_input["distance_transform"]
dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"])
dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"])
return dict_output
def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module:
extra_files_extract = {"commit.txt": None, "config.json": None}
model_jit_load = torch.jit.load(
f, _extra_files=extra_files_extract, map_location=map_location
)
model_load_yaml = modules.ScaleShiftMACE(
**convert_from_json_format(json.loads(extra_files_extract["config.json"]))
)
model_load_yaml.load_state_dict(model_jit_load.state_dict())
return model_load_yaml.to(map_location)
def get_atomic_energies(E0s, train_collection, z_table) -> dict:
if E0s is not None:
logging.info(
"Isolated Atomic Energies (E0s) not in training file, using command line argument"
)
if E0s.lower() == "average":
logging.info(
"Computing average Atomic Energies using least squares regression"
)
# catch if colections.train not defined above
try:
assert train_collection is not None
atomic_energies_dict = data.compute_average_E0s(
train_collection, z_table
)
except Exception as e:
raise RuntimeError(
f"Could not compute average E0s if no training xyz given, error {e} occured"
) from e
else:
if E0s.endswith(".json"):
logging.info(f"Loading atomic energies from {E0s}")
with open(E0s, "r", encoding="utf-8") as f:
atomic_energies_dict = json.load(f)
atomic_energies_dict = {
int(key): value for key, value in atomic_energies_dict.items()
}
else:
try:
atomic_energies_eval = ast.literal_eval(E0s)
if not all(
isinstance(value, dict)
for value in atomic_energies_eval.values()
):
atomic_energies_dict = atomic_energies_eval
else:
atomic_energies_dict = atomic_energies_eval
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(
f"E0s specified invalidly, error {e} occured"
) from e
else:
raise RuntimeError(
"E0s not found in training file and not specified in command line"
)
return atomic_energies_dict
def get_avg_num_neighbors(head_configs, args, train_loader, device):
if all(head_config.compute_avg_num_neighbors for head_config in head_configs):
logging.info("Computing average number of neighbors")
avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader)
if args.distributed:
num_graphs = torch.tensor(len(train_loader.dataset)).to(device)
num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device)
torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(
num_neighbors, op=torch.distributed.ReduceOp.SUM
)
avg_num_neighbors_out = (num_neighbors / num_graphs).item()
else:
avg_num_neighbors_out = avg_num_neighbors
else:
assert any(
head_config.avg_num_neighbors is not None for head_config in head_configs
), "Average number of neighbors must be provided in the configuration"
avg_num_neighbors_out = max(
head_config.avg_num_neighbors
for head_config in head_configs
if head_config.avg_num_neighbors is not None
)
if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100:
logging.warning(
f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}"
)
else:
logging.info(f"Average number of neighbors: {avg_num_neighbors_out}")
return avg_num_neighbors_out
def get_loss_fn(
args: argparse.Namespace,
dipole_only: bool,
compute_dipole: bool,
) -> torch.nn.Module:
if args.loss == "weighted":
loss_fn = modules.WeightedEnergyForcesLoss(
energy_weight=args.energy_weight, forces_weight=args.forces_weight
)
elif args.loss == "forces_only":
loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight)
elif args.loss == "virials":
loss_fn = modules.WeightedEnergyForcesVirialsLoss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
virials_weight=args.virials_weight,
)
elif args.loss == "stress":
loss_fn = modules.WeightedEnergyForcesStressLoss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
stress_weight=args.stress_weight,
)
elif args.loss == "huber":
loss_fn = modules.WeightedHuberEnergyForcesStressLoss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
stress_weight=args.stress_weight,
huber_delta=args.huber_delta,
)
elif args.loss == "universal":
loss_fn = modules.UniversalLoss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
stress_weight=args.stress_weight,
huber_delta=args.huber_delta,
)
elif args.loss == "l1l2energyforces":
loss_fn = modules.WeightedEnergyForcesL1L2Loss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
)
elif args.loss == "dipole":
assert (
dipole_only is True
), "dipole loss can only be used with AtomicDipolesMACE model"
loss_fn = modules.DipoleSingleLoss(
dipole_weight=args.dipole_weight,
)
elif args.loss == "energy_forces_dipole":
assert dipole_only is False and compute_dipole is True
loss_fn = modules.WeightedEnergyForcesDipoleLoss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
dipole_weight=args.dipole_weight,
)
else:
loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0)
return loss_fn
def get_swa(
args: argparse.Namespace,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
swas: List[bool],
dipole_only: bool = False,
):
assert dipole_only is False, "Stage Two for dipole fitting not implemented"
swas.append(True)
if args.start_swa is None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
else:
if args.start_swa >= args.max_num_epochs:
logging.warning(
f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}"
)
swas[-1] = False
if args.loss == "forces_only":
raise ValueError("Can not select Stage Two with forces only loss.")
if args.loss == "virials":
loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
virials_weight=args.swa_virials_weight,
)
logging.info(
f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}"
)
elif args.loss == "stress":
loss_fn_energy = modules.WeightedEnergyForcesStressLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
stress_weight=args.swa_stress_weight,
)
logging.info(
f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}"
)
elif args.loss == "energy_forces_dipole":
loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss(
args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
dipole_weight=args.swa_dipole_weight,
)
logging.info(
f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}"
)
elif args.loss == "universal":
loss_fn_energy = modules.UniversalLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
stress_weight=args.swa_stress_weight,
huber_delta=args.huber_delta,
)
logging.info(
f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}"
)
else:
loss_fn_energy = modules.WeightedEnergyForcesLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
)
logging.info(
f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}"
)
swa = SWAContainer(
model=AveragedModel(model),
scheduler=SWALR(
optimizer=optimizer,
swa_lr=args.swa_lr,
anneal_epochs=1,
anneal_strategy="linear",
),
start=args.start_swa,
loss_fn=loss_fn_energy,
)
return swa, swas
def get_params_options(
args: argparse.Namespace, model: torch.nn.Module
) -> Dict[str, Any]:
decay_interactions = {}
no_decay_interactions = {}
for name, param in model.interactions.named_parameters():
if "linear.weight" in name or "skip_tp_full.weight" in name:
decay_interactions[name] = param
else:
no_decay_interactions[name] = param
param_options = dict(
params=[
{
"name": "embedding",
"params": model.node_embedding.parameters(),
"weight_decay": 0.0,
},
{
"name": "interactions_decay",
"params": list(decay_interactions.values()),
"weight_decay": args.weight_decay,
},
{
"name": "interactions_no_decay",
"params": list(no_decay_interactions.values()),
"weight_decay": 0.0,
},
{
"name": "products",
"params": model.products.parameters(),
"weight_decay": args.weight_decay,
},
{
"name": "readouts",
"params": model.readouts.parameters(),
"weight_decay": 0.0,
},
],
lr=args.lr,
amsgrad=args.amsgrad,
betas=(args.beta, 0.999),
)
return param_options
def get_optimizer(
args: argparse.Namespace, param_options: Dict[str, Any]
) -> torch.optim.Optimizer:
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(**param_options)
elif args.optimizer == "schedulefree":
try:
from schedulefree import adamw_schedulefree
except ImportError as exc:
raise ImportError(
"`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`"
) from exc
_param_options = {k: v for k, v in param_options.items() if k != "amsgrad"}
optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options)
else:
optimizer = torch.optim.Adam(**param_options)
return optimizer
def setup_wandb(args: argparse.Namespace):
logging.info("Using Weights and Biases for logging")
import wandb
wandb_config = {}
args_dict = vars(args)
for key, value in args_dict.items():
if isinstance(value, np.ndarray):
args_dict[key] = value.tolist()
class CustomEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, KeySpecification):
return o.__dict__
return super().default(o)
args_dict_json = json.dumps(args_dict, cls=CustomEncoder)
for key in args.wandb_log_hypers:
wandb_config[key] = args_dict[key]
tools.init_wandb(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.wandb_name,
config=wandb_config,
directory=args.wandb_dir,
)
wandb.run.summary["params"] = args_dict_json
def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]:
return [
os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix)
]
def dict_to_array(input_data, heads):
if all(isinstance(value, np.ndarray) for value in input_data.values()):
return np.array([input_data[head] for head in heads])
if not all(isinstance(value, dict) for value in input_data.values()):
return np.array([[input_data[head]] for head in heads])
unique_keys = set()
for inner_dict in input_data.values():
unique_keys.update(inner_dict.keys())
unique_keys = list(unique_keys)
sorted_keys = sorted([int(key) for key in unique_keys])
result_array = np.zeros((len(input_data), len(sorted_keys)))
for _, (head_name, inner_dict) in enumerate(input_data.items()):
for key, value in inner_dict.items():
key_index = sorted_keys.index(int(key))
head_index = heads.index(head_name)
result_array[head_index][key_index] = value
return result_array
class LRScheduler:
def __init__(self, optimizer, args) -> None:
self.scheduler = args.scheduler
self._optimizer_type = (
args.optimizer
) # Schedulefree does not need an optimizer but checkpoint handler does.
if args.scheduler == "ExponentialLR":
self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer=optimizer, gamma=args.lr_scheduler_gamma
)
elif args.scheduler == "ReduceLROnPlateau":
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer=optimizer,
factor=args.lr_factor,
patience=args.scheduler_patience,
)
else:
raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'")
def step(self, metrics=None, epoch=None): # pylint: disable=E1123
if self._optimizer_type == "schedulefree":
return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary
if self.scheduler == "ExponentialLR":
self.lr_scheduler.step(epoch=epoch)
elif self.scheduler == "ReduceLROnPlateau":
self.lr_scheduler.step( # pylint: disable=E1123
metrics=metrics, epoch=epoch
)
def __getattr__(self, name):
if name == "step":
return self.step
return getattr(self.lr_scheduler, name)
def check_folder_subfolder(folder_path):
entries = os.listdir(folder_path)
for entry in entries:
full_path = os.path.join(folder_path, entry)
if os.path.isdir(full_path):
return True
return False
def check_path_ase_read(filename: Optional[str]) -> bool:
if filename is None:
return False
filepath = Path(filename)
if filepath.is_dir():
num_h5_files = len(list(filepath.glob("*.h5")))
num_hdf5_files = len(list(filepath.glob("*.hdf5")))
num_ldb_files = len(list(filepath.glob("*.lmdb")))
num_aselmbd_files = len(list(filepath.glob("*.aselmdb")))
num_mdb_files = len(list(filepath.glob("*.mdb")))
if (
num_h5_files
+ num_hdf5_files
+ num_ldb_files
+ num_aselmbd_files
+ num_mdb_files
== 0
):
# print all the files in the directory extension in the directory for debugging
for file in os.listdir(filepath):
print(file)
raise RuntimeError(f"No supported files found in directory '{filename}'")
return False
if filepath.suffix in (".h5", ".hdf5", ".lmdb", ".aselmdb", ".mdb"):
return False
return True
def dict_to_namespace(dictionary):
# Convert the dictionary into an argparse.Namespace
namespace = argparse.Namespace()
for key, value in dictionary.items():
setattr(namespace, key, value)
return namespace
###########################################################################################
# Slurm environment setup for distributed training.
# This code is refactored from rsarm's contribution at:
# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import os
import hostlist
class DistributedEnvironment:
def __init__(self):
self._setup_distr_env()
self.master_addr = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
self.world_size = int(os.environ["WORLD_SIZE"])
self.local_rank = int(os.environ["LOCAL_RANK"])
self.rank = int(os.environ["RANK"])
def _setup_distr_env(self):
hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0]
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333")
os.environ["WORLD_SIZE"] = os.environ.get(
"SLURM_NTASKS",
str(
int(os.environ["SLURM_NTASKS_PER_NODE"])
* int(os.environ["SLURM_NNODES"])
),
)
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
os.environ["RANK"] = os.environ["SLURM_PROCID"]
def __repr__(self):
return (
f"DistributedEnvironment(master_addr={self.master_addr}, master_port={self.master_port}, "
f"world_size={self.world_size}, local_rank={self.local_rank}, rank={self.rank})"
)
import logging
from typing import Dict, List, Optional
import torch
from prettytable import PrettyTable
from mace.tools import evaluate
def custom_key(key):
"""
Helper function to sort the keys of the data loader dictionary
to ensure that the training set, and validation set
are evaluated first
"""
if key == "train":
return (0, key)
if key == "valid":
return (1, key)
return (2, key)
def create_error_table(
table_type: str,
all_data_loaders: dict,
model: torch.nn.Module,
loss_fn: torch.nn.Module,
output_args: Dict[str, bool],
log_wandb: bool,
device: str,
distributed: bool = False,
skip_heads: Optional[List[str]] = None,
) -> PrettyTable:
if log_wandb:
import wandb
skip_heads = skip_heads or []
table = PrettyTable()
if table_type == "TotalRMSE":
table.field_names = [
"config_type",
"RMSE E / meV",
"RMSE F / meV / A",
"relative F RMSE %",
]
elif table_type == "PerAtomRMSE":
table.field_names = [
"config_type",
"RMSE E / meV / atom",
"RMSE F / meV / A",
"relative F RMSE %",
]
elif table_type == "PerAtomRMSEstressvirials":
table.field_names = [
"config_type",
"RMSE E / meV / atom",
"RMSE F / meV / A",
"relative F RMSE %",
"RMSE Stress (Virials) / meV / A (A^3)",
]
elif table_type == "PerAtomMAEstressvirials":
table.field_names = [
"config_type",
"MAE E / meV / atom",
"MAE F / meV / A",
"relative F MAE %",
"MAE Stress (Virials) / meV / A (A^3)",
]
elif table_type == "TotalMAE":
table.field_names = [
"config_type",
"MAE E / meV",
"MAE F / meV / A",
"relative F MAE %",
]
elif table_type == "PerAtomMAE":
table.field_names = [
"config_type",
"MAE E / meV / atom",
"MAE F / meV / A",
"relative F MAE %",
]
elif table_type == "DipoleRMSE":
table.field_names = [
"config_type",
"RMSE MU / mDebye / atom",
"relative MU RMSE %",
]
elif table_type == "DipoleMAE":
table.field_names = [
"config_type",
"MAE MU / mDebye / atom",
"relative MU MAE %",
]
elif table_type == "EnergyDipoleRMSE":
table.field_names = [
"config_type",
"RMSE E / meV / atom",
"RMSE F / meV / A",
"rel F RMSE %",
"RMSE MU / mDebye / atom",
"rel MU RMSE %",
]
for name in sorted(all_data_loaders, key=custom_key):
if any(skip_head in name for skip_head in skip_heads):
logging.info(f"Skipping evaluation of {name} (in skip_heads list)")
continue
data_loader = all_data_loaders[name]
logging.info(f"Evaluating {name} ...")
_, metrics = evaluate(
model,
loss_fn=loss_fn,
data_loader=data_loader,
output_args=output_args,
device=device,
)
if distributed:
torch.distributed.barrier()
del data_loader
torch.cuda.empty_cache()
if log_wandb:
wandb_log_dict = {
name
+ "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"]
* 1e3, # meV / atom
name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A
name + "_final_rel_rmse_f": metrics["rel_rmse_f"],
}
wandb.log(wandb_log_dict)
if table_type == "TotalRMSE":
table.add_row(
[
name,
f"{metrics['rmse_e'] * 1000:8.1f}",
f"{metrics['rmse_f'] * 1000:8.1f}",
f"{metrics['rel_rmse_f']:8.2f}",
]
)
elif table_type == "PerAtomRMSE":
table.add_row(
[
name,
f"{metrics['rmse_e_per_atom'] * 1000:8.1f}",
f"{metrics['rmse_f'] * 1000:8.1f}",
f"{metrics['rel_rmse_f']:8.2f}",
]
)
elif (
table_type == "PerAtomRMSEstressvirials"
and metrics["rmse_stress"] is not None
):
table.add_row(
[
name,
f"{metrics['rmse_e_per_atom'] * 1000:8.1f}",
f"{metrics['rmse_f'] * 1000:8.1f}",
f"{metrics['rel_rmse_f']:8.2f}",
f"{metrics['rmse_stress'] * 1000:8.1f}",
]
)
elif (
table_type == "PerAtomRMSEstressvirials"
and metrics["rmse_virials"] is not None
):
table.add_row(
[
name,
f"{metrics['rmse_e_per_atom'] * 1000:8.1f}",
f"{metrics['rmse_f'] * 1000:8.1f}",
f"{metrics['rel_rmse_f']:8.2f}",
f"{metrics['rmse_virials'] * 1000:8.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_stress"] is not None
):
table.add_row(
[
name,
f"{metrics['mae_e_per_atom'] * 1000:8.1f}",
f"{metrics['mae_f'] * 1000:8.1f}",
f"{metrics['rel_mae_f']:8.2f}",
f"{metrics['mae_stress'] * 1000:8.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_virials"] is not None
):
table.add_row(
[
name,
f"{metrics['mae_e_per_atom'] * 1000:8.1f}",
f"{metrics['mae_f'] * 1000:8.1f}",
f"{metrics['rel_mae_f']:8.2f}",
f"{metrics['mae_virials'] * 1000:8.1f}",
]
)
elif table_type == "TotalMAE":
table.add_row(
[
name,
f"{metrics['mae_e'] * 1000:8.1f}",
f"{metrics['mae_f'] * 1000:8.1f}",
f"{metrics['rel_mae_f']:8.2f}",
]
)
elif table_type == "PerAtomMAE":
table.add_row(
[
name,
f"{metrics['mae_e_per_atom'] * 1000:8.1f}",
f"{metrics['mae_f'] * 1000:8.1f}",
f"{metrics['rel_mae_f']:8.2f}",
]
)
elif table_type == "DipoleRMSE":
table.add_row(
[
name,
f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}",
f"{metrics['rel_rmse_mu']:8.1f}",
]
)
elif table_type == "DipoleMAE":
table.add_row(
[
name,
f"{metrics['mae_mu_per_atom'] * 1000:8.2f}",
f"{metrics['rel_mae_mu']:8.1f}",
]
)
elif table_type == "EnergyDipoleRMSE":
table.add_row(
[
name,
f"{metrics['rmse_e_per_atom'] * 1000:8.1f}",
f"{metrics['rmse_f'] * 1000:8.1f}",
f"{metrics['rel_rmse_f']:8.1f}",
f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}",
f"{metrics['rel_rmse_mu']:8.1f}",
]
)
return table
# Trimmed-down `pytorch_geometric`
MACE uses [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. However as only use a very limited subset of that library: the most basic graph data structures.
We follow the same approach to NequIP (https://github.com/mir-group/nequip/tree/main/nequip) and copy their code here.
To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is necessary for our code.
We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch.
[1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric <br>
[2] https://arxiv.org/abs/1903.02428
from .batch import Batch
from .data import Data
from .dataloader import DataLoader
from .dataset import Dataset
from .seed import seed_everything
__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment