You need to sign in or sign up before continuing.
Unverified Commit 251f5af2 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 73ff4f3a
"""Demonstrates active learning molecular dynamics with constant temperature."""
import argparse
import os
import time
import ase.io
import numpy as np
from ase import units
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from mace.calculators.mace import MACECalculator
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1
)
parser.add_argument(
"--error_threshold", help="error threshold", type=float, default=0.1
)
parser.add_argument("--temperature_K", help="temperature", type=float, default=300)
parser.add_argument("--friction", help="friction", type=float, default=0.01)
parser.add_argument("--timestep", help="timestep", type=float, default=1)
parser.add_argument("--nsteps", help="number of steps", type=int, default=1000)
parser.add_argument(
"--nprint", help="number of steps between prints", type=int, default=10
)
parser.add_argument(
"--nsave", help="number of steps between saves", type=int, default=10
)
parser.add_argument(
"--ncheckerror", help="number of steps between saves", type=int, default=10
)
parser.add_argument(
"--model",
help="path to model. Use wildcards to add multiple models as committee eg "
"(`mace_*.model` to load mace_1.model, mace_2.model) ",
required=True,
)
parser.add_argument("--output", help="output path", required=True)
parser.add_argument(
"--device",
help="select device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
)
parser.add_argument(
"--default_dtype",
help="set default dtype",
type=str,
choices=["float32", "float64"],
default="float64",
)
parser.add_argument(
"--compute_stress",
help="compute stress",
action="store_true",
default=False,
)
parser.add_argument(
"--info_prefix",
help="prefix for energy, forces and stress keys",
type=str,
default="MACE_",
)
return parser.parse_args()
def printenergy(dyn, start_time=None): # store a reference to atoms in the definition.
"""Function to print the potential, kinetic and total energy."""
a = dyn.atoms
epot = a.get_potential_energy() / len(a)
ekin = a.get_kinetic_energy() / len(a)
if start_time is None:
elapsed_time = 0
else:
elapsed_time = time.time() - start_time
forces_var = np.var(a.calc.results["forces_comm"], axis=0)
print(
"%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209
"Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A"
% (
elapsed_time,
epot,
ekin,
ekin / (1.5 * units.kB),
epot + ekin,
dyn.get_time() / units.fs,
a.calc.results["energy_var"],
np.max(np.linalg.norm(forces_var, axis=1)),
),
flush=True,
)
def save_config(dyn, fname):
atomsi = dyn.atoms
ens = atomsi.get_potential_energy()
frcs = atomsi.get_forces()
atomsi.info.update(
{
"mlff_energy": ens,
"time": np.round(dyn.get_time() / units.fs, 5),
"mlff_energy_var": atomsi.calc.results["energy_var"],
}
)
atomsi.arrays.update(
{
"mlff_forces": frcs,
"mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0),
}
)
ase.io.write(fname, atomsi, append=True)
def stop_error(dyn, threshold, reg=0.2):
atomsi = dyn.atoms
force_var = np.var(atomsi.calc.results["forces_comm"], axis=0)
force = atomsi.get_forces()
ferr = np.sqrt(np.sum(force_var, axis=1))
ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg)
if np.max(ferr_rel) > threshold:
print(
"Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209
np.max(ferr_rel), dyn.get_time() / units.fs
),
flush=True,
)
dyn.max_steps = 0
def main() -> None:
args = parse_args()
run(args)
def run(args: argparse.Namespace) -> None:
mace_fname = args.model
atoms_fname = args.config
atoms_index = args.config_index
mace_calc = MACECalculator(
model_paths=mace_fname,
device=args.device,
default_dtype=args.default_dtype,
)
NSTEPS = args.nsteps
if os.path.exists(args.output):
print("Trajectory exists. Continuing from last step.")
atoms = ase.io.read(args.output, index=-1)
len_save = len(ase.io.read(args.output, ":"))
print("Last step: ", atoms.info["time"], "Number of configs: ", len_save)
NSTEPS -= len_save * args.nsave
else:
atoms = ase.io.read(atoms_fname, index=atoms_index)
MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K)
atoms.calc = mace_calc
# We want to run MD with constant energy using the Langevin algorithm
# with a time step of 5 fs, the temperature T and the friction
# coefficient to 0.02 atomic units.
dyn = Langevin(
atoms=atoms,
timestep=args.timestep * units.fs,
temperature_K=args.temperature_K,
friction=args.friction,
)
dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time())
dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output)
dyn.attach(
stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold
)
# Now run the dynamics
dyn.run(NSTEPS)
if __name__ == "__main__":
main()
import argparse
import logging
import os
from typing import Dict, List, Tuple
import torch
from mace.tools.scripts_utils import extract_config_mace_model
def get_transfer_keys(num_layers: int) -> List[str]:
"""Get list of keys that need to be transferred"""
return [
"node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight",
*[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)],
"scale_shift.scale",
"scale_shift.shift",
*[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)],
] + [
s
for j in range(num_layers)
for s in [
f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight",
]
]
def get_kmax_pairs(
max_L: int, correlation: int, num_layers: int
) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation"""
if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3:
kmax_pairs = [[i, max_L] for i in range(num_layers - 1)]
kmax_pairs = kmax_pairs + [[num_layers - 1, 0]]
return kmax_pairs
raise NotImplementedError(f"Correlation {correlation} not supported")
def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor],
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer symmetric contraction weights from CuEq to E3nn format"""
kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers)
for i, kmax in kmax_pairs:
# Get the combined weight tensor from source
wm = source_dict[f"products.{i}.symmetric_contractions.weight"]
# Get split sizes based on target dimensions
splits = []
for k in range(kmax + 1):
for suffix in ["_max", ".0", ".1"]:
key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}"
target_shape = target_dict[key].shape
splits.append(target_shape[1])
# Split the weights using the calculated sizes
weights_split = torch.split(wm, splits, dim=1)
# Assign back to target dictionary
idx = 0
for k in range(kmax + 1):
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights_max"
] = weights_split[idx]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.0"
] = weights_split[idx + 1]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.1"
] = weights_split[idx + 2]
idx += 3
def transfer_weights(
source_model: torch.nn.Module,
target_model: torch.nn.Module,
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer weights from CuEq to E3nn format"""
# Get state dicts
source_dict = source_model.state_dict()
target_dict = target_model.state_dict()
# Transfer main weights
transfer_keys = get_transfer_keys(num_layers)
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")
# Transfer symmetric contractions
transfer_symmetric_contractions(
source_dict, target_dict, max_L, correlation, num_layers
)
# Unsqueeze linear and skip_tp layers
for key in source_dict.keys():
if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key:
target_dict[key] = target_dict[key].squeeze(0)
# Transfer remaining matching keys
transferred_keys = set(transfer_keys)
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}
if remaining_keys:
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key]
else:
logging.warning(
f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
)
# Transfer avg_num_neighbors
for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i
].avg_num_neighbors
# Load state dict into target model
target_model.load_state_dict(target_dict)
def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True):
# Load CuEq model
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
config = extract_config_mace_model(source_model)
# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
# Remove CuEq config
config.pop("cueq_config", None)
# Create new model without CuEq config
logging.info("Creating new model without CuEq settings")
target_model = source_model.__class__(**config)
# Transfer weights with proper remapping
num_layers = config["num_interactions"]
transfer_weights(source_model, target_model, max_L, correlation, num_layers)
if return_model:
return target_model
# Save model
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.warning(f"Saving E3nn model to {output_model}")
torch.save(target_model, output_model)
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input CuEq model")
parser.add_argument(
"--output_model", help="Path to output E3nn model", default="e3nn_model.pt"
)
parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument(
"--return_model",
action="store_false",
help="Return model instead of saving to file",
)
args = parser.parse_args()
run(
input_model=args.input_model,
output_model=args.output_model,
device=args.device,
return_model=args.return_model,
)
if __name__ == "__main__":
main()
from argparse import ArgumentParser
import torch
def main():
parser = ArgumentParser()
parser.add_argument(
"--target_device",
"-t",
help="device to convert to, usually 'cpu' or 'cuda'",
default="cpu",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()
if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device
model = torch.load(args.model_file, weights_only=False)
model.to(args.target_device)
torch.save(model, args.output_file)
if __name__ == "__main__":
main()
import argparse
import logging
import os
from typing import Dict, List, Tuple
import torch
from mace.modules.wrapper_ops import CuEquivarianceConfig
from mace.tools.scripts_utils import extract_config_mace_model
def get_transfer_keys(num_layers: int) -> List[str]:
"""Get list of keys that need to be transferred"""
return [
"node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight",
*[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)],
"scale_shift.scale",
"scale_shift.shift",
*[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)],
] + [
s
for j in range(num_layers)
for s in [
f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight",
]
]
def get_kmax_pairs(
max_L: int, correlation: int, num_layers: int
) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation"""
if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3:
kmax_pairs = [[i, max_L] for i in range(num_layers - 1)]
kmax_pairs = kmax_pairs + [[num_layers - 1, 0]]
return kmax_pairs
raise NotImplementedError(f"Correlation {correlation} not supported")
def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor],
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer symmetric contraction weights"""
kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers)
for i, kmax in kmax_pairs:
wm = torch.concatenate(
[
source_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}"
]
for k in range(kmax + 1)
for j in ["_max", ".0", ".1"]
],
dim=1,
)
target_dict[f"products.{i}.symmetric_contractions.weight"] = wm
def transfer_weights(
source_model: torch.nn.Module,
target_model: torch.nn.Module,
max_L: int,
correlation: int,
num_layers: int,
):
"""Transfer weights with proper remapping"""
# Get source state dict
source_dict = source_model.state_dict()
target_dict = target_model.state_dict()
# Transfer main weights
transfer_keys = get_transfer_keys(num_layers)
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")
# Transfer symmetric contractions
transfer_symmetric_contractions(
source_dict, target_dict, max_L, correlation, num_layers
)
# Unsqueeze linear and skip_tp layers
for key in source_dict.keys():
if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key:
target_dict[key] = target_dict[key].unsqueeze(0)
transferred_keys = set(transfer_keys)
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}
if remaining_keys:
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key]
else:
logging.warning(
f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
)
# Transfer avg_num_neighbors
for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i
].avg_num_neighbors
# Load state dict into target model
target_model.load_state_dict(target_dict)
def run(
input_model,
output_model="_cueq.model",
device="cpu",
return_model=True,
):
# Setup logging
# Load original model
# logging.warning(f"Loading model")
# check if input_model is a path or a model
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
config = extract_config_mace_model(source_model)
# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
# Add cuequivariance config
config["cueq_config"] = CuEquivarianceConfig(
enabled=True,
layout="ir_mul",
group="O3_e3nn",
optimize_all=True,
)
# Create new model with cuequivariance config
logging.info("Creating new model with cuequivariance settings")
target_model = source_model.__class__(**config).to(device)
# Transfer weights with proper remapping
num_layers = config["num_interactions"]
transfer_weights(source_model, target_model, max_L, correlation, num_layers)
if return_model:
return target_model
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.warning(f"Saving CuEq model to {output_model}")
torch.save(target_model, output_model)
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input MACE model")
parser.add_argument(
"--output_model",
help="Path to output cuequivariance model",
default="cueq_model.pt",
)
parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument(
"--return_model",
action="store_false",
help="Return model instead of saving to file",
)
args = parser.parse_args()
run(
input_model=args.input_model,
output_model=args.output_model,
device=args.device,
return_model=args.return_model,
)
if __name__ == "__main__":
main()
# pylint: disable=wrong-import-position
import argparse
import copy
import os
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import torch
from e3nn.util import jit
from mace.calculators import LAMMPS_MACE
from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"model_path",
type=str,
help="Path to the model to be converted to LAMMPS",
)
parser.add_argument(
"--head",
type=str,
nargs="?",
help="Head of the model to be converted to LAMMPS",
default=None,
)
parser.add_argument(
"--dtype",
type=str,
nargs="?",
help="Data type of the model to be converted to LAMMPS",
default="float64",
)
parser.add_argument(
"--format",
type=str,
help="Old libtorch format, or new mliap format",
default="libtorch",
)
return parser.parse_args()
def select_head(model):
if hasattr(model, "heads"):
heads = model.heads
else:
heads = [None]
if len(heads) == 1:
print(f"Only one head found in the model: {heads[0]}. Skipping selection.")
return heads[0]
print("Available heads in the model:")
for i, head in enumerate(heads):
print(f"{i + 1}: {head}")
# Ask the user to select a head
selected = input(
f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): "
)
if selected.isdigit() and 1 <= int(selected) <= len(heads):
return heads[int(selected) - 1]
if selected == "":
print("No head selected. Proceeding without specifying a head.")
return None
print(f"No valid selection made. Defaulting to the last head: {heads[-1]}")
return heads[-1]
def main():
args = parse_args()
model_path = args.model_path # takes model name as command-line input
model = torch.load(
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
if args.dtype == "float64":
model = model.double().to("cpu")
elif args.dtype == "float32":
print("Converting model to float32, this may cause loss of precision.")
model = model.float().to("cpu")
if args.format == "mliap":
# Enabling cuequivariance by default. TODO: switch?
model = run_e3nn_to_cueq(copy.deepcopy(model))
model.lammps_mliap = True
if args.head is None:
head = select_head(model)
else:
head = args.head
print(
f"Selected head: {head} from command line in the list available heads: {model.heads}"
)
lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE
lammps_model = (
lammps_class(model, head=head) if head is not None else lammps_class(model)
)
if args.format == "mliap":
torch.save(lammps_model, model_path + "-mliap_lammps.pt")
else:
lammps_model_compiled = jit.compile(lammps_model)
lammps_model_compiled.save(model_path + "-lammps.pt")
if __name__ == "__main__":
main()
###########################################################################################
# Script for evaluating configurations contained in an xyz file with a trained model
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import argparse
import ase.data
import ase.io
import numpy as np
import torch
from mace import data
from mace.tools import torch_geometric, torch_tools, utils
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True)
parser.add_argument(
"--device",
help="select device",
type=str,
choices=["cpu", "cuda"],
default="cpu",
)
parser.add_argument(
"--default_dtype",
help="set default dtype",
type=str,
choices=["float32", "float64"],
default="float64",
)
parser.add_argument("--batch_size", help="batch size", type=int, default=64)
parser.add_argument(
"--compute_stress",
help="compute stress",
action="store_true",
default=False,
)
parser.add_argument(
"--return_contributions",
help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE",
action="store_true",
default=False,
)
parser.add_argument(
"--info_prefix",
help="prefix for energy, forces and stress keys",
type=str,
default="MACE_",
)
parser.add_argument(
"--head",
help="Model head used for evaluation",
type=str,
required=False,
default=None,
)
return parser.parse_args()
def main() -> None:
args = parse_args()
run(args)
def run(args: argparse.Namespace) -> None:
torch_tools.set_default_dtype(args.default_dtype)
device = torch_tools.init_device(args.device)
# Load model
model = torch.load(f=args.model, map_location=args.device)
model = model.to(
args.device
) # shouldn't be necessary but seems to help with CUDA problems
for param in model.parameters():
param.requires_grad = False
# Load data and prepare input
atoms_list = ase.io.read(args.configs, index=":")
if args.head is not None:
for atoms in atoms_list:
atoms.info["head"] = args.head
configs = [data.config_from_atoms(atoms) for atoms in atoms_list]
z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])
try:
heads = model.heads
except AttributeError:
heads = None
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=z_table, cutoff=float(model.r_max), heads=heads
)
for config in configs
],
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
)
# Collect data
energies_list = []
contributions_list = []
stresses_list = []
forces_collection = []
for batch in data_loader:
batch = batch.to(device)
output = model(batch.to_dict(), compute_stress=args.compute_stress)
energies_list.append(torch_tools.to_numpy(output["energy"]))
if args.compute_stress:
stresses_list.append(torch_tools.to_numpy(output["stress"]))
if args.return_contributions:
contributions_list.append(torch_tools.to_numpy(output["contributions"]))
forces = np.split(
torch_tools.to_numpy(output["forces"]),
indices_or_sections=batch.ptr[1:],
axis=0,
)
forces_collection.append(forces[:-1]) # drop last as its empty
energies = np.concatenate(energies_list, axis=0)
forces_list = [
forces for forces_list in forces_collection for forces in forces_list
]
assert len(atoms_list) == len(energies) == len(forces_list)
if args.compute_stress:
stresses = np.concatenate(stresses_list, axis=0)
assert len(atoms_list) == stresses.shape[0]
if args.return_contributions:
contributions = np.concatenate(contributions_list, axis=0)
assert len(atoms_list) == contributions.shape[0]
# Store data in atoms objects
for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)):
atoms.calc = None # crucial
atoms.info[args.info_prefix + "energy"] = energy
atoms.arrays[args.info_prefix + "forces"] = forces
if args.compute_stress:
atoms.info[args.info_prefix + "stress"] = stresses[i]
if args.return_contributions:
atoms.info[args.info_prefix + "BO_contributions"] = contributions[i]
# Write atoms to output path
ase.io.write(args.output, images=atoms_list, format="extxyz")
if __name__ == "__main__":
main()
###########################################################################################
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from __future__ import annotations
import argparse
import logging
from dataclasses import dataclass
from enum import Enum
from typing import List, Tuple, Union
import ase.data
import ase.io
import numpy as np
import torch
from mace.calculators import MACECalculator, mace_mp
try:
import fpsample # type: ignore
except ImportError:
pass
class FilteringType(Enum):
NONE = "none"
COMBINATIONS = "combinations"
EXCLUSIVE = "exclusive"
INCLUSIVE = "inclusive"
class SubselectType(Enum):
FPS = "fps"
RANDOM = "random"
@dataclass
class SelectionSettings:
configs_pt: str
output: str
configs_ft: str | None = None
atomic_numbers: List[int] | None = None
num_samples: int | None = None
subselect: SubselectType = SubselectType.FPS
model: str = "small"
descriptors: str | None = None
device: str = "cpu"
default_dtype: str = "float64"
head_pt: str | None = None
head_ft: str | None = None
filtering_type: FilteringType = FilteringType.COMBINATIONS
weight_ft: float = 1.0
weight_pt: float = 1.0
seed: int = 42
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--configs_pt",
help="path to XYZ configurations for the pretraining",
required=True,
)
parser.add_argument(
"--configs_ft",
help="path or list of paths to XYZ configurations for the finetuning",
required=False,
default=None,
)
parser.add_argument(
"--num_samples",
help="number of samples to select for the pretraining",
type=int,
required=False,
default=None,
)
parser.add_argument(
"--subselect",
help="method to subselect the configurations of the pretraining set",
type=SubselectType,
choices=list(SubselectType),
default=SubselectType.FPS,
)
parser.add_argument(
"--model", help="path to model", default="small", required=False
)
parser.add_argument("--output", help="output path", required=True)
parser.add_argument(
"--descriptors", help="path to descriptors", required=False, default=None
)
parser.add_argument(
"--device",
help="select device",
type=str,
choices=["cpu", "cuda"],
default="cpu",
)
parser.add_argument(
"--default_dtype",
help="set default dtype",
type=str,
choices=["float32", "float64"],
default="float64",
)
parser.add_argument(
"--head_pt",
help="level of head for the pretraining set",
type=str,
default=None,
)
parser.add_argument(
"--head_ft",
help="level of head for the finetuning set",
type=str,
default=None,
)
parser.add_argument(
"--filtering_type",
help="filtering type",
type=FilteringType,
choices=list(FilteringType),
default=FilteringType.NONE,
)
parser.add_argument(
"--weight_ft",
help="weight for the finetuning set",
type=float,
default=1.0,
)
parser.add_argument(
"--weight_pt",
help="weight for the pretraining set",
type=float,
default=1.0,
)
parser.add_argument("--seed", help="random seed", type=int, default=42)
return parser.parse_args()
def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None:
logging.info("Calculating descriptors")
for mol in atoms:
descriptors = calc.get_descriptors(mol.copy(), invariants_only=True)
# average descriptors over atoms for each element
descriptors_dict = {
element: np.mean(descriptors[mol.symbols == element], axis=0)
for element in np.unique(mol.symbols)
}
mol.info["mace_descriptors"] = descriptors_dict
def filter_atoms(
atoms: ase.Atoms,
element_subset: List[str],
filtering_type: FilteringType = FilteringType.COMBINATIONS,
) -> bool:
"""
Filters atoms based on the provided filtering type and element subset.
Parameters:
atoms (ase.Atoms): The atoms object to filter.
element_subset (list): The list of elements to consider during filtering.
filtering_type (FilteringType): The type of filtering to apply.
Can be one of the following `FilteringType` enum members:
- `FilteringType.NONE`: No filtering is applied.
- `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present.
- `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise.
- `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements.
Returns:
bool: True if the atoms pass the filter, False otherwise.
"""
if filtering_type == FilteringType.NONE:
return True
if filtering_type == FilteringType.COMBINATIONS:
atom_symbols = np.unique(atoms.symbols)
return all(
x in element_subset for x in atom_symbols
) # atoms must *only* contain elements in the subset
if filtering_type == FilteringType.EXCLUSIVE:
atom_symbols = set(list(atoms.symbols))
return atom_symbols == set(element_subset)
if filtering_type == FilteringType.INCLUSIVE:
atom_symbols = np.unique(atoms.symbols)
return all(
x in atom_symbols for x in element_subset
) # atoms must *at least* contain elements in the subset
raise ValueError(
f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}."
)
class FPS:
def __init__(self, atoms_list: List[ase.Atoms], n_samples: int):
self.n_samples = n_samples
self.atoms_list = atoms_list
self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore
self.species_dict = {x: i for i, x in enumerate(self.species)}
# start from a random configuration
self.list_index = [np.random.randint(0, len(atoms_list))]
self.assemble_descriptors()
def run(
self,
) -> List[int]:
"""
Run the farthest point sampling algorithm.
"""
descriptor_dataset_reshaped = (
self.descriptors_dataset.reshape( # pylint: disable=E1121
(len(self.atoms_list), -1)
)
)
logging.info(f"{descriptor_dataset_reshaped.shape}")
logging.info(f"n_samples: {self.n_samples}")
self.list_index = fpsample.fps_npdu_kdtree_sampling(
descriptor_dataset_reshaped,
self.n_samples,
)
return self.list_index
def assemble_descriptors(self) -> None:
"""
Assemble the descriptors for all the configurations.
"""
self.descriptors_dataset: np.ndarray = 10e10 * np.ones(
(
len(self.atoms_list),
len(self.species),
len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]),
),
dtype=np.float32,
).astype(np.float32)
for i, atoms in enumerate(self.atoms_list):
descriptors = atoms.info["mace_descriptors"]
for z in descriptors:
self.descriptors_dataset[i, self.species_dict[z]] = np.array(
descriptors[z]
).astype(np.float32)
def _load_calc(
model: str, device: str, default_dtype: str, subselect: SubselectType
) -> Union[MACECalculator, None]:
if subselect == SubselectType.RANDOM:
return None
if model in ["small", "medium", "large"]:
calc = mace_mp(model, device=device, default_dtype=default_dtype)
else:
calc = MACECalculator(
model_paths=model,
device=device,
default_dtype=default_dtype,
)
return calc
def _get_finetuning_elements(
atoms: List[ase.Atoms], atomic_numbers: List[int] | None
) -> List[str]:
if atoms:
logging.debug(
"Using elements from the finetuning configurations for filtering."
)
species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore
elif atomic_numbers is not None and atomic_numbers:
logging.debug("Using the supplied atomic numbers for filtering.")
species = [ase.data.chemical_symbols[z] for z in atomic_numbers]
else:
species = []
return species
def _read_finetuning_configs(
configs_ft: Union[str, list[str], None],
) -> List[ase.Atoms]:
if isinstance(configs_ft, str):
path = configs_ft
return ase.io.read(path, index=":") # type: ignore
if isinstance(configs_ft, list):
assert all(isinstance(x, str) for x in configs_ft)
atoms_list_ft = []
for path in configs_ft:
atoms_list_ft += ase.io.read(path, index=":")
return atoms_list_ft
if configs_ft is None:
return []
raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}")
def _filter_pretraining_data(
atoms: list[ase.Atoms],
filtering_type: FilteringType,
all_species_ft: List[str],
) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]:
logging.info(
"Filtering configurations based on the finetuning set, "
f"filtering type: {filtering_type}, elements: {all_species_ft}"
)
passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms]
assert len(passes_filter) == len(atoms), "Filtering failed"
filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes]
remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes]
return filtered_atoms, remaining_atoms, passes_filter
def _get_random_configs(
num_samples: int,
atoms: List[ase.Atoms],
) -> list[ase.Atoms]:
if num_samples > len(atoms):
raise ValueError(
f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})"
)
indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False)
return [atoms[i] for i in indices]
def _load_descriptors(
atoms: List[ase.Atoms],
passes_filter: List[bool],
descriptors_path: str | None,
calc: MACECalculator | None,
full_data_length: int,
) -> None:
if descriptors_path is not None:
logging.info(f"Loading descriptors from {descriptors_path}")
descriptors = np.load(descriptors_path, allow_pickle=True)
assert sum(passes_filter) == len(atoms)
if len(descriptors) != full_data_length:
raise ValueError(
f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})"
"Please provide descriptors for all configurations"
)
required_descriptors = [
descriptors[i] for i, passes in enumerate(passes_filter) if passes
]
for i, atoms_ in enumerate(atoms):
atoms_.info["mace_descriptors"] = required_descriptors[i]
else:
logging.info("Calculating descriptors")
if calc is None:
raise ValueError("MACECalculator must be provided to calculate descriptors")
calculate_descriptors(atoms, calc)
def _maybe_save_descriptors(
atoms: List[ase.Atoms],
output_path: str,
) -> None:
"""
Save the descriptors if they are present in the atoms objects.
Also, delete the descriptors from the atoms objects.
"""
if all("mace_descriptors" in x.info for x in atoms):
descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy")
logging.info(f"Saving descriptors at {descriptor_save_path}")
descriptors_list = [x.info["mace_descriptors"] for x in atoms]
np.save(descriptor_save_path, descriptors_list, allow_pickle=True)
for x in atoms:
del x.info["mace_descriptors"]
def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]:
try:
fps_pt = FPS(atoms, num_samples)
idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations")
return [atoms[i] for i in idx_pt]
except Exception as e: # pylint: disable=W0703
logging.error(f"FPS failed, selecting random configurations instead: {e}")
return _get_random_configs(num_samples, atoms)
def _subsample_data(
filtered_atoms: List[ase.Atoms],
remaining_atoms: List[ase.Atoms],
passes_filter: List[bool],
num_samples: int | None,
subselect: SubselectType,
descriptors_path: str | None,
calc: MACECalculator | None,
) -> List[ase.Atoms]:
if num_samples is None or num_samples == len(filtered_atoms):
logging.info(
f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations"
)
return filtered_atoms
if num_samples > len(filtered_atoms):
num_sample_randomly = num_samples - len(filtered_atoms)
logging.info(
f"Number of configurations after filtering {len(filtered_atoms)} "
f"is less than the number of samples {num_samples}, "
f"selecting {num_sample_randomly} random configurations for the rest."
)
return filtered_atoms + _get_random_configs(
num_sample_randomly, remaining_atoms
)
if num_samples == 0:
raise ValueError("Number of samples must be greater than 0")
if subselect == SubselectType.FPS:
_load_descriptors(
filtered_atoms,
passes_filter,
descriptors_path,
calc,
full_data_length=len(filtered_atoms) + len(remaining_atoms),
)
logging.info("Selecting configurations using Farthest Point Sampling")
return _maybe_fps(filtered_atoms, num_samples)
if subselect == SubselectType.RANDOM:
return _get_random_configs(num_samples, filtered_atoms)
raise ValueError(f"Invalid subselect type: {subselect}")
def _write_metadata(
atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None
) -> None:
for a in atoms:
a.info["pretrained"] = pretrained
a.info["config_weight"] = config_weight
if head is not None:
a.info["head"] = head
def select_samples(
settings: SelectionSettings,
) -> None:
np.random.seed(settings.seed)
torch.manual_seed(settings.seed)
calc = _load_calc(
settings.model, settings.device, settings.default_dtype, settings.subselect
)
atoms_list_ft = _read_finetuning_configs(settings.configs_ft)
all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers)
if settings.filtering_type is not FilteringType.NONE and not all_species_ft:
raise ValueError(
"Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag."
)
atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore
filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data(
atoms_list_pt, settings.filtering_type, all_species_ft
)
subsampled_atoms = _subsample_data(
filtered_pt_atoms,
remaining_atoms,
passes_filter,
settings.num_samples,
settings.subselect,
settings.descriptors,
calc,
)
_maybe_save_descriptors(subsampled_atoms, settings.output)
_write_metadata(
subsampled_atoms,
pretrained=True,
config_weight=settings.weight_pt,
head=settings.head_pt,
)
_write_metadata(
atoms_list_ft,
pretrained=False,
config_weight=settings.weight_ft,
head=settings.head_ft,
)
logging.info("Saving the selected configurations")
ase.io.write(settings.output, subsampled_atoms, format="extxyz")
logging.info("Saving a combined XYZ file")
atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft
ase.io.write(
settings.output.replace(".xyz", "_combined.xyz"),
atoms_fps_pt_ft,
format="extxyz",
)
def main():
args = parse_args()
settings = SelectionSettings(**vars(args))
select_samples(settings)
if __name__ == "__main__":
main()
import argparse
import dataclasses
import glob
import json
import os
import re
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
plt.rcParams.update({"font.size": 8})
plt.style.use("seaborn-v0_8-paper")
colors = [
"#1f77b4", # muted blue
"#d62728", # brick red
"#ff7f0e", # safety orange
"#2ca02c", # cooked asparagus green
"#9467bd", # muted purple
"#8c564b", # chestnut brown
"#e377c2", # raspberry yogurt pink
"#7f7f7f", # middle gray
"#bcbd22", # curry yellow-green
"#17becf", # blue-teal
]
@dataclasses.dataclass
class RunInfo:
name: str
seed: int
name_re = re.compile(r"(?P<name>.+)_run-(?P<seed>\d+)_train.txt")
def parse_path(path: str) -> RunInfo:
match = name_re.match(os.path.basename(path))
if not match:
raise RuntimeError(f"Cannot parse {path}")
return RunInfo(name=match.group("name"), seed=int(match.group("seed")))
def parse_training_results(path: str) -> List[dict]:
run_info = parse_path(path)
results = []
with open(path, mode="r", encoding="utf-8") as f:
for line in f:
d = json.loads(line)
d["name"] = run_info.name
d["seed"] = run_info.seed
results.append(d)
return results
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Plot mace training statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--path", help="Path to results file (.txt) or directory.", required=True
)
parser.add_argument(
"--min_epoch", help="Minimum epoch.", default=0, type=int, required=False
)
parser.add_argument(
"--start_stage_two",
"--start_swa",
help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.",
default=None,
type=int,
required=False,
dest="start_swa",
)
parser.add_argument(
"--linear",
help="Whether to plot linear instead of log scales.",
default=False,
required=False,
action="store_true",
)
parser.add_argument(
"--error_bars",
help="Whether to plot standard deviations.",
default=False,
required=False,
action="store_true",
)
parser.add_argument(
"--keys",
help="Comma-separated list of keys to plot.",
default="rmse_e,rmse_f",
type=str,
required=False,
)
parser.add_argument(
"--output_format",
help="What file type to save plot as",
default="png",
type=str,
required=False,
)
parser.add_argument(
"--heads",
help="Comma-separated name of the heads used for multihead training",
default=None,
type=str,
required=False,
)
return parser.parse_args()
def plot(
data: pd.DataFrame,
min_epoch: int,
output_path: str,
output_format: str,
linear: bool,
start_swa: int,
error_bars: bool,
keys: str,
heads: str,
) -> None:
"""
Plots train,validation loss and errors as a function of epoch.
min_epoch: minimum epoch to plot.
output_path: path to save the plot.
output_format: format to save the plot.
start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins.
error_bars: whether to plot standard deviation of loss.
linear: whether to plot in linear scale or logscale (default).
keys: Values to plot.
heads: Heads used for multihead training.
"""
labels = {
"mae_e": "MAE E [meV]",
"mae_e_per_atom": "MAE E/atom [meV]",
"rmse_e": "RMSE E [meV]",
"rmse_e_per_atom": "RMSE E/atom [meV]",
"q95_e": "Q95 E [meV]",
"mae_f": "MAE F [meV / A]",
"rel_mae_f": "Relative MAE F [meV / A]",
"rmse_f": "RMSE F [meV / A]",
"rel_rmse_f": "Relative RMSE F [meV / A]",
"q95_f": "Q95 F [meV / A]",
"mae_stress": "MAE Stress",
"rmse_stress": "RMSE Stress [meV / A^3]",
"rmse_virials_per_atom": " RMSE virials/atom [meV]",
"mae_virials": "MAE Virials [meV]",
"rmse_mu_per_atom": "RMSE MU/atom [mDebye]",
}
data = data[data["epoch"] > min_epoch]
if heads is None:
data = (
data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index()
)
valid_data = data[data["mode"] == "eval"]
valid_data_dict = {"default": valid_data}
train_data = data[data["mode"] == "opt"]
else:
heads = heads.split(",")
# Separate eval and opt data
valid_data = (
data[data["mode"] == "eval"]
.groupby(["name", "mode", "epoch", "head"])
.agg(["mean", "std"])
.reset_index()
)
train_data = (
data[data["mode"] == "opt"]
.groupby(["name", "mode", "epoch"])
.agg(["mean", "std"])
.reset_index()
)
valid_data_dict = {
head: valid_data[valid_data["head"] == head] for head in heads
}
for head, valid_data in valid_data_dict.items():
fig, axes = plt.subplots(
nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True
)
# ---- Plot loss ----
ax = axes[0]
ax.plot(
train_data["epoch"],
train_data["loss"]["mean"],
color=colors[1],
linewidth=1,
)
ax.set_ylabel("Training Loss", color=colors[1])
ax.set_yscale("log")
ax2 = ax.twinx()
ax2.plot(
valid_data["epoch"],
valid_data["loss"]["mean"],
color=colors[0],
linewidth=1,
)
ax2.set_ylabel("Validation Loss", color=colors[0])
if not linear:
ax.set_yscale("log")
ax2.set_yscale("log")
if error_bars:
ax.fill_between(
train_data["epoch"],
train_data["loss"]["mean"] - train_data["loss"]["std"],
train_data["loss"]["mean"] + train_data["loss"]["std"],
alpha=0.3,
color=colors[1],
)
ax.fill_between(
valid_data["epoch"],
valid_data["loss"]["mean"] - valid_data["loss"]["std"],
valid_data["loss"]["mean"] + valid_data["loss"]["std"],
alpha=0.3,
color=colors[0],
)
if start_swa is not None:
ax.axvline(
start_swa,
color="black",
linestyle="dashed",
linewidth=1,
alpha=0.6,
label="Stage Two Starts",
)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend(loc="upper right", fontsize=4)
ax.grid(True, linestyle="--", alpha=0.5)
# ---- Plot selected keys ----
ax = axes[1]
twin_axes = []
for i, key in enumerate(keys.split(",")):
color = colors[(i + 3)]
label = labels.get(key, key)
if i == 0:
main_ax = ax
else:
main_ax = ax.twinx()
main_ax.spines.right.set_position(("outward", 40 * (i - 1)))
twin_axes.append(main_ax)
main_ax.plot(
valid_data["epoch"],
valid_data[key]["mean"] * 1e3,
color=color,
label=label,
linewidth=1,
)
if error_bars:
main_ax.fill_between(
valid_data["epoch"],
(valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3,
(valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3,
alpha=0.3,
color=color,
)
main_ax.set_ylabel(label, color=color)
main_ax.tick_params(axis="y", colors=color)
if start_swa is not None:
ax.axvline(
start_swa,
color="black",
linestyle="dashed",
linewidth=1,
alpha=0.6,
label="Stage Two Starts",
)
ax.set_xlabel("Epoch")
ax.set_xlim(left=min_epoch)
ax.grid(True, linestyle="--", alpha=0.5)
fig.savefig(
f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight"
)
plt.close(fig)
def get_paths(path: str) -> List[str]:
if os.path.isfile(path):
return [path]
paths = glob.glob(os.path.join(path, "*_train.txt"))
if len(paths) == 0:
raise RuntimeError(f"Cannot find results in '{path}'")
return paths
def main() -> None:
args = parse_args()
run(args)
def run(args: argparse.Namespace) -> None:
data = pd.DataFrame(
results
for path in get_paths(args.path)
for results in parse_training_results(path)
)
for name, group in data.groupby("name"):
plot(
group,
min_epoch=args.min_epoch,
output_path=name,
output_format=args.output_format,
linear=args.linear,
start_swa=args.start_swa,
error_bars=args.error_bars,
keys=args.keys,
heads=args.heads,
)
if __name__ == "__main__":
main()
# This file loads an xyz dataset and prepares
# new hdf5 file that is ready for training with on-the-fly dataloading
import argparse
import ast
import json
import logging
import multiprocessing as mp
import os
import random
from functools import partial
from glob import glob
from typing import List, Tuple
import h5py
import numpy as np
import tqdm
from mace import data, tools
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.data.utils import save_configurations_as_HDF5
from mace.modules import compute_statistics
from mace.tools import torch_geometric
from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz
from mace.tools.utils import AtomicNumberTable
def compute_stats_target(
file: str,
z_table: AtomicNumberTable,
r_max: float,
atomic_energies: Tuple,
batch_size: int,
):
train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max)
train_loader = torch_geometric.dataloader.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)
avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies)
output = [avg_num_neighbors, mean, std]
return output
def pool_compute_stats(inputs: List):
path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs
with mp.Pool(processes=num_process) as pool:
re = [
pool.apply_async(
compute_stats_target,
args=(
file,
z_table,
r_max,
atomic_energies,
batch_size,
),
)
for file in glob(path_to_files + "/*")
]
pool.close()
pool.join()
results = [r.get() for r in tqdm.tqdm(re)]
if not results:
raise ValueError(
"No results were computed. Check if the input files exist and are readable."
)
# Separate avg_num_neighbors, mean, and std
avg_num_neighbors = np.mean([r[0] for r in results])
means = np.array([r[1] for r in results])
stds = np.array([r[2] for r in results])
# Compute averages
mean = np.mean(means, axis=0).item()
std = np.mean(stds, axis=0).item()
return avg_num_neighbors, mean, std
def split_array(a: np.ndarray, max_size: int):
drop_last = False
if len(a) % 2 == 1:
a = np.append(a, a[-1])
drop_last = True
factors = get_prime_factors(len(a))
max_factor = 1
for i in range(1, len(factors) + 1):
for j in range(0, len(factors) - i + 1):
if np.prod(factors[j : j + i]) <= max_size:
test = np.prod(factors[j : j + i])
max_factor = max(test, max_factor)
return np.array_split(a, max_factor), drop_last
def get_prime_factors(n: int):
factors = []
for i in range(2, n + 1):
while n % i == 0:
factors.append(i)
n = n / i
return factors
# Define Task for Multiprocessiing
def multi_train_hdf5(process, args, split_train, drop_last):
with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)
def multi_valid_hdf5(process, args, split_valid, drop_last):
with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)
def multi_test_hdf5(process, name, args, split_test, drop_last):
with h5py.File(
args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w"
) as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)
def main() -> None:
"""
This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading
"""
args = tools.build_preprocess_arg_parser().parse_args()
run(args)
def run(args: argparse.Namespace):
"""
This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading
"""
# currently support only command line property_key syntax
args.key_specification = KeySpecification()
update_keyspec_from_kwargs(args.key_specification, vars(args))
# Setup
tools.set_seeds(args.seed)
random.seed(args.seed)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler()],
)
try:
config_type_weights = ast.literal_eval(args.config_type_weights)
assert isinstance(config_type_weights, dict)
except Exception as e: # pylint: disable=W0703
logging.warning(
f"Config type weights not specified correctly ({e}), using Default"
)
config_type_weights = {"Default": 1.0}
folders = ["train", "val", "test"]
for sub_dir in folders:
if not os.path.exists(args.h5_prefix + sub_dir):
os.makedirs(args.h5_prefix + sub_dir)
# Data preparation
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
config_type_weights=config_type_weights,
test_path=args.test_file,
seed=args.seed,
key_specification=args.key_specification,
head_name=None,
)
# Atomic number table
# yapf: disable
if args.atomic_numbers is None:
z_table = tools.get_atomic_number_table_from_zs(
z
for configs in (collections.train, collections.valid)
for config in configs
for z in config.atomic_numbers
)
else:
logging.info("Using atomic numbers from command line argument")
zs_list = ast.literal_eval(args.atomic_numbers)
assert isinstance(zs_list, list)
z_table = tools.get_atomic_number_table_from_zs(zs_list)
logging.info("Preparing training set")
if args.shuffle:
random.shuffle(collections.train)
# split collections.train into batches and save them to hdf5
split_train = np.array_split(collections.train,args.num_process)
drop_last = False
if len(collections.train) % 2 == 1:
drop_last = True
multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_train_hdf5_, args=[i])
p.start()
processes.append(p)
for i in processes:
i.join()
if args.compute_statistics:
logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
# Remove atomic energies if element not in z_table
removed_atomic_energies = {}
for z in list(atomic_energies_dict):
if z not in z_table.zs:
removed_atomic_energies[z] = atomic_energies_dict.pop(z)
if len(removed_atomic_energies) > 0:
logging.warning("Atomic energies for elements not present in the atomic number table have been removed.")
logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}")
logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.")
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic Energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")
# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str([int(z) for z in z_table.zs]),
"r_max": args.r_max,
}
with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)
logging.info("Preparing validation set")
if args.shuffle:
random.shuffle(collections.valid)
split_valid = np.array_split(collections.valid, args.num_process)
drop_last = False
if len(collections.valid) % 2 == 1:
drop_last = True
multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_valid_hdf5_, args=[i])
p.start()
processes.append(p)
for i in processes:
i.join()
if args.test_file is not None:
logging.info("Preparing test sets")
for name, subset in collections.tests:
drop_last = False
if len(subset) % 2 == 1:
drop_last = True
split_test = np.array_split(subset, args.num_process)
multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_test_hdf5_, args=[i, name])
p.start()
processes.append(p)
for i in processes:
i.join()
if __name__ == "__main__":
main()
This diff is collapsed.
from argparse import ArgumentParser
import torch
from mace.tools.scripts_utils import remove_pt_head
def main():
parser = ArgumentParser()
grp = parser.add_mutually_exclusive_group()
grp.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
grp.add_argument(
"--list_heads",
"-l",
action="store_true",
help="list names of the heads",
)
parser.add_argument(
"--target_device",
"-d",
help="target device, defaults to model's current device",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model.head_name, followed by .target_device if specified",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()
model = torch.load(args.model_file, map_location=args.target_device)
torch.set_default_dtype(next(model.parameters()).dtype)
if args.list_heads:
print("Available heads:")
print("\n".join([" " + h for h in model.heads]))
else:
if args.output_file is None:
args.output_file = (
args.model_file
+ "."
+ args.head_name
+ ("." + args.target_device if (args.target_device is not None) else "")
)
model_single = remove_pt_head(model, args.head_name)
if args.target_device is not None:
target_device = str(next(model.parameters()).device)
model_single.to(target_device)
torch.save(model_single, args.output_file)
if __name__ == "__main__":
main()
This diff is collapsed.
from .atomic_data import AtomicData
from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5
from .lmdb_dataset import LMDBDataset
from .neighborhood import get_neighborhood
from .utils import (
Configuration,
Configurations,
KeySpecification,
compute_average_E0s,
config_from_atoms,
config_from_atoms_list,
load_from_xyz,
random_train_valid_split,
save_AtomicData_to_HDF5,
save_configurations_as_HDF5,
save_dataset_as_HDF5,
test_config_types,
update_keyspec_from_kwargs,
)
__all__ = [
"get_neighborhood",
"Configuration",
"Configurations",
"random_train_valid_split",
"load_from_xyz",
"test_config_types",
"config_from_atoms",
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
"save_dataset_as_HDF5",
"HDF5Dataset",
"dataset_from_sharded_hdf5",
"save_AtomicData_to_HDF5",
"save_configurations_as_HDF5",
"KeySpecification",
"update_keyspec_from_kwargs",
"LMDBDataset",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment