Unverified Commit 251f5af2 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 73ff4f3a
"""Demonstrates active learning molecular dynamics with constant temperature."""
import argparse
import os
import time
import ase.io
import numpy as np
from ase import units
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from mace.calculators.mace import MACECalculator
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1
)
parser.add_argument(
"--error_threshold", help="error threshold", type=float, default=0.1
)
parser.add_argument("--temperature_K", help="temperature", type=float, default=300)
parser.add_argument("--friction", help="friction", type=float, default=0.01)
parser.add_argument("--timestep", help="timestep", type=float, default=1)
parser.add_argument("--nsteps", help="number of steps", type=int, default=1000)
parser.add_argument(
"--nprint", help="number of steps between prints", type=int, default=10
)
parser.add_argument(
"--nsave", help="number of steps between saves", type=int, default=10
)
parser.add_argument(
"--ncheckerror", help="number of steps between saves", type=int, default=10
)
parser.add_argument(
"--model",
help="path to model. Use wildcards to add multiple models as committee eg "
"(`mace_*.model` to load mace_1.model, mace_2.model) ",
required=True,
)
parser.add_argument("--output", help="output path", required=True)
parser.add_argument(
"--device",
help="select device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
)
parser.add_argument(
"--default_dtype",
help="set default dtype",
type=str,
choices=["float32", "float64"],
default="float64",
)
parser.add_argument(
"--compute_stress",
help="compute stress",
action="store_true",
default=False,
)
parser.add_argument(
"--info_prefix",
help="prefix for energy, forces and stress keys",
type=str,
default="MACE_",
)
return parser.parse_args()
def printenergy(dyn, start_time=None): # store a reference to atoms in the definition.
"""Function to print the potential, kinetic and total energy."""
a = dyn.atoms
epot = a.get_potential_energy() / len(a)
ekin = a.get_kinetic_energy() / len(a)
if start_time is None:
elapsed_time = 0
else:
elapsed_time = time.time() - start_time
forces_var = np.var(a.calc.results["forces_comm"], axis=0)
print(
"%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209
"Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A"
% (
elapsed_time,
epot,
ekin,
ekin / (1.5 * units.kB),
epot + ekin,
dyn.get_time() / units.fs,
a.calc.results["energy_var"],
np.max(np.linalg.norm(forces_var, axis=1)),
),
flush=True,
)
def save_config(dyn, fname):
atomsi = dyn.atoms
ens = atomsi.get_potential_energy()
frcs = atomsi.get_forces()
atomsi.info.update(
{
"mlff_energy": ens,
"time": np.round(dyn.get_time() / units.fs, 5),
"mlff_energy_var": atomsi.calc.results["energy_var"],
}
)
atomsi.arrays.update(
{
"mlff_forces": frcs,
"mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0),
}
)
ase.io.write(fname, atomsi, append=True)
def stop_error(dyn, threshold, reg=0.2):
atomsi = dyn.atoms
force_var = np.var(atomsi.calc.results["forces_comm"], axis=0)
force = atomsi.get_forces()
ferr = np.sqrt(np.sum(force_var, axis=1))
ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg)
if np.max(ferr_rel) > threshold:
print(
"Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209
np.max(ferr_rel), dyn.get_time() / units.fs
),
flush=True,
)
dyn.max_steps = 0
def main() -> None:
args = parse_args()
run(args)
def run(args: argparse.Namespace) -> None:
mace_fname = args.model
atoms_fname = args.config
atoms_index = args.config_index
mace_calc = MACECalculator(
model_paths=mace_fname,
device=args.device,
default_dtype=args.default_dtype,
)
NSTEPS = args.nsteps
if os.path.exists(args.output):
print("Trajectory exists. Continuing from last step.")
atoms = ase.io.read(args.output, index=-1)
len_save = len(ase.io.read(args.output, ":"))
print("Last step: ", atoms.info["time"], "Number of configs: ", len_save)
NSTEPS -= len_save * args.nsave
else:
atoms = ase.io.read(atoms_fname, index=atoms_index)
MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K)
atoms.calc = mace_calc
# We want to run MD with constant energy using the Langevin algorithm
# with a time step of 5 fs, the temperature T and the friction
# coefficient to 0.02 atomic units.
dyn = Langevin(
atoms=atoms,
timestep=args.timestep * units.fs,
temperature_K=args.temperature_K,
friction=args.friction,
)
dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time())
dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output)
dyn.attach(
stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold
)
# Now run the dynamics
dyn.run(NSTEPS)
if __name__ == "__main__":
main()
import argparse
import logging
import os
from typing import Dict, List, Tuple
import torch
from mace.tools.scripts_utils import extract_config_mace_model
def get_transfer_keys(num_layers: int) -> List[str]:
"""Get list of keys that need to be transferred"""
return [
"node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight",
*[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)],
"scale_shift.scale",
"scale_shift.shift",
*[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)],
] + [
s
for j in range(num_layers)
for s in [
f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight",
]
]
def get_kmax_pairs(
max_L: int, correlation: int, num_layers: int
) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation"""
if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3:
kmax_pairs = [[i, max_L] for i in range(num_layers - 1)]
kmax_pairs = kmax_pairs + [[num_layers - 1, 0]]
return kmax_pairs
raise NotImplementedError(f"Correlation {correlation} not supported")
def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor],
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer symmetric contraction weights from CuEq to E3nn format"""
kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers)
for i, kmax in kmax_pairs:
# Get the combined weight tensor from source
wm = source_dict[f"products.{i}.symmetric_contractions.weight"]
# Get split sizes based on target dimensions
splits = []
for k in range(kmax + 1):
for suffix in ["_max", ".0", ".1"]:
key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}"
target_shape = target_dict[key].shape
splits.append(target_shape[1])
# Split the weights using the calculated sizes
weights_split = torch.split(wm, splits, dim=1)
# Assign back to target dictionary
idx = 0
for k in range(kmax + 1):
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights_max"
] = weights_split[idx]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.0"
] = weights_split[idx + 1]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.1"
] = weights_split[idx + 2]
idx += 3
def transfer_weights(
source_model: torch.nn.Module,
target_model: torch.nn.Module,
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer weights from CuEq to E3nn format"""
# Get state dicts
source_dict = source_model.state_dict()
target_dict = target_model.state_dict()
# Transfer main weights
transfer_keys = get_transfer_keys(num_layers)
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")
# Transfer symmetric contractions
transfer_symmetric_contractions(
source_dict, target_dict, max_L, correlation, num_layers
)
# Unsqueeze linear and skip_tp layers
for key in source_dict.keys():
if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key:
target_dict[key] = target_dict[key].squeeze(0)
# Transfer remaining matching keys
transferred_keys = set(transfer_keys)
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}
if remaining_keys:
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key]
else:
logging.warning(
f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
)
# Transfer avg_num_neighbors
for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i
].avg_num_neighbors
# Load state dict into target model
target_model.load_state_dict(target_dict)
def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True):
# Load CuEq model
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
config = extract_config_mace_model(source_model)
# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
# Remove CuEq config
config.pop("cueq_config", None)
# Create new model without CuEq config
logging.info("Creating new model without CuEq settings")
target_model = source_model.__class__(**config)
# Transfer weights with proper remapping
num_layers = config["num_interactions"]
transfer_weights(source_model, target_model, max_L, correlation, num_layers)
if return_model:
return target_model
# Save model
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.warning(f"Saving E3nn model to {output_model}")
torch.save(target_model, output_model)
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input CuEq model")
parser.add_argument(
"--output_model", help="Path to output E3nn model", default="e3nn_model.pt"
)
parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument(
"--return_model",
action="store_false",
help="Return model instead of saving to file",
)
args = parser.parse_args()
run(
input_model=args.input_model,
output_model=args.output_model,
device=args.device,
return_model=args.return_model,
)
if __name__ == "__main__":
main()
from argparse import ArgumentParser
import torch
def main():
parser = ArgumentParser()
parser.add_argument(
"--target_device",
"-t",
help="device to convert to, usually 'cpu' or 'cuda'",
default="cpu",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()
if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device
model = torch.load(args.model_file, weights_only=False)
model.to(args.target_device)
torch.save(model, args.output_file)
if __name__ == "__main__":
main()
import argparse
import logging
import os
from typing import Dict, List, Tuple
import torch
from mace.modules.wrapper_ops import CuEquivarianceConfig
from mace.tools.scripts_utils import extract_config_mace_model
def get_transfer_keys(num_layers: int) -> List[str]:
"""Get list of keys that need to be transferred"""
return [
"node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight",
*[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)],
"scale_shift.scale",
"scale_shift.shift",
*[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)],
] + [
s
for j in range(num_layers)
for s in [
f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight",
]
]
def get_kmax_pairs(
max_L: int, correlation: int, num_layers: int
) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation"""
if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3:
kmax_pairs = [[i, max_L] for i in range(num_layers - 1)]
kmax_pairs = kmax_pairs + [[num_layers - 1, 0]]
return kmax_pairs
raise NotImplementedError(f"Correlation {correlation} not supported")
def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor],
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer symmetric contraction weights"""
kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers)
for i, kmax in kmax_pairs:
wm = torch.concatenate(
[
source_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}"
]
for k in range(kmax + 1)
for j in ["_max", ".0", ".1"]
],
dim=1,
)
target_dict[f"products.{i}.symmetric_contractions.weight"] = wm
def transfer_weights(
source_model: torch.nn.Module,
target_model: torch.nn.Module,
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer weights with proper remapping"""
# Get source state dict
source_dict = source_model.state_dict()
target_dict = target_model.state_dict()
# Transfer main weights
transfer_keys = get_transfer_keys(num_layers)
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")
# Transfer symmetric contractions
transfer_symmetric_contractions(
source_dict, target_dict, max_L, correlation, num_layers
)
# Unsqueeze linear and skip_tp layers
for key in source_dict.keys():
if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key:
target_dict[key] = target_dict[key].unsqueeze(0)
transferred_keys = set(transfer_keys)
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}
if remaining_keys:
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key]
else:
logging.warning(
f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
)
# Transfer avg_num_neighbors
for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i
].avg_num_neighbors
# Load state dict into target model
target_model.load_state_dict(target_dict)
def run(
input_model,
output_model="_cueq.model",
device="cpu",
return_model=True,
):
# Setup logging
# Load original model
# logging.warning(f"Loading model")
# check if input_model is a path or a model
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
config = extract_config_mace_model(source_model)
# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
# Add cuequivariance config
config["cueq_config"] = CuEquivarianceConfig(
enabled=True,
layout="ir_mul",
group="O3_e3nn",
optimize_all=True,
)
# Create new model with cuequivariance config
logging.info("Creating new model with cuequivariance settings")
target_model = source_model.__class__(**config).to(device)
# Transfer weights with proper remapping
num_layers = config["num_interactions"]
transfer_weights(source_model, target_model, max_L, correlation, num_layers)
if return_model:
return target_model
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.warning(f"Saving CuEq model to {output_model}")
torch.save(target_model, output_model)
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input MACE model")
parser.add_argument(
"--output_model",
help="Path to output cuequivariance model",
default="cueq_model.pt",
)
parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument(
"--return_model",
action="store_false",
help="Return model instead of saving to file",
)
args = parser.parse_args()
run(
input_model=args.input_model,
output_model=args.output_model,
device=args.device,
return_model=args.return_model,
)
if __name__ == "__main__":
main()
# pylint: disable=wrong-import-position
import argparse
import copy
import os
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import torch
from e3nn.util import jit
from mace.calculators import LAMMPS_MACE
from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"model_path",
type=str,
help="Path to the model to be converted to LAMMPS",
)
parser.add_argument(
"--head",
type=str,
nargs="?",
help="Head of the model to be converted to LAMMPS",
default=None,
)
parser.add_argument(
"--dtype",
type=str,
nargs="?",
help="Data type of the model to be converted to LAMMPS",
default="float64",
)
parser.add_argument(
"--format",
type=str,
help="Old libtorch format, or new mliap format",
default="libtorch",
)
return parser.parse_args()
def select_head(model):
if hasattr(model, "heads"):
heads = model.heads
else:
heads = [None]
if len(heads) == 1:
print(f"Only one head found in the model: {heads[0]}. Skipping selection.")
return heads[0]
print("Available heads in the model:")
for i, head in enumerate(heads):
print(f"{i + 1}: {head}")
# Ask the user to select a head
selected = input(
f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): "
)
if selected.isdigit() and 1 <= int(selected) <= len(heads):
return heads[int(selected) - 1]
if selected == "":
print("No head selected. Proceeding without specifying a head.")
return None
print(f"No valid selection made. Defaulting to the last head: {heads[-1]}")
return heads[-1]
def main():
args = parse_args()
model_path = args.model_path # takes model name as command-line input
model = torch.load(
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
if args.dtype == "float64":
model = model.double().to("cpu")
elif args.dtype == "float32":
print("Converting model to float32, this may cause loss of precision.")
model = model.float().to("cpu")
if args.format == "mliap":
# Enabling cuequivariance by default. TODO: switch?
model = run_e3nn_to_cueq(copy.deepcopy(model))
model.lammps_mliap = True
if args.head is None:
head = select_head(model)
else:
head = args.head
print(
f"Selected head: {head} from command line in the list available heads: {model.heads}"
)
lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE
lammps_model = (
lammps_class(model, head=head) if head is not None else lammps_class(model)
)
if args.format == "mliap":
torch.save(lammps_model, model_path + "-mliap_lammps.pt")
else:
lammps_model_compiled = jit.compile(lammps_model)
lammps_model_compiled.save(model_path + "-lammps.pt")
if __name__ == "__main__":
main()
###########################################################################################
# Script for evaluating configurations contained in an xyz file with a trained model
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import argparse
import ase.data
import ase.io
import numpy as np
import torch
from mace import data
from mace.tools import torch_geometric, torch_tools, utils
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True)
parser.add_argument(
"--device",
help="select device",
type=str,
choices=["cpu", "cuda"],
default="cpu",
)
parser.add_argument(
"--default_dtype",
help="set default dtype",
type=str,
choices=["float32", "float64"],
default="float64",
)
parser.add_argument("--batch_size", help="batch size", type=int, default=64)
parser.add_argument(
"--compute_stress",
help="compute stress",
action="store_true",
default=False,
)
parser.add_argument(
"--return_contributions",
help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE",
action="store_true",
default=False,
)
parser.add_argument(
"--info_prefix",
help="prefix for energy, forces and stress keys",
type=str,
default="MACE_",
)
parser.add_argument(
"--head",
help="Model head used for evaluation",
type=str,
required=False,
default=None,
)
return parser.parse_args()
def main() -> None:
args = parse_args()
run(args)
def run(args: argparse.Namespace) -> None:
torch_tools.set_default_dtype(args.default_dtype)
device = torch_tools.init_device(args.device)
# Load model
model = torch.load(f=args.model, map_location=args.device)
model = model.to(
args.device
) # shouldn't be necessary but seems to help with CUDA problems
for param in model.parameters():
param.requires_grad = False
# Load data and prepare input
atoms_list = ase.io.read(args.configs, index=":")
if args.head is not None:
for atoms in atoms_list:
atoms.info["head"] = args.head
configs = [data.config_from_atoms(atoms) for atoms in atoms_list]
z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])
try:
heads = model.heads
except AttributeError:
heads = None
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=z_table, cutoff=float(model.r_max), heads=heads
)
for config in configs
],
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
)
# Collect data
energies_list = []
contributions_list = []
stresses_list = []
forces_collection = []
for batch in data_loader:
batch = batch.to(device)
output = model(batch.to_dict(), compute_stress=args.compute_stress)
energies_list.append(torch_tools.to_numpy(output["energy"]))
if args.compute_stress:
stresses_list.append(torch_tools.to_numpy(output["stress"]))
if args.return_contributions:
contributions_list.append(torch_tools.to_numpy(output["contributions"]))
forces = np.split(
torch_tools.to_numpy(output["forces"]),
indices_or_sections=batch.ptr[1:],
axis=0,
)
forces_collection.append(forces[:-1]) # drop last as its empty
energies = np.concatenate(energies_list, axis=0)
forces_list = [
forces for forces_list in forces_collection for forces in forces_list
]
assert len(atoms_list) == len(energies) == len(forces_list)
if args.compute_stress:
stresses = np.concatenate(stresses_list, axis=0)
assert len(atoms_list) == stresses.shape[0]
if args.return_contributions:
contributions = np.concatenate(contributions_list, axis=0)
assert len(atoms_list) == contributions.shape[0]
# Store data in atoms objects
for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)):
atoms.calc = None # crucial
atoms.info[args.info_prefix + "energy"] = energy
atoms.arrays[args.info_prefix + "forces"] = forces
if args.compute_stress:
atoms.info[args.info_prefix + "stress"] = stresses[i]
if args.return_contributions:
atoms.info[args.info_prefix + "BO_contributions"] = contributions[i]
# Write atoms to output path
ase.io.write(args.output, images=atoms_list, format="extxyz")
if __name__ == "__main__":
main()
###########################################################################################
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from __future__ import annotations
import argparse
import logging
from dataclasses import dataclass
from enum import Enum
from typing import List, Tuple, Union
import ase.data
import ase.io
import numpy as np
import torch
from mace.calculators import MACECalculator, mace_mp
try:
import fpsample # type: ignore
except ImportError:
pass
class FilteringType(Enum):
NONE = "none"
COMBINATIONS = "combinations"
EXCLUSIVE = "exclusive"
INCLUSIVE = "inclusive"
class SubselectType(Enum):
FPS = "fps"
RANDOM = "random"
@dataclass
class SelectionSettings:
configs_pt: str
output: str
configs_ft: str | None = None
atomic_numbers: List[int] | None = None
num_samples: int | None = None
subselect: SubselectType = SubselectType.FPS
model: str = "small"
descriptors: str | None = None
device: str = "cpu"
default_dtype: str = "float64"
head_pt: str | None = None
head_ft: str | None = None
filtering_type: FilteringType = FilteringType.COMBINATIONS
weight_ft: float = 1.0
weight_pt: float = 1.0
seed: int = 42
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--configs_pt",
help="path to XYZ configurations for the pretraining",
required=True,
)
parser.add_argument(
"--configs_ft",
help="path or list of paths to XYZ configurations for the finetuning",
required=False,
default=None,
)
parser.add_argument(
"--num_samples",
help="number of samples to select for the pretraining",
type=int,
required=False,
default=None,
)
parser.add_argument(
"--subselect",
help="method to subselect the configurations of the pretraining set",
type=SubselectType,
choices=list(SubselectType),
default=SubselectType.FPS,
)
parser.add_argument(
"--model", help="path to model", default="small", required=False
)
parser.add_argument("--output", help="output path", required=True)
parser.add_argument(
"--descriptors", help="path to descriptors", required=False, default=None
)
parser.add_argument(
"--device",
help="select device",
type=str,
choices=["cpu", "cuda"],
default="cpu",
)
parser.add_argument(
"--default_dtype",
help="set default dtype",
type=str,
choices=["float32", "float64"],
default="float64",
)
parser.add_argument(
"--head_pt",
help="level of head for the pretraining set",
type=str,
default=None,
)
parser.add_argument(
"--head_ft",
help="level of head for the finetuning set",
type=str,
default=None,
)
parser.add_argument(
"--filtering_type",
help="filtering type",
type=FilteringType,
choices=list(FilteringType),
default=FilteringType.NONE,
)
parser.add_argument(
"--weight_ft",
help="weight for the finetuning set",
type=float,
default=1.0,
)
parser.add_argument(
"--weight_pt",
help="weight for the pretraining set",
type=float,
default=1.0,
)
parser.add_argument("--seed", help="random seed", type=int, default=42)
return parser.parse_args()
def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None:
logging.info("Calculating descriptors")
for mol in atoms:
descriptors = calc.get_descriptors(mol.copy(), invariants_only=True)
# average descriptors over atoms for each element
descriptors_dict = {
element: np.mean(descriptors[mol.symbols == element], axis=0)
for element in np.unique(mol.symbols)
}
mol.info["mace_descriptors"] = descriptors_dict
def filter_atoms(
atoms: ase.Atoms,
element_subset: List[str],
filtering_type: FilteringType = FilteringType.COMBINATIONS,
) -> bool:
"""
Filters atoms based on the provided filtering type and element subset.
Parameters:
atoms (ase.Atoms): The atoms object to filter.
element_subset (list): The list of elements to consider during filtering.
filtering_type (FilteringType): The type of filtering to apply.
Can be one of the following `FilteringType` enum members:
- `FilteringType.NONE`: No filtering is applied.
- `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present.
- `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise.
- `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements.
Returns:
bool: True if the atoms pass the filter, False otherwise.
"""
if filtering_type == FilteringType.NONE:
return True
if filtering_type == FilteringType.COMBINATIONS:
atom_symbols = np.unique(atoms.symbols)
return all(
x in element_subset for x in atom_symbols
) # atoms must *only* contain elements in the subset
if filtering_type == FilteringType.EXCLUSIVE:
atom_symbols = set(list(atoms.symbols))
return atom_symbols == set(element_subset)
if filtering_type == FilteringType.INCLUSIVE:
atom_symbols = np.unique(atoms.symbols)
return all(
x in atom_symbols for x in element_subset
) # atoms must *at least* contain elements in the subset
raise ValueError(
f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}."
)
class FPS:
def __init__(self, atoms_list: List[ase.Atoms], n_samples: int):
self.n_samples = n_samples
self.atoms_list = atoms_list
self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore
self.species_dict = {x: i for i, x in enumerate(self.species)}
# start from a random configuration
self.list_index = [np.random.randint(0, len(atoms_list))]
self.assemble_descriptors()
def run(
self,
) -> List[int]:
"""
Run the farthest point sampling algorithm.
"""
descriptor_dataset_reshaped = (
self.descriptors_dataset.reshape( # pylint: disable=E1121
(len(self.atoms_list), -1)
)
)
logging.info(f"{descriptor_dataset_reshaped.shape}")
logging.info(f"n_samples: {self.n_samples}")
self.list_index = fpsample.fps_npdu_kdtree_sampling(
descriptor_dataset_reshaped,
self.n_samples,
)
return self.list_index
def assemble_descriptors(self) -> None:
"""
Assemble the descriptors for all the configurations.
"""
self.descriptors_dataset: np.ndarray = 10e10 * np.ones(
(
len(self.atoms_list),
len(self.species),
len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]),
),
dtype=np.float32,
).astype(np.float32)
for i, atoms in enumerate(self.atoms_list):
descriptors = atoms.info["mace_descriptors"]
for z in descriptors:
self.descriptors_dataset[i, self.species_dict[z]] = np.array(
descriptors[z]
).astype(np.float32)
def _load_calc(
model: str, device: str, default_dtype: str, subselect: SubselectType
) -> Union[MACECalculator, None]:
if subselect == SubselectType.RANDOM:
return None
if model in ["small", "medium", "large"]:
calc = mace_mp(model, device=device, default_dtype=default_dtype)
else:
calc = MACECalculator(
model_paths=model,
device=device,
default_dtype=default_dtype,
)
return calc
def _get_finetuning_elements(
atoms: List[ase.Atoms], atomic_numbers: List[int] | None
) -> List[str]:
if atoms:
logging.debug(
"Using elements from the finetuning configurations for filtering."
)
species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore
elif atomic_numbers is not None and atomic_numbers:
logging.debug("Using the supplied atomic numbers for filtering.")
species = [ase.data.chemical_symbols[z] for z in atomic_numbers]
else:
species = []
return species
def _read_finetuning_configs(
configs_ft: Union[str, list[str], None],
) -> List[ase.Atoms]:
if isinstance(configs_ft, str):
path = configs_ft
return ase.io.read(path, index=":") # type: ignore
if isinstance(configs_ft, list):
assert all(isinstance(x, str) for x in configs_ft)
atoms_list_ft = []
for path in configs_ft:
atoms_list_ft += ase.io.read(path, index=":")
return atoms_list_ft
if configs_ft is None:
return []
raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}")
def _filter_pretraining_data(
atoms: list[ase.Atoms],
filtering_type: FilteringType,
all_species_ft: List[str],
) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]:
logging.info(
"Filtering configurations based on the finetuning set, "
f"filtering type: {filtering_type}, elements: {all_species_ft}"
)
passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms]
assert len(passes_filter) == len(atoms), "Filtering failed"
filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes]
remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes]
return filtered_atoms, remaining_atoms, passes_filter
def _get_random_configs(
num_samples: int,
atoms: List[ase.Atoms],
) -> list[ase.Atoms]:
if num_samples > len(atoms):
raise ValueError(
f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})"
)
indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False)
return [atoms[i] for i in indices]
def _load_descriptors(
atoms: List[ase.Atoms],
passes_filter: List[bool],
descriptors_path: str | None,
calc: MACECalculator | None,
full_data_length: int,
) -> None:
if descriptors_path is not None:
logging.info(f"Loading descriptors from {descriptors_path}")
descriptors = np.load(descriptors_path, allow_pickle=True)
assert sum(passes_filter) == len(atoms)
if len(descriptors) != full_data_length:
raise ValueError(
f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})"
"Please provide descriptors for all configurations"
)
required_descriptors = [
descriptors[i] for i, passes in enumerate(passes_filter) if passes
]
for i, atoms_ in enumerate(atoms):
atoms_.info["mace_descriptors"] = required_descriptors[i]
else:
logging.info("Calculating descriptors")
if calc is None:
raise ValueError("MACECalculator must be provided to calculate descriptors")
calculate_descriptors(atoms, calc)
def _maybe_save_descriptors(
atoms: List[ase.Atoms],
output_path: str,
) -> None:
"""
Save the descriptors if they are present in the atoms objects.
Also, delete the descriptors from the atoms objects.
"""
if all("mace_descriptors" in x.info for x in atoms):
descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy")
logging.info(f"Saving descriptors at {descriptor_save_path}")
descriptors_list = [x.info["mace_descriptors"] for x in atoms]
np.save(descriptor_save_path, descriptors_list, allow_pickle=True)
for x in atoms:
del x.info["mace_descriptors"]
def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]:
try:
fps_pt = FPS(atoms, num_samples)
idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations")
return [atoms[i] for i in idx_pt]
except Exception as e: # pylint: disable=W0703
logging.error(f"FPS failed, selecting random configurations instead: {e}")
return _get_random_configs(num_samples, atoms)
def _subsample_data(
filtered_atoms: List[ase.Atoms],
remaining_atoms: List[ase.Atoms],
passes_filter: List[bool],
num_samples: int | None,
subselect: SubselectType,
descriptors_path: str | None,
calc: MACECalculator | None,
) -> List[ase.Atoms]:
if num_samples is None or num_samples == len(filtered_atoms):
logging.info(
f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations"
)
return filtered_atoms
if num_samples > len(filtered_atoms):
num_sample_randomly = num_samples - len(filtered_atoms)
logging.info(
f"Number of configurations after filtering {len(filtered_atoms)} "
f"is less than the number of samples {num_samples}, "
f"selecting {num_sample_randomly} random configurations for the rest."
)
return filtered_atoms + _get_random_configs(
num_sample_randomly, remaining_atoms
)
if num_samples == 0:
raise ValueError("Number of samples must be greater than 0")
if subselect == SubselectType.FPS:
_load_descriptors(
filtered_atoms,
passes_filter,
descriptors_path,
calc,
full_data_length=len(filtered_atoms) + len(remaining_atoms),
)
logging.info("Selecting configurations using Farthest Point Sampling")
return _maybe_fps(filtered_atoms, num_samples)
if subselect == SubselectType.RANDOM:
return _get_random_configs(num_samples, filtered_atoms)
raise ValueError(f"Invalid subselect type: {subselect}")
def _write_metadata(
atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None
) -> None:
for a in atoms:
a.info["pretrained"] = pretrained
a.info["config_weight"] = config_weight
if head is not None:
a.info["head"] = head
def select_samples(
settings: SelectionSettings,
) -> None:
np.random.seed(settings.seed)
torch.manual_seed(settings.seed)
calc = _load_calc(
settings.model, settings.device, settings.default_dtype, settings.subselect
)
atoms_list_ft = _read_finetuning_configs(settings.configs_ft)
all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers)
if settings.filtering_type is not FilteringType.NONE and not all_species_ft:
raise ValueError(
"Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag."
)
atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore
filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data(
atoms_list_pt, settings.filtering_type, all_species_ft
)
subsampled_atoms = _subsample_data(
filtered_pt_atoms,
remaining_atoms,
passes_filter,
settings.num_samples,
settings.subselect,
settings.descriptors,
calc,
)
_maybe_save_descriptors(subsampled_atoms, settings.output)
_write_metadata(
subsampled_atoms,
pretrained=True,
config_weight=settings.weight_pt,
head=settings.head_pt,
)
_write_metadata(
atoms_list_ft,
pretrained=False,
config_weight=settings.weight_ft,
head=settings.head_ft,
)
logging.info("Saving the selected configurations")
ase.io.write(settings.output, subsampled_atoms, format="extxyz")
logging.info("Saving a combined XYZ file")
atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft
ase.io.write(
settings.output.replace(".xyz", "_combined.xyz"),
atoms_fps_pt_ft,
format="extxyz",
)
def main():
args = parse_args()
settings = SelectionSettings(**vars(args))
select_samples(settings)
if __name__ == "__main__":
main()
import argparse
import dataclasses
import glob
import json
import os
import re
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
plt.rcParams.update({"font.size": 8})
plt.style.use("seaborn-v0_8-paper")
colors = [
"#1f77b4", # muted blue
"#d62728", # brick red
"#ff7f0e", # safety orange
"#2ca02c", # cooked asparagus green
"#9467bd", # muted purple
"#8c564b", # chestnut brown
"#e377c2", # raspberry yogurt pink
"#7f7f7f", # middle gray
"#bcbd22", # curry yellow-green
"#17becf", # blue-teal
]
@dataclasses.dataclass
class RunInfo:
name: str
seed: int
name_re = re.compile(r"(?P<name>.+)_run-(?P<seed>\d+)_train.txt")
def parse_path(path: str) -> RunInfo:
match = name_re.match(os.path.basename(path))
if not match:
raise RuntimeError(f"Cannot parse {path}")
return RunInfo(name=match.group("name"), seed=int(match.group("seed")))
def parse_training_results(path: str) -> List[dict]:
run_info = parse_path(path)
results = []
with open(path, mode="r", encoding="utf-8") as f:
for line in f:
d = json.loads(line)
d["name"] = run_info.name
d["seed"] = run_info.seed
results.append(d)
return results
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Plot mace training statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--path", help="Path to results file (.txt) or directory.", required=True
)
parser.add_argument(
"--min_epoch", help="Minimum epoch.", default=0, type=int, required=False
)
parser.add_argument(
"--start_stage_two",
"--start_swa",
help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.",
default=None,
type=int,
required=False,
dest="start_swa",
)
parser.add_argument(
"--linear",
help="Whether to plot linear instead of log scales.",
default=False,
required=False,
action="store_true",
)
parser.add_argument(
"--error_bars",
help="Whether to plot standard deviations.",
default=False,
required=False,
action="store_true",
)
parser.add_argument(
"--keys",
help="Comma-separated list of keys to plot.",
default="rmse_e,rmse_f",
type=str,
required=False,
)
parser.add_argument(
"--output_format",
help="What file type to save plot as",
default="png",
type=str,
required=False,
)
parser.add_argument(
"--heads",
help="Comma-separated name of the heads used for multihead training",
default=None,
type=str,
required=False,
)
return parser.parse_args()
def plot(
data: pd.DataFrame,
min_epoch: int,
output_path: str,
output_format: str,
linear: bool,
start_swa: int,
error_bars: bool,
keys: str,
heads: str,
) -> None:
"""
Plots train,validation loss and errors as a function of epoch.
min_epoch: minimum epoch to plot.
output_path: path to save the plot.
output_format: format to save the plot.
start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins.
error_bars: whether to plot standard deviation of loss.
linear: whether to plot in linear scale or logscale (default).
keys: Values to plot.
heads: Heads used for multihead training.
"""
labels = {
"mae_e": "MAE E [meV]",
"mae_e_per_atom": "MAE E/atom [meV]",
"rmse_e": "RMSE E [meV]",
"rmse_e_per_atom": "RMSE E/atom [meV]",
"q95_e": "Q95 E [meV]",
"mae_f": "MAE F [meV / A]",
"rel_mae_f": "Relative MAE F [meV / A]",
"rmse_f": "RMSE F [meV / A]",
"rel_rmse_f": "Relative RMSE F [meV / A]",
"q95_f": "Q95 F [meV / A]",
"mae_stress": "MAE Stress",
"rmse_stress": "RMSE Stress [meV / A^3]",
"rmse_virials_per_atom": " RMSE virials/atom [meV]",
"mae_virials": "MAE Virials [meV]",
"rmse_mu_per_atom": "RMSE MU/atom [mDebye]",
}
data = data[data["epoch"] > min_epoch]
if heads is None:
data = (
data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index()
)
valid_data = data[data["mode"] == "eval"]
valid_data_dict = {"default": valid_data}
train_data = data[data["mode"] == "opt"]
else:
heads = heads.split(",")
# Separate eval and opt data
valid_data = (
data[data["mode"] == "eval"]
.groupby(["name", "mode", "epoch", "head"])
.agg(["mean", "std"])
.reset_index()
)
train_data = (
data[data["mode"] == "opt"]
.groupby(["name", "mode", "epoch"])
.agg(["mean", "std"])
.reset_index()
)
valid_data_dict = {
head: valid_data[valid_data["head"] == head] for head in heads
}
for head, valid_data in valid_data_dict.items():
fig, axes = plt.subplots(
nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True
)
# ---- Plot loss ----
ax = axes[0]
ax.plot(
train_data["epoch"],
train_data["loss"]["mean"],
color=colors[1],
linewidth=1,
)
ax.set_ylabel("Training Loss", color=colors[1])
ax.set_yscale("log")
ax2 = ax.twinx()
ax2.plot(
valid_data["epoch"],
valid_data["loss"]["mean"],
color=colors[0],
linewidth=1,
)
ax2.set_ylabel("Validation Loss", color=colors[0])
if not linear:
ax.set_yscale("log")
ax2.set_yscale("log")
if error_bars:
ax.fill_between(
train_data["epoch"],
train_data["loss"]["mean"] - train_data["loss"]["std"],
train_data["loss"]["mean"] + train_data["loss"]["std"],
alpha=0.3,
color=colors[1],
)
ax.fill_between(
valid_data["epoch"],
valid_data["loss"]["mean"] - valid_data["loss"]["std"],
valid_data["loss"]["mean"] + valid_data["loss"]["std"],
alpha=0.3,
color=colors[0],
)
if start_swa is not None:
ax.axvline(
start_swa,
color="black",
linestyle="dashed",
linewidth=1,
alpha=0.6,
label="Stage Two Starts",
)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend(loc="upper right", fontsize=4)
ax.grid(True, linestyle="--", alpha=0.5)
# ---- Plot selected keys ----
ax = axes[1]
twin_axes = []
for i, key in enumerate(keys.split(",")):
color = colors[(i + 3)]
label = labels.get(key, key)
if i == 0:
main_ax = ax
else:
main_ax = ax.twinx()
main_ax.spines.right.set_position(("outward", 40 * (i - 1)))
twin_axes.append(main_ax)
main_ax.plot(
valid_data["epoch"],
valid_data[key]["mean"] * 1e3,
color=color,
label=label,
linewidth=1,
)
if error_bars:
main_ax.fill_between(
valid_data["epoch"],
(valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3,
(valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3,
alpha=0.3,
color=color,
)
main_ax.set_ylabel(label, color=color)
main_ax.tick_params(axis="y", colors=color)
if start_swa is not None:
ax.axvline(
start_swa,
color="black",
linestyle="dashed",
linewidth=1,
alpha=0.6,
label="Stage Two Starts",
)
ax.set_xlabel("Epoch")
ax.set_xlim(left=min_epoch)
ax.grid(True, linestyle="--", alpha=0.5)
fig.savefig(
f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight"
)
plt.close(fig)
def get_paths(path: str) -> List[str]:
if os.path.isfile(path):
return [path]
paths = glob.glob(os.path.join(path, "*_train.txt"))
if len(paths) == 0:
raise RuntimeError(f"Cannot find results in '{path}'")
return paths
def main() -> None:
args = parse_args()
run(args)
def run(args: argparse.Namespace) -> None:
data = pd.DataFrame(
results
for path in get_paths(args.path)
for results in parse_training_results(path)
)
for name, group in data.groupby("name"):
plot(
group,
min_epoch=args.min_epoch,
output_path=name,
output_format=args.output_format,
linear=args.linear,
start_swa=args.start_swa,
error_bars=args.error_bars,
keys=args.keys,
heads=args.heads,
)
if __name__ == "__main__":
main()
# This file loads an xyz dataset and prepares
# new hdf5 file that is ready for training with on-the-fly dataloading
import argparse
import ast
import json
import logging
import multiprocessing as mp
import os
import random
from functools import partial
from glob import glob
from typing import List, Tuple
import h5py
import numpy as np
import tqdm
from mace import data, tools
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.data.utils import save_configurations_as_HDF5
from mace.modules import compute_statistics
from mace.tools import torch_geometric
from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz
from mace.tools.utils import AtomicNumberTable
def compute_stats_target(
file: str,
z_table: AtomicNumberTable,
r_max: float,
atomic_energies: Tuple,
batch_size: int,
):
train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max)
train_loader = torch_geometric.dataloader.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)
avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies)
output = [avg_num_neighbors, mean, std]
return output
def pool_compute_stats(inputs: List):
path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs
with mp.Pool(processes=num_process) as pool:
re = [
pool.apply_async(
compute_stats_target,
args=(
file,
z_table,
r_max,
atomic_energies,
batch_size,
),
)
for file in glob(path_to_files + "/*")
]
pool.close()
pool.join()
results = [r.get() for r in tqdm.tqdm(re)]
if not results:
raise ValueError(
"No results were computed. Check if the input files exist and are readable."
)
# Separate avg_num_neighbors, mean, and std
avg_num_neighbors = np.mean([r[0] for r in results])
means = np.array([r[1] for r in results])
stds = np.array([r[2] for r in results])
# Compute averages
mean = np.mean(means, axis=0).item()
std = np.mean(stds, axis=0).item()
return avg_num_neighbors, mean, std
def split_array(a: np.ndarray, max_size: int):
drop_last = False
if len(a) % 2 == 1:
a = np.append(a, a[-1])
drop_last = True
factors = get_prime_factors(len(a))
max_factor = 1
for i in range(1, len(factors) + 1):
for j in range(0, len(factors) - i + 1):
if np.prod(factors[j : j + i]) <= max_size:
test = np.prod(factors[j : j + i])
max_factor = max(test, max_factor)
return np.array_split(a, max_factor), drop_last
def get_prime_factors(n: int):
factors = []
for i in range(2, n + 1):
while n % i == 0:
factors.append(i)
n = n / i
return factors
# Define Task for Multiprocessiing
def multi_train_hdf5(process, args, split_train, drop_last):
with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)
def multi_valid_hdf5(process, args, split_valid, drop_last):
with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)
def multi_test_hdf5(process, name, args, split_test, drop_last):
with h5py.File(
args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w"
) as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)
def main() -> None:
"""
This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading
"""
args = tools.build_preprocess_arg_parser().parse_args()
run(args)
def run(args: argparse.Namespace):
"""
This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading
"""
# currently support only command line property_key syntax
args.key_specification = KeySpecification()
update_keyspec_from_kwargs(args.key_specification, vars(args))
# Setup
tools.set_seeds(args.seed)
random.seed(args.seed)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler()],
)
try:
config_type_weights = ast.literal_eval(args.config_type_weights)
assert isinstance(config_type_weights, dict)
except Exception as e: # pylint: disable=W0703
logging.warning(
f"Config type weights not specified correctly ({e}), using Default"
)
config_type_weights = {"Default": 1.0}
folders = ["train", "val", "test"]
for sub_dir in folders:
if not os.path.exists(args.h5_prefix + sub_dir):
os.makedirs(args.h5_prefix + sub_dir)
# Data preparation
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
config_type_weights=config_type_weights,
test_path=args.test_file,
seed=args.seed,
key_specification=args.key_specification,
head_name=None,
)
# Atomic number table
# yapf: disable
if args.atomic_numbers is None:
z_table = tools.get_atomic_number_table_from_zs(
z
for configs in (collections.train, collections.valid)
for config in configs
for z in config.atomic_numbers
)
else:
logging.info("Using atomic numbers from command line argument")
zs_list = ast.literal_eval(args.atomic_numbers)
assert isinstance(zs_list, list)
z_table = tools.get_atomic_number_table_from_zs(zs_list)
logging.info("Preparing training set")
if args.shuffle:
random.shuffle(collections.train)
# split collections.train into batches and save them to hdf5
split_train = np.array_split(collections.train,args.num_process)
drop_last = False
if len(collections.train) % 2 == 1:
drop_last = True
multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_train_hdf5_, args=[i])
p.start()
processes.append(p)
for i in processes:
i.join()
if args.compute_statistics:
logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
# Remove atomic energies if element not in z_table
removed_atomic_energies = {}
for z in list(atomic_energies_dict):
if z not in z_table.zs:
removed_atomic_energies[z] = atomic_energies_dict.pop(z)
if len(removed_atomic_energies) > 0:
logging.warning("Atomic energies for elements not present in the atomic number table have been removed.")
logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}")
logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.")
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic Energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")
# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str([int(z) for z in z_table.zs]),
"r_max": args.r_max,
}
with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)
logging.info("Preparing validation set")
if args.shuffle:
random.shuffle(collections.valid)
split_valid = np.array_split(collections.valid, args.num_process)
drop_last = False
if len(collections.valid) % 2 == 1:
drop_last = True
multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_valid_hdf5_, args=[i])
p.start()
processes.append(p)
for i in processes:
i.join()
if args.test_file is not None:
logging.info("Preparing test sets")
for name, subset in collections.tests:
drop_last = False
if len(subset) % 2 == 1:
drop_last = True
split_test = np.array_split(subset, args.num_process)
multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_test_hdf5_, args=[i, name])
p.start()
processes.append(p)
for i in processes:
i.join()
if __name__ == "__main__":
main()
###########################################################################################
# Training script for MACE
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import ast
import glob
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import List, Optional
import torch.distributed
import torch.nn.functional
from e3nn.util import jit
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import LBFGS
from torch.utils.data import ConcatDataset
from torch_ema import ExponentialMovingAverage
import mace
from mace import data, tools
from mace.calculators.foundations_models import mace_mp, mace_off
from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.cli.visualise_train import TrainingPlotter
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.tools import torch_geometric
from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import (
HeadConfig,
assemble_mp_data,
dict_head_to_dataclass,
prepare_default_head,
prepare_pt_head,
)
from mace.tools.run_train_utils import (
combine_datasets,
load_dataset_for_path,
normalize_file_paths,
)
from mace.tools.scripts_utils import (
LRScheduler,
SubsetCollection,
check_path_ase_read,
convert_to_json_format,
dict_to_array,
extract_config_mace_model,
get_atomic_energies,
get_avg_num_neighbors,
get_config_type_weights,
get_dataset_from_xyz,
get_files_with_suffix,
get_loss_fn,
get_optimizer,
get_params_options,
get_swa,
print_git_commit,
remove_pt_head,
setup_wandb,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable
def main() -> None:
"""
This script runs the training/fine tuning for mace
"""
args = tools.build_default_arg_parser().parse_args()
run(args)
def run(args) -> None:
"""
This script runs the training/fine tuning for mace
"""
tag = tools.get_tag(name=args.name, seed=args.seed)
args, input_log_messages = tools.check_args(args)
# default keyspec to update using heads dictionary
args.key_specification = KeySpecification()
update_keyspec_from_kwargs(args.key_specification, vars(args))
if args.device == "xpu":
try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
raise ImportError(
"Error: Intel extension for PyTorch not found, but XPU device was specified"
) from e
if args.distributed:
try:
distr_env = DistributedEnvironment()
except Exception as e: # pylint: disable=W0703
logging.error(f"Failed to initialize distributed environment: {e}")
return
world_size = distr_env.world_size
local_rank = distr_env.local_rank
rank = distr_env.rank
if rank == 0:
print(distr_env)
torch.distributed.init_process_group(backend="nccl")
else:
rank = int(0)
# Setup
tools.set_seeds(args.seed)
tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank)
logging.info("===========VERIFYING SETTINGS===========")
for message, loglevel in input_log_messages:
logging.log(level=loglevel, msg=message)
if args.distributed:
torch.cuda.set_device(local_rank)
logging.info(f"Process group initialized: {torch.distributed.is_initialized()}")
logging.info(f"Processes: {world_size}")
try:
logging.info(f"MACE version: {mace.__version__}")
except AttributeError:
logging.info("Cannot find MACE version, please install MACE via pip")
logging.debug(f"Configuration: {args}")
tools.set_default_dtype(args.default_dtype)
device = tools.init_device(args.device)
commit = print_git_commit()
model_foundation: Optional[torch.nn.Module] = None
foundation_model_avg_num_neighbors = 0
if args.foundation_model is not None:
if args.foundation_model in ["small", "medium", "large"]:
logging.info(
f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint."
)
calc = mace_mp(
model=args.foundation_model,
device=args.device,
default_dtype=args.default_dtype,
)
model_foundation = calc.models[0]
elif args.foundation_model in ["small_off", "medium_off", "large_off"]:
model_type = args.foundation_model.split("_")[0]
logging.info(
f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license."
)
calc = mace_off(
model=model_type,
device=args.device,
default_dtype=args.default_dtype,
)
model_foundation = calc.models[0]
else:
model_foundation = torch.load(
args.foundation_model, map_location=args.device
)
logging.info(
f"Using foundation model {args.foundation_model} as initial checkpoint."
)
args.r_max = model_foundation.r_max.item()
foundation_model_avg_num_neighbors = model_foundation.interactions[
0
].avg_num_neighbors
if (
args.foundation_model not in ["small", "medium", "large"]
and args.pt_train_file is None
):
logging.warning(
"Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file."
)
args.multiheads_finetuning = False
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head
if not args.force_mh_ft_lr:
logging.info(
"Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True."
)
args.lr = 0.0001
args.ema = True
args.ema_decay = 0.99999
logging.info(
"Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True"
)
if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1:
logging.warning(
"Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head."
)
model_foundation = remove_pt_head(
model_foundation, args.foundation_head
)
else:
args.multiheads_finetuning = False
if args.heads is not None:
args.heads = ast.literal_eval(args.heads)
for _, head_dict in args.heads.items():
# priority is global args < head property_key values < head info_keys+arrays_keys
head_keyspec = deepcopy(args.key_specification)
update_keyspec_from_kwargs(head_keyspec, head_dict)
head_keyspec.update(
info_keys=head_dict.get("info_keys", {}),
arrays_keys=head_dict.get("arrays_keys", {}),
)
head_dict["key_specification"] = head_keyspec
else:
args.heads = prepare_default_head(args)
if args.multiheads_finetuning:
pt_keyspec = (
args.heads["pt_head"]["key_specification"]
if "pt_head" in args.heads
else deepcopy(args.key_specification)
)
args.heads["pt_head"] = prepare_pt_head(
args, pt_keyspec, foundation_model_avg_num_neighbors
)
logging.info("===========LOADING INPUT DATA===========")
heads = list(args.heads.keys())
logging.info(f"Using heads: {heads}")
logging.info("Using the key specifications to parse data:")
for name, head_dict in args.heads.items():
head_keyspec = head_dict["key_specification"]
logging.info(f"{name}: {head_keyspec}")
head_configs: List[HeadConfig] = []
for head, head_args in args.heads.items():
logging.info(f"============= Processing head {head} ===========")
head_config = dict_head_to_dataclass(head_args, head, args)
# Handle train_file and valid_file - normalize to lists
if hasattr(head_config, "train_file") and head_config.train_file is not None:
head_config.train_file = normalize_file_paths(head_config.train_file)
if hasattr(head_config, "valid_file") and head_config.valid_file is not None:
head_config.valid_file = normalize_file_paths(head_config.valid_file)
if hasattr(head_config, "test_file") and head_config.test_file is not None:
head_config.test_file = normalize_file_paths(head_config.test_file)
if (
head_config.statistics_file is not None
and head_config.head_name != "pt_head"
):
with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514
statistics = json.load(f)
logging.info("Using statistics json file")
head_config.atomic_numbers = statistics["atomic_numbers"]
head_config.mean = statistics["mean"]
head_config.std = statistics["std"]
head_config.avg_num_neighbors = statistics["avg_num_neighbors"]
head_config.compute_avg_num_neighbors = False
if isinstance(statistics["atomic_energies"], str) and statistics[
"atomic_energies"
].endswith(".json"):
with open(statistics["atomic_energies"], "r", encoding="utf-8") as f:
atomic_energies = json.load(f)
head_config.E0s = atomic_energies
head_config.atomic_energies_dict = ast.literal_eval(atomic_energies)
else:
head_config.E0s = statistics["atomic_energies"]
head_config.atomic_energies_dict = ast.literal_eval(
statistics["atomic_energies"]
)
if head_config.train_file == ["mp"]:
assert (
head_config.head_name == "pt_head"
), "Only pt_head should use mp as train_file"
logging.info(
"Using the full Materials Project data for replay. You can construct a different subset using `fine_tuning_select.py` script."
)
collections = assemble_mp_data(args, head_config, tag)
head_config.collections = collections
elif any(check_path_ase_read(f) for f in head_config.train_file):
train_files_ase_list = [
f for f in head_config.train_file if check_path_ase_read(f)
]
valid_files_ase_list = None
test_files_ase_list = None
if head_config.valid_file:
valid_files_ase_list = [
f for f in head_config.valid_file if check_path_ase_read(f)
]
if head_config.test_file:
test_files_ase_list = [
f for f in head_config.test_file if check_path_ase_read(f)
]
config_type_weights = get_config_type_weights(
head_config.config_type_weights
)
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=train_files_ase_list,
valid_path=valid_files_ase_list,
valid_fraction=head_config.valid_fraction,
config_type_weights=config_type_weights,
test_path=test_files_ase_list,
seed=args.seed,
key_specification=head_config.key_specification,
head_name=head_config.head_name,
keep_isolated_atoms=head_config.keep_isolated_atoms,
)
head_config.collections = SubsetCollection(
train=collections.train,
valid=collections.valid,
tests=collections.tests,
)
head_config.atomic_energies_dict = atomic_energies_dict
logging.info(
f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, "
f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}],"
)
head_configs.append(head_config)
if all(
check_path_ase_read(head_config.train_file[0]) for head_config in head_configs
):
size_collections_train = sum(
len(head_config.collections.train) for head_config in head_configs
)
size_collections_valid = sum(
len(head_config.collections.valid) for head_config in head_configs
)
if size_collections_train < args.batch_size:
logging.error(
f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})"
)
if size_collections_valid < args.valid_batch_size:
logging.warning(
f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})"
)
if args.multiheads_finetuning:
logging.info(
"==================Using multiheads finetuning mode=================="
)
args.loss = "universal"
all_ase_readable = all(
all(check_path_ase_read(f) for f in head_config.train_file)
for head_config in head_configs
)
head_config_pt = filter(lambda x: x.head_name == "pt_head", head_configs)
head_config_pt = next(head_config_pt, None)
assert head_config_pt is not None, "Pretraining head not found"
if all_ase_readable:
ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train)
if ratio_pt_ft < 0.1:
logging.warning(
f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, "
f"increasing the number of configurations in the fine-tuning heads by {int(0.1 / ratio_pt_ft)}"
)
for head_config in head_configs:
if head_config.head_name == "pt_head":
continue
head_config.collections.train += (
head_config.collections.train * int(0.1 / ratio_pt_ft)
)
logging.info(
f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}"
)
else:
logging.debug(
"Using LMDB/HDF5 datasets for pretraining or fine-tuning - skipping ratio check"
)
# Atomic number table
# yapf: disable
for head_config in head_configs:
if head_config.atomic_numbers is None:
assert all(check_path_ase_read(f) for f in head_config.train_file), "Must specify atomic_numbers when using .h5 or .aselmdb train_file input"
z_table_head = tools.get_atomic_number_table_from_zs(
z
for configs in (head_config.collections.train, head_config.collections.valid)
for config in configs
for z in config.atomic_numbers
)
head_config.atomic_numbers = z_table_head.zs
head_config.z_table = z_table_head
else:
if head_config.statistics_file is None:
logging.info("Using atomic numbers from command line argument")
else:
logging.info("Using atomic numbers from statistics file")
zs_list = ast.literal_eval(head_config.atomic_numbers)
assert isinstance(zs_list, list)
z_table_head = tools.AtomicNumberTable(zs_list)
head_config.atomic_numbers = zs_list
head_config.z_table = z_table_head
# yapf: enable
all_atomic_numbers = set()
for head_config in head_configs:
all_atomic_numbers.update(head_config.atomic_numbers)
z_table = AtomicNumberTable(sorted(list(all_atomic_numbers)))
if args.foundation_model_elements and model_foundation:
z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist()))
logging.info(f"Atomic Numbers used: {z_table.zs}")
# Atomic energies
atomic_energies_dict = {}
for head_config in head_configs:
if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0:
assert head_config.E0s is not None, "Atomic energies must be provided"
if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation":
atomic_energies_dict[head_config.head_name] = get_atomic_energies(
head_config.E0s, head_config.collections.train, head_config.z_table
)
elif head_config.E0s.lower() == "foundation":
assert args.foundation_model is not None
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict[head_config.head_name] = {
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
}
else:
atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table)
else:
atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict
# Atomic energies for multiheads finetuning
if args.multiheads_finetuning:
assert (
model_foundation is not None
), "Model foundation must be provided for multiheads finetuning"
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict["pt_head"] = {
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
}
heads = sorted(heads, key=lambda x: -1000 if x == "pt_head" else 0)
# Padding atomic energies if keeping all elements of the foundation model
if args.foundation_model_elements and model_foundation:
atomic_energies_dict_padded = {}
for head_name, head_energies in atomic_energies_dict.items():
energy_head_padded = {}
for z in z_table.zs:
energy_head_padded[z] = head_energies.get(z, 0.0)
atomic_energies_dict_padded[head_name] = energy_head_padded
atomic_energies_dict = atomic_energies_dict_padded
if args.model == "AtomicDipolesMACE":
atomic_energies = None
dipole_only = True
args.compute_dipole = True
args.compute_energy = False
args.compute_forces = False
args.compute_virials = False
args.compute_stress = False
else:
dipole_only = False
if args.model == "EnergyDipolesMACE":
args.compute_dipole = True
args.compute_energy = True
args.compute_forces = True
args.compute_virials = False
args.compute_stress = False
else:
args.compute_energy = True
args.compute_dipole = False
# atomic_energies: np.ndarray = np.array(
# [atomic_energies_dict[z] for z in z_table.zs]
# )
atomic_energies = dict_to_array(atomic_energies_dict, heads)
for head_config in head_configs:
try:
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
except KeyError as e:
raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e
# Load datasets for each head, supporting multiple files per head
valid_sets = {head: [] for head in heads}
train_sets = {head: [] for head in heads}
for head_config in head_configs:
train_datasets = []
logging.info(f"Processing datasets for head '{head_config.head_name}'")
ase_files = [f for f in head_config.train_file if check_path_ase_read(f)]
non_ase_files = [f for f in head_config.train_file if not check_path_ase_read(f)]
if ase_files:
dataset = load_dataset_for_path(
file_path=ase_files,
r_max=args.r_max,
z_table=z_table,
head_config=head_config,
heads=heads,
collection=head_config.collections.train,
)
train_datasets.append(dataset)
logging.debug(f"Successfully loaded dataset from ASE files: {ase_files}")
for file in non_ase_files:
dataset = load_dataset_for_path(
file_path=file,
r_max=args.r_max,
z_table=z_table,
head_config=head_config,
heads=heads,
)
train_datasets.append(dataset)
logging.debug(f"Successfully loaded dataset from non-ASE file: {file}")
if not train_datasets:
raise ValueError(f"No valid training datasets found for head {head_config.head_name}")
train_sets[head_config.head_name] = combine_datasets(train_datasets, head_config.head_name)
if head_config.valid_file:
valid_datasets = []
valid_ase_files = [f for f in head_config.valid_file if check_path_ase_read(f)]
valid_non_ase_files = [f for f in head_config.valid_file if not check_path_ase_read(f)]
if valid_ase_files:
valid_dataset = load_dataset_for_path(
file_path=valid_ase_files,
r_max=args.r_max,
z_table=z_table,
head_config=head_config,
heads=heads,
collection=head_config.collections.valid,
)
valid_datasets.append(valid_dataset)
logging.debug(f"Successfully loaded validation dataset from ASE files: {valid_ase_files}")
for valid_file in valid_non_ase_files:
valid_dataset = load_dataset_for_path(
file_path=valid_file,
r_max=args.r_max,
z_table=z_table,
head_config=head_config,
heads=heads,
)
valid_datasets.append(valid_dataset)
logging.debug(f"Successfully loaded validation dataset from {valid_file}")
# Combine validation datasets
if valid_datasets:
valid_sets[head_config.head_name] = combine_datasets(valid_datasets, f"{head_config.head_name}_valid")
logging.info(f"Combined validation datasets for {head_config.head_name}")
# If no valid file is provided but collection exist, use the validation set from the collection
if head_config.valid_file is None and head_config.collections.valid:
valid_sets[head_config.head_name] = [
data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max, heads=heads
)
for config in head_config.collections.valid
]
if not valid_sets[head_config.head_name]:
raise ValueError(f"No valid datasets found for head {head_config.head_name}, please provide a valid_file or a valid_fraction")
# Create data loader for this head
if isinstance(train_sets[head_config.head_name], list):
dataset_size = len(train_sets[head_config.head_name])
else:
dataset_size = len(train_sets[head_config.head_name])
logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}")
train_loader_head = torch_geometric.dataloader.DataLoader(
dataset=train_sets[head_config.head_name],
batch_size=args.batch_size,
shuffle=True,
drop_last=(not args.lbfgs),
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
)
head_config.train_loader = train_loader_head
# concatenate all the trainsets
train_set = ConcatDataset([train_sets[head] for head in heads])
train_sampler, valid_sampler = None, None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_set,
num_replicas=world_size,
rank=rank,
shuffle=True,
drop_last=(not args.lbfgs),
seed=args.seed,
)
valid_samplers = {}
for head, valid_set in valid_sets.items():
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_set,
num_replicas=world_size,
rank=rank,
shuffle=True,
drop_last=True,
seed=args.seed,
)
valid_samplers[head] = valid_sampler
train_loader = torch_geometric.dataloader.DataLoader(
dataset=train_set,
batch_size=args.batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
drop_last=(train_sampler is None and not args.lbfgs),
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
)
valid_loaders = {heads[i]: None for i in range(len(heads))}
if not isinstance(valid_sets, dict):
valid_sets = {"Default": valid_sets}
for head, valid_set in valid_sets.items():
valid_loaders[head] = torch_geometric.dataloader.DataLoader(
dataset=valid_set,
batch_size=args.valid_batch_size,
sampler=valid_samplers[head] if args.distributed else None,
shuffle=False,
drop_last=False,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
)
loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole)
args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device)
# Model
model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs)
model.to(device)
logging.debug(model)
logging.info(f"Total number of parameters: {tools.count_parameters(model)}")
logging.info("")
logging.info("===========OPTIMIZER INFORMATION===========")
logging.info(f"Using {args.optimizer.upper()} as parameter optimizer")
logging.info(f"Batch size: {args.batch_size}")
if args.ema:
logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}")
logging.info(
f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}"
)
logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}")
logging.info(loss_fn)
# Cueq
if args.enable_cueq:
logging.info("Converting model to CUEQ for accelerated training")
assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"]
model = run_e3nn_to_cueq(deepcopy(model), device=device)
# Optimizer
param_options = get_params_options(args, model)
optimizer: torch.optim.Optimizer
optimizer = get_optimizer(args, param_options)
if args.device == "xpu":
logging.info("Optimzing model and optimzier for XPU")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
logger = tools.MetricsLogger(
directory=args.results_dir, tag=tag + "_train"
) # pylint: disable=E1123
lr_scheduler = LRScheduler(optimizer, args)
swa: Optional[tools.SWAContainer] = None
swas = [False]
if args.swa:
swa, swas = get_swa(args, model, optimizer, swas, dipole_only)
checkpoint_handler = tools.CheckpointHandler(
directory=args.checkpoints_dir,
tag=tag,
keep=args.keep_checkpoints,
swa_start=args.start_swa,
)
start_epoch = 0
restart_lbfgs = False
opt_start_epoch = None
if args.restart_latest:
try:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=True,
device=device,
)
except Exception: # pylint: disable=W0703
try:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
except Exception: # pylint: disable=W0703
restart_lbfgs = True
if opt_start_epoch is not None:
start_epoch = opt_start_epoch
ema: Optional[ExponentialMovingAverage] = None
if args.ema:
ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)
else:
for group in optimizer.param_groups:
group["lr"] = args.lr
if args.lbfgs:
logging.info("Switching optimizer to LBFGS")
optimizer = LBFGS(model.parameters(),
history_size=200,
max_iter=20,
line_search_fn="strong_wolfe")
if restart_lbfgs:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
if opt_start_epoch is not None:
start_epoch = opt_start_epoch
if args.wandb:
setup_wandb(args)
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None
train_valid_data_loader = {}
for head_config in head_configs:
data_loader_name = "train_" + head_config.head_name
train_valid_data_loader[data_loader_name] = head_config.train_loader
for head, valid_loader in valid_loaders.items():
data_load_name = "valid_" + head
train_valid_data_loader[data_load_name] = valid_loader
if args.plot and args.plot_frequency > 0:
try:
plotter = TrainingPlotter(
results_dir=logger.path,
heads=heads,
table_type=args.error_table,
train_valid_data=train_valid_data_loader,
test_data={},
output_args=output_args,
device=device,
plot_frequency=args.plot_frequency,
distributed=args.distributed,
swa_start=swa.start if swa else None
)
except Exception as e: # pylint: disable=W0718
logging.debug(f"Creating Plotter failed: {e}")
else:
plotter = None
if args.dry_run:
logging.info("DRY RUN mode enabled. Stopping now.")
return
tools.train(
model=model,
loss_fn=loss_fn,
train_loader=train_loader,
valid_loaders=valid_loaders,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
checkpoint_handler=checkpoint_handler,
eval_interval=args.eval_interval,
start_epoch=start_epoch,
max_num_epochs=args.max_num_epochs,
logger=logger,
patience=args.patience,
save_all_checkpoints=args.save_all_checkpoints,
output_args=output_args,
device=device,
swa=swa,
ema=ema,
max_grad_norm=args.clip_grad,
log_errors=args.error_table,
log_wandb=args.wandb,
distributed=args.distributed,
distributed_model=distributed_model,
plotter=plotter,
train_sampler=train_sampler,
rank=rank,
)
logging.info("")
logging.info("===========RESULTS===========")
train_valid_data_loader = {}
for head_config in head_configs:
data_loader_name = "train_" + head_config.head_name
train_valid_data_loader[data_loader_name] = head_config.train_loader
for head, valid_loader in valid_loaders.items():
data_load_name = "valid_" + head
train_valid_data_loader[data_load_name] = valid_loader
test_sets = {}
stop_first_test = False
test_data_loader = {}
if all(
head_config.test_file == head_configs[0].test_file
for head_config in head_configs
) and head_configs[0].test_file is not None:
stop_first_test = True
if all(
head_config.test_dir == head_configs[0].test_dir
for head_config in head_configs
) and head_configs[0].test_dir is not None:
stop_first_test = True
for head_config in head_configs:
if all(check_path_ase_read(f) for f in head_config.train_file):
for name, subset in head_config.collections.tests:
test_sets[name] = [
data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max, heads=heads
)
for config in subset
]
if head_config.test_dir is not None:
if not args.multi_processed_test:
test_files = get_files_with_suffix(head_config.test_dir, "_test.h5")
for test_file in test_files:
name = os.path.splitext(os.path.basename(test_file))[0]
test_sets[name] = data.HDF5Dataset(
test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
)
else:
test_folders = glob(head_config.test_dir + "/*")
for folder in test_folders:
name = os.path.splitext(os.path.basename(test_file))[0]
test_sets[name] = data.dataset_from_sharded_hdf5(
folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
)
for test_name, test_set in test_sets.items():
test_sampler = None
if args.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_set,
num_replicas=world_size,
rank=rank,
shuffle=True,
drop_last=True,
seed=args.seed,
)
try:
drop_last = test_set.drop_last
except AttributeError as e: # pylint: disable=W0612
drop_last = False
test_loader = torch_geometric.dataloader.DataLoader(
test_set,
batch_size=args.valid_batch_size,
shuffle=(test_sampler is None),
drop_last=drop_last,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
)
test_data_loader[test_name] = test_loader
if stop_first_test:
break
for swa_eval in swas:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval,
device=device,
)
model.to(device)
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
model_to_evaluate = model if not args.distributed else distributed_model
if swa_eval:
logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation")
else:
logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation")
if rank == 0:
# Save entire model
if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model")
else:
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
model_to_save = deepcopy(model)
if args.enable_cueq:
print("RUNING CUEQ TO E3NN")
print("swa_eval", swa_eval)
model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device)
if args.save_cpu:
model_to_save = model_to_save.to("cpu")
torch.save(model_to_save, model_path)
extra_files = {
"commit.txt": commit.encode("utf-8") if commit is not None else b"",
"config.yaml": json.dumps(
convert_to_json_format(extract_config_mace_model(model))
),
}
if swa_eval:
torch.save(
model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model")
)
try:
path_complied = Path(args.model_dir) / (
args.name + "_stagetwo_compiled.model"
)
logging.info(f"Compiling model, saving metadata {path_complied}")
model_compiled = jit.compile(deepcopy(model_to_save))
torch.jit.save(
model_compiled,
path_complied,
_extra_files=extra_files,
)
except Exception as e: # pylint: disable=W0718
pass
else:
torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model"))
try:
path_complied = Path(args.model_dir) / (
args.name + "_compiled.model"
)
logging.info(f"Compiling model, saving metadata to {path_complied}")
model_compiled = jit.compile(deepcopy(model_to_save))
torch.jit.save(
model_compiled,
path_complied,
_extra_files=extra_files,
)
except Exception as e: # pylint: disable=W0718
pass
logging.info("Computing metrics for training, validation, and test sets")
for param in model.parameters():
param.requires_grad = False
skip_heads = args.skip_evaluate_heads.split(",") if args.skip_evaluate_heads else []
if skip_heads:
logging.info(f"Skipping evaluation for heads: {skip_heads}")
table_train_valid = create_error_table(
table_type=args.error_table,
all_data_loaders=train_valid_data_loader,
model=model_to_evaluate,
loss_fn=loss_fn,
output_args=output_args,
log_wandb=args.wandb,
device=device,
distributed=args.distributed,
skip_heads=skip_heads,
)
logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid))
if test_data_loader:
table_test = create_error_table(
table_type=args.error_table,
all_data_loaders=test_data_loader,
model=model_to_evaluate,
loss_fn=loss_fn,
output_args=output_args,
log_wandb=args.wandb,
device=device,
distributed=args.distributed,
)
logging.info("Error-table on TEST:\n" + str(table_test))
if args.plot:
try:
plotter = TrainingPlotter(
results_dir=logger.path,
heads=heads,
table_type=args.error_table,
train_valid_data=train_valid_data_loader,
test_data=test_data_loader,
output_args=output_args,
device=device,
plot_frequency=args.plot_frequency,
distributed=args.distributed,
swa_start=swa.start if swa else None
)
plotter.plot(epoch, model_to_evaluate, rank)
except Exception as e: # pylint: disable=W0718
logging.debug(f"Plotting failed: {e}")
if args.distributed:
torch.distributed.barrier()
logging.info("Done")
if args.distributed:
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()
from argparse import ArgumentParser
import torch
from mace.tools.scripts_utils import remove_pt_head
def main():
parser = ArgumentParser()
grp = parser.add_mutually_exclusive_group()
grp.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
grp.add_argument(
"--list_heads",
"-l",
action="store_true",
help="list names of the heads",
)
parser.add_argument(
"--target_device",
"-d",
help="target device, defaults to model's current device",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model.head_name, followed by .target_device if specified",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()
model = torch.load(args.model_file, map_location=args.target_device)
torch.set_default_dtype(next(model.parameters()).dtype)
if args.list_heads:
print("Available heads:")
print("\n".join([" " + h for h in model.heads]))
else:
if args.output_file is None:
args.output_file = (
args.model_file
+ "."
+ args.head_name
+ ("." + args.target_device if (args.target_device is not None) else "")
)
model_single = remove_pt_head(model, args.head_name)
if args.target_device is not None:
target_device = str(next(model.parameters()).device)
model_single.to(target_device)
torch.save(model_single, args.output_file)
if __name__ == "__main__":
main()
import json
import logging
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.distributed
from torchmetrics import Metric
plt.rcParams.update({"font.size": 8})
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above
colors = [
"#1f77b4", # muted blue
"#d62728", # brick red
"#7f7f7f", # middle gray
"#2ca02c", # cooked asparagus green
"#ff7f0e", # safety orange
"#9467bd", # muted purple
"#8c564b", # chestnut brown
"#e377c2", # raspberry yogurt pink
"#bcbd22", # curry yellow-green
"#17becf", # blue-teal
]
error_type = {
"TotalRMSE": (
[("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
),
"PerAtomRMSE": (
[("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
),
"PerAtomRMSEstressvirials": (
[
("rmse_e_per_atom", "RMSE E/atom [meV]"),
("rmse_f", "RMSE F [meV / A]"),
("rmse_stress", "RMSE Stress [meV / A^3]"),
],
[
("energy", "Energy per atom [eV]"),
("force", "Force [eV / A]"),
("stress", "Stress [eV / A^3]"),
],
),
"PerAtomMAEstressvirials": (
[
("mae_e_per_atom", "MAE E/atom [meV]"),
("mae_f", "MAE F [meV / A]"),
("mae_stress", "MAE Stress [meV / A^3]"),
],
[
("energy", "Energy per atom [eV]"),
("force", "Force [eV / A]"),
("stress", "Stress [eV / A^3]"),
],
),
"TotalMAE": (
[("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
),
"PerAtomMAE": (
[("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
),
"DipoleRMSE": (
[
("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"),
("rel_rmse_f", "Relative MU RMSE [%]"),
],
[("dipole", "Dipole per atom [Debye]")],
),
"DipoleMAE": (
[("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")],
[("dipole", "Dipole per atom [Debye]")],
),
"EnergyDipoleRMSE": (
[
("rmse_e_per_atom", "RMSE E/atom [meV]"),
("rmse_f", "RMSE F [meV / A]"),
("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"),
],
[
("energy", "Energy per atom [eV]"),
("force", "Force [eV / A]"),
("dipole", "Dipole per atom [Debye]"),
],
),
}
class TrainingPlotter:
def __init__(
self,
results_dir: str,
heads: List[str],
table_type: str,
train_valid_data: Dict,
test_data: Dict,
output_args: str,
device: str,
plot_frequency: int,
distributed: bool = False,
swa_start: Optional[int] = None,
):
self.results_dir = results_dir
self.heads = heads
self.table_type = table_type
self.train_valid_data = train_valid_data
self.test_data = test_data
self.output_args = output_args
self.device = device
self.plot_frequency = plot_frequency
self.distributed = distributed
self.swa_start = swa_start
def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None:
# All ranks process data through model_inference
train_valid_dict = model_inference(
self.train_valid_data,
model,
self.output_args,
self.device,
self.distributed,
)
test_dict = model_inference(
self.test_data, model, self.output_args, self.device, self.distributed
)
# Only rank 0 creates and saves plots
if rank != 0:
return
data = pd.DataFrame(
results for results in parse_training_results(self.results_dir)
)
labels, quantities = error_type[self.table_type]
for head in self.heads:
fig = plt.figure(layout="constrained", figsize=(10, 6))
fig.suptitle(
f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16
)
subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05)
axsTop = subfigs[0].subplots(1, 2, sharey=False)
axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False)
plot_epoch_dependence(axsTop, data, head, model_epoch, labels)
# Use the pre-computed results for plotting
plot_inference_from_results(
axsBottom, train_valid_dict, test_dict, head, quantities
)
if self.swa_start is not None:
# Add vertical lines to both axes
for ax in axsTop:
ax.axvline(
self.swa_start,
color="black",
linestyle="dashed",
linewidth=1,
alpha=0.6,
label="Stage Two Starts",
)
stage = "stage_two" if self.swa_start < model_epoch else "stage_one"
else:
stage = "stage_one"
axsTop[0].legend(loc="best")
# Save the figure using the appropriate stage in the filename
filename = f"{self.results_dir[:-4]}_{head}_{stage}.png"
fig.savefig(filename, dpi=300, bbox_inches="tight")
plt.close(fig)
def parse_training_results(path: str) -> List[dict]:
results = []
with open(path, mode="r", encoding="utf-8") as f:
for line in f:
try:
d = json.loads(line.strip()) # Ensure it's valid JSON
results.append(d)
except json.JSONDecodeError:
print(
f"Skipping invalid line: {line.strip()}"
) # Handle non-JSON lines gracefully
return results
def plot_epoch_dependence(
axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str]
) -> None:
valid_data = (
data[data["mode"] == "eval"]
.groupby(["mode", "epoch", "head"])
.agg(["mean", "std"])
.reset_index()
)
valid_data = valid_data[valid_data["head"] == head]
train_data = (
data[data["mode"] == "opt"]
.groupby(["mode", "epoch"])
.agg(["mean", "std"])
.reset_index()
)
# ---- Plot loss ----
ax = axes[0]
ax.plot(
train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1
)
ax.set_ylabel("Training Loss", color=colors[1])
ax.set_yscale("log")
ax2 = ax.twinx()
ax2.plot(
valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1
)
ax2.set_ylabel("Validation Loss", color=colors[0])
ax2.set_yscale("log")
ax.axvline(
model_epoch,
color="black",
linestyle="solid",
linewidth=1,
alpha=0.8,
label="Loaded Model",
)
ax.set_xlabel("Epoch")
ax.grid(True, linestyle="--", alpha=0.5)
# ---- Plot selected keys ----
ax = axes[1]
twin_axes = []
for i, label in enumerate(labels):
color = colors[(i + 3)]
key, axis_label = label
if i == 0:
main_ax = ax
else:
main_ax = ax.twinx()
main_ax.spines.right.set_position(("outward", 60 * (i - 1)))
twin_axes.append(main_ax)
main_ax.plot(
valid_data["epoch"],
valid_data[key]["mean"] * 1e3,
color=color,
label=label,
linewidth=1,
)
main_ax.set_yscale("log")
main_ax.set_ylabel(axis_label, color=color)
main_ax.tick_params(axis="y", colors=color)
ax.axvline(
model_epoch,
color="black",
linestyle="solid",
linewidth=1,
alpha=0.8,
label="Loaded Model",
)
ax.set_xlabel("Epoch")
ax.grid(True, linestyle="--", alpha=0.5)
# INFERENCE=========
def plot_inference_from_results(
axes: np.ndarray,
train_valid_dict: dict,
test_dict: dict,
head: str,
quantities: List[str],
) -> None:
for ax, quantity in zip(axes, quantities):
key, label = quantity
# Store legend handles to avoid duplicates
legend_labels = {}
# Plot train/valid data (each entry keeps its own name)
for name, result in train_valid_dict.items():
if "train" in name:
fixed_color_train_valid = colors[1]
marker = "x"
else:
fixed_color_train_valid = colors[0]
marker = "+"
if head not in name:
continue
# Initialize scatter to None
scatter = None
if key == "energy" and "energy" in result:
scatter = ax.scatter(
result["energy"]["reference_per_atom"],
result["energy"]["predicted_per_atom"],
marker=marker,
color=fixed_color_train_valid,
label=name,
)
elif key == "force" and "forces" in result:
scatter = ax.scatter(
result["forces"]["reference"],
result["forces"]["predicted"],
marker=marker,
color=fixed_color_train_valid,
label=name,
)
elif key == "stress" and "stress" in result:
scatter = ax.scatter(
result["stress"]["reference"],
result["stress"]["predicted"],
marker=marker,
color=fixed_color_train_valid,
label=name,
)
elif key == "virials" and "virials" in result:
scatter = ax.scatter(
result["virials"]["reference_per_atom"],
result["virials"]["predicted_per_atom"],
marker=marker,
color=fixed_color_train_valid,
label=name,
)
elif key == "dipole" and "dipole" in result:
scatter = ax.scatter(
result["dipole"]["reference_per_atom"],
result["dipole"]["predicted_per_atom"],
marker=marker,
color=fixed_color_train_valid,
label=name,
)
# Add each train/valid dataset's name to the legend if scatter was assigned
if scatter is not None:
legend_labels[name] = scatter
fixed_color_test = colors[2] # Color for test dataset
# Plot test data (single legend entry)
for name, result in test_dict.items():
# Initialize scatter to None to avoid possibly used before assignment
scatter = None
if key == "energy" and "energy" in result:
scatter = ax.scatter(
result["energy"]["reference_per_atom"],
result["energy"]["predicted_per_atom"],
marker="o",
color=fixed_color_test,
label="Test",
)
elif key == "force" and "forces" in result:
scatter = ax.scatter(
result["forces"]["reference"],
result["forces"]["predicted"],
marker="o",
color=fixed_color_test,
label="Test",
)
elif key == "stress" and "stress" in result:
scatter = ax.scatter(
result["stress"]["reference"],
result["stress"]["predicted"],
marker="o",
color=fixed_color_test,
label="Test",
)
elif key == "virials" and "virials" in result:
scatter = ax.scatter(
result["virials"]["reference_per_atom"],
result["virials"]["predicted_per_atom"],
marker="o",
color=fixed_color_test,
label="Test",
)
elif key == "dipole" and "dipole" in result:
scatter = ax.scatter(
result["dipole"]["reference_per_atom"],
result["dipole"]["predicted_per_atom"],
marker="o",
color=fixed_color_test,
label="Test",
)
# Only add to legend_labels if scatter was assigned
if scatter is not None:
legend_labels["Test"] = scatter
# Add diagonal line for guide
min_val = min(ax.get_xlim()[0], ax.get_ylim()[0])
max_val = max(ax.get_xlim()[1], ax.get_ylim()[1])
ax.plot(
[min_val, max_val],
[min_val, max_val],
linestyle="--",
color="black",
alpha=0.7,
)
# Set legend with unique entries (Test + individual train/valid names)
if legend_labels:
ax.legend(
handles=legend_labels.values(), labels=legend_labels.keys(), loc="best"
)
ax.set_xlabel(f"Reference {label}")
ax.set_ylabel(f"MACE {label}")
ax.grid(True, linestyle="--", alpha=0.5)
def model_inference(
all_data_loaders: dict,
model: torch.nn.Module,
output_args: Dict[str, bool],
device: str,
distributed: bool = False,
):
for param in model.parameters():
param.requires_grad = False
results_dict = {}
for name in all_data_loaders:
data_loader = all_data_loaders[name]
logging.debug(f"Running inference on {name} dataset")
scatter_metric = InferenceMetric().to(device)
for batch in data_loader:
batch = batch.to(device)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=False,
compute_force=output_args.get("forces", False),
compute_virials=output_args.get("virials", False),
compute_stress=output_args.get("stress", False),
)
results = scatter_metric(batch, output)
if distributed:
torch.distributed.barrier()
results = scatter_metric.compute()
results_dict[name] = results
scatter_metric.reset()
del data_loader
for param in model.parameters():
param.requires_grad = True
return results_dict
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
return tensor.cpu().detach().numpy()
class InferenceMetric(Metric):
"""Metric class for collecting reference and predicted values for scatterplot visualization."""
def __init__(self):
super().__init__()
# Raw values
self.add_state("ref_energies", default=[], dist_reduce_fx="cat")
self.add_state("pred_energies", default=[], dist_reduce_fx="cat")
self.add_state("ref_forces", default=[], dist_reduce_fx="cat")
self.add_state("pred_forces", default=[], dist_reduce_fx="cat")
self.add_state("ref_stress", default=[], dist_reduce_fx="cat")
self.add_state("pred_stress", default=[], dist_reduce_fx="cat")
self.add_state("ref_virials", default=[], dist_reduce_fx="cat")
self.add_state("pred_virials", default=[], dist_reduce_fx="cat")
self.add_state("ref_dipole", default=[], dist_reduce_fx="cat")
self.add_state("pred_dipole", default=[], dist_reduce_fx="cat")
# Per-atom normalized values
self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat")
# Store atom counts for each configuration
self.add_state("atom_counts", default=[], dist_reduce_fx="cat")
# Counters
self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, batch, output): # pylint: disable=arguments-differ
"""Update metric states with new batch data."""
# Calculate number of atoms per configuration
atoms_per_config = batch.ptr[1:] - batch.ptr[:-1]
self.atom_counts.append(atoms_per_config)
# Energy
if output.get("energy") is not None and batch.energy is not None:
self.n_energy += 1.0
self.ref_energies.append(batch.energy)
self.pred_energies.append(output["energy"])
# Per-atom normalization
self.ref_energies_per_atom.append(batch.energy / atoms_per_config)
self.pred_energies_per_atom.append(output["energy"] / atoms_per_config)
# Forces
if output.get("forces") is not None and batch.forces is not None:
self.n_forces += 1.0
self.ref_forces.append(batch.forces)
self.pred_forces.append(output["forces"])
# Stress
if output.get("stress") is not None and batch.stress is not None:
self.n_stress += 1.0
self.ref_stress.append(batch.stress)
self.pred_stress.append(output["stress"])
# Virials
if output.get("virials") is not None and batch.virials is not None:
self.n_virials += 1.0
self.ref_virials.append(batch.virials)
self.pred_virials.append(output["virials"])
# Per-atom normalization
atoms_per_config_3d = atoms_per_config.view(-1, 1, 1)
self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d)
self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d)
# Dipole
if output.get("dipole") is not None and batch.dipole is not None:
self.n_dipole += 1.0
self.ref_dipole.append(batch.dipole)
self.pred_dipole.append(output["dipole"])
atoms_per_config_3d = atoms_per_config.view(-1, 1)
self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d)
self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d)
def _process_data(self, ref_list, pred_list):
# Handle different possible states of ref_list and pred_list in distributed mode
# Check if this is a list type object
if isinstance(ref_list, (list, tuple)):
if len(ref_list) == 0:
return None, None
ref = torch.cat(ref_list).reshape(-1)
pred = torch.cat(pred_list).reshape(-1)
# Handle case where ref_list is already a tensor (happens after reset in distributed mode)
elif isinstance(ref_list, torch.Tensor):
ref = ref_list.reshape(-1)
pred = pred_list.reshape(-1)
# Handle other possible types
else:
return None, None
return to_numpy(ref), to_numpy(pred)
def compute(self):
"""Compute final results for scatterplot."""
results = {}
# Process energies
if self.n_energy:
ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies)
ref_e_pa, pred_e_pa = self._process_data(
self.ref_energies_per_atom, self.pred_energies_per_atom
)
results["energy"] = {
"reference": ref_e,
"predicted": pred_e,
"reference_per_atom": ref_e_pa,
"predicted_per_atom": pred_e_pa,
}
# Process forces
if self.n_forces:
ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces)
results["forces"] = {
"reference": ref_f,
"predicted": pred_f,
}
# Process stress
if self.n_stress:
ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress)
results["stress"] = {
"reference": ref_s,
"predicted": pred_s,
}
# Process virials
if self.n_virials:
ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials)
ref_v_pa, pred_v_pa = self._process_data(
self.ref_virials_per_atom, self.pred_virials_per_atom
)
results["virials"] = {
"reference": ref_v,
"predicted": pred_v,
"reference_per_atom": ref_v_pa,
"predicted_per_atom": pred_v_pa,
}
# Process dipoles
if self.n_dipole:
ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole)
ref_d_pa, pred_d_pa = self._process_data(
self.ref_dipole_per_atom, self.pred_dipole_per_atom
)
results["dipole"] = {
"reference": ref_d,
"predicted": pred_d,
"reference_per_atom": ref_d_pa,
"predicted_per_atom": pred_d_pa,
}
return results
from .atomic_data import AtomicData
from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5
from .lmdb_dataset import LMDBDataset
from .neighborhood import get_neighborhood
from .utils import (
Configuration,
Configurations,
KeySpecification,
compute_average_E0s,
config_from_atoms,
config_from_atoms_list,
load_from_xyz,
random_train_valid_split,
save_AtomicData_to_HDF5,
save_configurations_as_HDF5,
save_dataset_as_HDF5,
test_config_types,
update_keyspec_from_kwargs,
)
__all__ = [
"get_neighborhood",
"Configuration",
"Configurations",
"random_train_valid_split",
"load_from_xyz",
"test_config_types",
"config_from_atoms",
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
"save_dataset_as_HDF5",
"HDF5Dataset",
"dataset_from_sharded_hdf5",
"save_AtomicData_to_HDF5",
"save_configurations_as_HDF5",
"KeySpecification",
"update_keyspec_from_kwargs",
"LMDBDataset",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment