Unverified Commit 273418ee authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Merge pull request #1 from hjhk258/main

Fix
parents 7fb73825 9d7b4f63
"""Demonstrates active learning molecular dynamics with constant temperature.""" """Demonstrates active learning molecular dynamics with constant temperature."""
import argparse import argparse
import os import os
import time import time
import ase.io import ase.io
import numpy as np import numpy as np
from ase import units from ase import units
from ase.md.langevin import Langevin from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from mace.calculators.mace import MACECalculator from mace.calculators.mace import MACECalculator
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument("--config", help="path to XYZ configurations", required=True) parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument( parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1 "--config_index", help="index of configuration", type=int, default=-1
) )
parser.add_argument( parser.add_argument(
"--error_threshold", help="error threshold", type=float, default=0.1 "--error_threshold", help="error threshold", type=float, default=0.1
) )
parser.add_argument("--temperature_K", help="temperature", type=float, default=300) parser.add_argument("--temperature_K", help="temperature", type=float, default=300)
parser.add_argument("--friction", help="friction", type=float, default=0.01) parser.add_argument("--friction", help="friction", type=float, default=0.01)
parser.add_argument("--timestep", help="timestep", type=float, default=1) parser.add_argument("--timestep", help="timestep", type=float, default=1)
parser.add_argument("--nsteps", help="number of steps", type=int, default=1000) parser.add_argument("--nsteps", help="number of steps", type=int, default=1000)
parser.add_argument( parser.add_argument(
"--nprint", help="number of steps between prints", type=int, default=10 "--nprint", help="number of steps between prints", type=int, default=10
) )
parser.add_argument( parser.add_argument(
"--nsave", help="number of steps between saves", type=int, default=10 "--nsave", help="number of steps between saves", type=int, default=10
) )
parser.add_argument( parser.add_argument(
"--ncheckerror", help="number of steps between saves", type=int, default=10 "--ncheckerror", help="number of steps between saves", type=int, default=10
) )
parser.add_argument( parser.add_argument(
"--model", "--model",
help="path to model. Use wildcards to add multiple models as committee eg " help="path to model. Use wildcards to add multiple models as committee eg "
"(`mace_*.model` to load mace_1.model, mace_2.model) ", "(`mace_*.model` to load mace_1.model, mace_2.model) ",
required=True, required=True,
) )
parser.add_argument("--output", help="output path", required=True) parser.add_argument("--output", help="output path", required=True)
parser.add_argument( parser.add_argument(
"--device", "--device",
help="select device", help="select device",
type=str, type=str,
choices=["cpu", "cuda"], choices=["cpu", "cuda"],
default="cuda", default="cuda",
) )
parser.add_argument( parser.add_argument(
"--default_dtype", "--default_dtype",
help="set default dtype", help="set default dtype",
type=str, type=str,
choices=["float32", "float64"], choices=["float32", "float64"],
default="float64", default="float64",
) )
parser.add_argument( parser.add_argument(
"--compute_stress", "--compute_stress",
help="compute stress", help="compute stress",
action="store_true", action="store_true",
default=False, default=False,
) )
parser.add_argument( parser.add_argument(
"--info_prefix", "--info_prefix",
help="prefix for energy, forces and stress keys", help="prefix for energy, forces and stress keys",
type=str, type=str,
default="MACE_", default="MACE_",
) )
return parser.parse_args() return parser.parse_args()
def printenergy(dyn, start_time=None): # store a reference to atoms in the definition. def printenergy(dyn, start_time=None): # store a reference to atoms in the definition.
"""Function to print the potential, kinetic and total energy.""" """Function to print the potential, kinetic and total energy."""
a = dyn.atoms a = dyn.atoms
epot = a.get_potential_energy() / len(a) epot = a.get_potential_energy() / len(a)
ekin = a.get_kinetic_energy() / len(a) ekin = a.get_kinetic_energy() / len(a)
if start_time is None: if start_time is None:
elapsed_time = 0 elapsed_time = 0
else: else:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
forces_var = np.var(a.calc.results["forces_comm"], axis=0) forces_var = np.var(a.calc.results["forces_comm"], axis=0)
print( print(
"%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209 "%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209
"Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A" "Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A"
% ( % (
elapsed_time, elapsed_time,
epot, epot,
ekin, ekin,
ekin / (1.5 * units.kB), ekin / (1.5 * units.kB),
epot + ekin, epot + ekin,
dyn.get_time() / units.fs, dyn.get_time() / units.fs,
a.calc.results["energy_var"], a.calc.results["energy_var"],
np.max(np.linalg.norm(forces_var, axis=1)), np.max(np.linalg.norm(forces_var, axis=1)),
), ),
flush=True, flush=True,
) )
def save_config(dyn, fname): def save_config(dyn, fname):
atomsi = dyn.atoms atomsi = dyn.atoms
ens = atomsi.get_potential_energy() ens = atomsi.get_potential_energy()
frcs = atomsi.get_forces() frcs = atomsi.get_forces()
atomsi.info.update( atomsi.info.update(
{ {
"mlff_energy": ens, "mlff_energy": ens,
"time": np.round(dyn.get_time() / units.fs, 5), "time": np.round(dyn.get_time() / units.fs, 5),
"mlff_energy_var": atomsi.calc.results["energy_var"], "mlff_energy_var": atomsi.calc.results["energy_var"],
} }
) )
atomsi.arrays.update( atomsi.arrays.update(
{ {
"mlff_forces": frcs, "mlff_forces": frcs,
"mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0), "mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0),
} }
) )
ase.io.write(fname, atomsi, append=True) ase.io.write(fname, atomsi, append=True)
def stop_error(dyn, threshold, reg=0.2): def stop_error(dyn, threshold, reg=0.2):
atomsi = dyn.atoms atomsi = dyn.atoms
force_var = np.var(atomsi.calc.results["forces_comm"], axis=0) force_var = np.var(atomsi.calc.results["forces_comm"], axis=0)
force = atomsi.get_forces() force = atomsi.get_forces()
ferr = np.sqrt(np.sum(force_var, axis=1)) ferr = np.sqrt(np.sum(force_var, axis=1))
ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg) ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg)
if np.max(ferr_rel) > threshold: if np.max(ferr_rel) > threshold:
print( print(
"Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209 "Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209
np.max(ferr_rel), dyn.get_time() / units.fs np.max(ferr_rel), dyn.get_time() / units.fs
), ),
flush=True, flush=True,
) )
dyn.max_steps = 0 dyn.max_steps = 0
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
run(args) run(args)
def run(args: argparse.Namespace) -> None: def run(args: argparse.Namespace) -> None:
mace_fname = args.model mace_fname = args.model
atoms_fname = args.config atoms_fname = args.config
atoms_index = args.config_index atoms_index = args.config_index
mace_calc = MACECalculator( mace_calc = MACECalculator(
model_paths=mace_fname, model_paths=mace_fname,
device=args.device, device=args.device,
default_dtype=args.default_dtype, default_dtype=args.default_dtype,
) )
NSTEPS = args.nsteps NSTEPS = args.nsteps
if os.path.exists(args.output): if os.path.exists(args.output):
print("Trajectory exists. Continuing from last step.") print("Trajectory exists. Continuing from last step.")
atoms = ase.io.read(args.output, index=-1) atoms = ase.io.read(args.output, index=-1)
len_save = len(ase.io.read(args.output, ":")) len_save = len(ase.io.read(args.output, ":"))
print("Last step: ", atoms.info["time"], "Number of configs: ", len_save) print("Last step: ", atoms.info["time"], "Number of configs: ", len_save)
NSTEPS -= len_save * args.nsave NSTEPS -= len_save * args.nsave
else: else:
atoms = ase.io.read(atoms_fname, index=atoms_index) atoms = ase.io.read(atoms_fname, index=atoms_index)
MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K) MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K)
atoms.calc = mace_calc atoms.calc = mace_calc
# We want to run MD with constant energy using the Langevin algorithm # We want to run MD with constant energy using the Langevin algorithm
# with a time step of 5 fs, the temperature T and the friction # with a time step of 5 fs, the temperature T and the friction
# coefficient to 0.02 atomic units. # coefficient to 0.02 atomic units.
dyn = Langevin( dyn = Langevin(
atoms=atoms, atoms=atoms,
timestep=args.timestep * units.fs, timestep=args.timestep * units.fs,
temperature_K=args.temperature_K, temperature_K=args.temperature_K,
friction=args.friction, friction=args.friction,
) )
dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time()) dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time())
dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output) dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output)
dyn.attach( dyn.attach(
stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold
) )
# Now run the dynamics # Now run the dynamics
dyn.run(NSTEPS) dyn.run(NSTEPS)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import argparse import argparse
import logging import logging
import os import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.scripts_utils import extract_config_mace_model
def get_transfer_keys(num_layers: int) -> List[str]: def get_transfer_keys(num_layers: int) -> List[str]:
"""Get list of keys that need to be transferred""" """Get list of keys that need to be transferred"""
return [ return [
"node_embedding.linear.weight", "node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights", "radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies", "atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight", "readouts.0.linear.weight",
*[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)],
"scale_shift.scale", "scale_shift.scale",
"scale_shift.shift", "scale_shift.shift",
*[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)],
] + [ ] + [
s s
for j in range(num_layers) for j in range(num_layers)
for s in [ for s in [
f"interactions.{j}.linear_up.weight", f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight", f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight", f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight", f"products.{j}.linear.weight",
] ]
] ]
def get_kmax_pairs( def get_kmax_pairs(
max_L: int, correlation: int, num_layers: int max_L: int, correlation: int, num_layers: int
) -> List[Tuple[int, int]]: ) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation""" """Determine kmax pairs based on max_L and correlation"""
if correlation == 2: if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet") raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3: if correlation == 3:
kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] kmax_pairs = [[i, max_L] for i in range(num_layers - 1)]
kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] kmax_pairs = kmax_pairs + [[num_layers - 1, 0]]
return kmax_pairs return kmax_pairs
raise NotImplementedError(f"Correlation {correlation} not supported") raise NotImplementedError(f"Correlation {correlation} not supported")
def transfer_symmetric_contractions( def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor], source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor], target_dict: Dict[str, torch.Tensor],
max_L: int, max_L: int,
correlation: int, correlation: int,
num_layers: int, num_layers: int,
): ):
"""Transfer symmetric contraction weights from CuEq to E3nn format""" """Transfer symmetric contraction weights from CuEq to E3nn format"""
kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers)
for i, kmax in kmax_pairs: for i, kmax in kmax_pairs:
# Get the combined weight tensor from source # Get the combined weight tensor from source
wm = source_dict[f"products.{i}.symmetric_contractions.weight"] wm = source_dict[f"products.{i}.symmetric_contractions.weight"]
# Get split sizes based on target dimensions # Get split sizes based on target dimensions
splits = [] splits = []
for k in range(kmax + 1): for k in range(kmax + 1):
for suffix in ["_max", ".0", ".1"]: for suffix in ["_max", ".0", ".1"]:
key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}"
target_shape = target_dict[key].shape target_shape = target_dict[key].shape
splits.append(target_shape[1]) splits.append(target_shape[1])
# Split the weights using the calculated sizes # Split the weights using the calculated sizes
weights_split = torch.split(wm, splits, dim=1) weights_split = torch.split(wm, splits, dim=1)
# Assign back to target dictionary # Assign back to target dictionary
idx = 0 idx = 0
for k in range(kmax + 1): for k in range(kmax + 1):
target_dict[ target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" f"products.{i}.symmetric_contractions.contractions.{k}.weights_max"
] = weights_split[idx] ] = weights_split[idx]
target_dict[ target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" f"products.{i}.symmetric_contractions.contractions.{k}.weights.0"
] = weights_split[idx + 1] ] = weights_split[idx + 1]
target_dict[ target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" f"products.{i}.symmetric_contractions.contractions.{k}.weights.1"
] = weights_split[idx + 2] ] = weights_split[idx + 2]
idx += 3 idx += 3
def transfer_weights( def transfer_weights(
source_model: torch.nn.Module, source_model: torch.nn.Module,
target_model: torch.nn.Module, target_model: torch.nn.Module,
max_L: int, max_L: int,
correlation: int, correlation: int,
num_layers: int, num_layers: int,
): ):
"""Transfer weights from CuEq to E3nn format""" """Transfer weights from CuEq to E3nn format"""
# Get state dicts # Get state dicts
source_dict = source_model.state_dict() source_dict = source_model.state_dict()
target_dict = target_model.state_dict() target_dict = target_model.state_dict()
# Transfer main weights # Transfer main weights
transfer_keys = get_transfer_keys(num_layers) transfer_keys = get_transfer_keys(num_layers)
for key in transfer_keys: for key in transfer_keys:
if key in source_dict: # Check if key exists if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key] target_dict[key] = source_dict[key]
else: else:
logging.warning(f"Key {key} not found in source model") logging.warning(f"Key {key} not found in source model")
# Transfer symmetric contractions # Transfer symmetric contractions
transfer_symmetric_contractions( transfer_symmetric_contractions(
source_dict, target_dict, max_L, correlation, num_layers source_dict, target_dict, max_L, correlation, num_layers
) )
# Unsqueeze linear and skip_tp layers # Unsqueeze linear and skip_tp layers
for key in source_dict.keys(): for key in source_dict.keys():
if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key:
target_dict[key] = target_dict[key].squeeze(0) target_dict[key] = target_dict[key].squeeze(0)
# Transfer remaining matching keys # Transfer remaining matching keys
transferred_keys = set(transfer_keys) transferred_keys = set(transfer_keys)
remaining_keys = ( remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
) )
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}
if remaining_keys: if remaining_keys:
for key in remaining_keys: for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape: if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}") logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key] target_dict[key] = source_dict[key]
else: else:
logging.warning( logging.warning(
f"Shape mismatch for key {key}: " f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}" f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
) )
# Transfer avg_num_neighbors # Transfer avg_num_neighbors
for i in range(2): for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[ target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i i
].avg_num_neighbors ].avg_num_neighbors
# Load state dict into target model # Load state dict into target model
target_model.load_state_dict(target_dict) target_model.load_state_dict(target_dict)
def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True): def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True):
# Load CuEq model # Load CuEq model
if isinstance(input_model, str): if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device) source_model = torch.load(input_model, map_location=device)
else: else:
source_model = input_model source_model = input_model
default_dtype = next(source_model.parameters()).dtype default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype) torch.set_default_dtype(default_dtype)
# Extract configuration # Extract configuration
config = extract_config_mace_model(source_model) config = extract_config_mace_model(source_model)
# Get max_L and correlation from config # Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax max_L = config["hidden_irreps"].lmax
correlation = config["correlation"] correlation = config["correlation"]
# Remove CuEq config # Remove CuEq config
config.pop("cueq_config", None) config.pop("cueq_config", None)
# Create new model without CuEq config # Create new model without CuEq config
logging.info("Creating new model without CuEq settings") logging.info("Creating new model without CuEq settings")
target_model = source_model.__class__(**config) target_model = source_model.__class__(**config)
# Transfer weights with proper remapping # Transfer weights with proper remapping
num_layers = config["num_interactions"] num_layers = config["num_interactions"]
transfer_weights(source_model, target_model, max_L, correlation, num_layers) transfer_weights(source_model, target_model, max_L, correlation, num_layers)
if return_model: if return_model:
return target_model return target_model
# Save model # Save model
if isinstance(input_model, str): if isinstance(input_model, str):
base = os.path.splitext(input_model)[0] base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}" output_model = f"{base}.{output_model}"
logging.warning(f"Saving E3nn model to {output_model}") logging.warning(f"Saving E3nn model to {output_model}")
torch.save(target_model, output_model) torch.save(target_model, output_model)
return None return None
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input CuEq model") parser.add_argument("input_model", help="Path to input CuEq model")
parser.add_argument( parser.add_argument(
"--output_model", help="Path to output E3nn model", default="e3nn_model.pt" "--output_model", help="Path to output E3nn model", default="e3nn_model.pt"
) )
parser.add_argument("--device", default="cpu", help="Device to use") parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument( parser.add_argument(
"--return_model", "--return_model",
action="store_false", action="store_false",
help="Return model instead of saving to file", help="Return model instead of saving to file",
) )
args = parser.parse_args() args = parser.parse_args()
run( run(
input_model=args.input_model, input_model=args.input_model,
output_model=args.output_model, output_model=args.output_model,
device=args.device, device=args.device,
return_model=args.return_model, return_model=args.return_model,
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( parser.add_argument(
"--target_device", "--target_device",
"-t", "-t",
help="device to convert to, usually 'cpu' or 'cuda'", help="device to convert to, usually 'cpu' or 'cuda'",
default="cpu", default="cpu",
) )
parser.add_argument( parser.add_argument(
"--output_file", "--output_file",
"-o", "-o",
help="name for output model, defaults to model_file.target_device", help="name for output model, defaults to model_file.target_device",
) )
parser.add_argument("model_file", help="input model file path") parser.add_argument("model_file", help="input model file path")
args = parser.parse_args() args = parser.parse_args()
if args.output_file is None: if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device args.output_file = args.model_file + "." + args.target_device
model = torch.load(args.model_file, weights_only=False) model = torch.load(args.model_file, weights_only=False)
model.to(args.target_device) model.to(args.target_device)
torch.save(model, args.output_file) torch.save(model, args.output_file)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import argparse import argparse
import logging import logging
import os import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.modules.wrapper_ops import CuEquivarianceConfig
from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.scripts_utils import extract_config_mace_model
def get_transfer_keys(num_layers: int) -> List[str]: def get_transfer_keys(num_layers: int) -> List[str]:
"""Get list of keys that need to be transferred""" """Get list of keys that need to be transferred"""
return [ return [
"node_embedding.linear.weight", "node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights", "radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies", "atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight", "readouts.0.linear.weight",
*[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)],
"scale_shift.scale", "scale_shift.scale",
"scale_shift.shift", "scale_shift.shift",
*[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)],
] + [ ] + [
s s
for j in range(num_layers) for j in range(num_layers)
for s in [ for s in [
f"interactions.{j}.linear_up.weight", f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight", f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight", f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight", f"products.{j}.linear.weight",
] ]
] ]
def get_kmax_pairs( def get_kmax_pairs(
max_L: int, correlation: int, num_layers: int max_L: int, correlation: int, num_layers: int
) -> List[Tuple[int, int]]: ) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation""" """Determine kmax pairs based on max_L and correlation"""
if correlation == 2: if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet") raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3: if correlation == 3:
kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] kmax_pairs = [[i, max_L] for i in range(num_layers - 1)]
kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] kmax_pairs = kmax_pairs + [[num_layers - 1, 0]]
return kmax_pairs return kmax_pairs
raise NotImplementedError(f"Correlation {correlation} not supported") raise NotImplementedError(f"Correlation {correlation} not supported")
def transfer_symmetric_contractions( def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor], source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor], target_dict: Dict[str, torch.Tensor],
max_L: int, max_L: int,
correlation: int, correlation: int,
num_layers: int, num_layers: int,
): ):
"""Transfer symmetric contraction weights""" """Transfer symmetric contraction weights"""
kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers)
for i, kmax in kmax_pairs: for i, kmax in kmax_pairs:
wm = torch.concatenate( wm = torch.concatenate(
[ [
source_dict[ source_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}"
] ]
for k in range(kmax + 1) for k in range(kmax + 1)
for j in ["_max", ".0", ".1"] for j in ["_max", ".0", ".1"]
], ],
dim=1, dim=1,
) )
target_dict[f"products.{i}.symmetric_contractions.weight"] = wm target_dict[f"products.{i}.symmetric_contractions.weight"] = wm
def transfer_weights( def transfer_weights(
source_model: torch.nn.Module, source_model: torch.nn.Module,
target_model: torch.nn.Module, target_model: torch.nn.Module,
max_L: int, max_L: int,
correlation: int, correlation: int,
num_layers: int, num_layers: int,
): ):
"""Transfer weights with proper remapping""" """Transfer weights with proper remapping"""
# Get source state dict # Get source state dict
source_dict = source_model.state_dict() source_dict = source_model.state_dict()
target_dict = target_model.state_dict() target_dict = target_model.state_dict()
# Transfer main weights # Transfer main weights
transfer_keys = get_transfer_keys(num_layers) transfer_keys = get_transfer_keys(num_layers)
for key in transfer_keys: for key in transfer_keys:
if key in source_dict: # Check if key exists if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key] target_dict[key] = source_dict[key]
else: else:
logging.warning(f"Key {key} not found in source model") logging.warning(f"Key {key} not found in source model")
# Transfer symmetric contractions # Transfer symmetric contractions
transfer_symmetric_contractions( transfer_symmetric_contractions(
source_dict, target_dict, max_L, correlation, num_layers source_dict, target_dict, max_L, correlation, num_layers
) )
# Unsqueeze linear and skip_tp layers # Unsqueeze linear and skip_tp layers
for key in source_dict.keys(): for key in source_dict.keys():
if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key:
target_dict[key] = target_dict[key].unsqueeze(0) target_dict[key] = target_dict[key].unsqueeze(0)
transferred_keys = set(transfer_keys) transferred_keys = set(transfer_keys)
remaining_keys = ( remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
) )
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}
if remaining_keys: if remaining_keys:
for key in remaining_keys: for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape: if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}") logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key] target_dict[key] = source_dict[key]
else: else:
logging.warning( logging.warning(
f"Shape mismatch for key {key}: " f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}" f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
) )
# Transfer avg_num_neighbors # Transfer avg_num_neighbors
for i in range(2): for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[ target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i i
].avg_num_neighbors ].avg_num_neighbors
# Load state dict into target model # Load state dict into target model
target_model.load_state_dict(target_dict) target_model.load_state_dict(target_dict)
def run( def run(
input_model, input_model,
output_model="_cueq.model", output_model="_cueq.model",
device="cpu", device="cpu",
return_model=True, return_model=True,
): ):
# Setup logging # Setup logging
# Load original model # Load original model
# logging.warning(f"Loading model") # logging.warning(f"Loading model")
# check if input_model is a path or a model # check if input_model is a path or a model
if isinstance(input_model, str): if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device) source_model = torch.load(input_model, map_location=device)
else: else:
source_model = input_model source_model = input_model
default_dtype = next(source_model.parameters()).dtype default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype) torch.set_default_dtype(default_dtype)
# Extract configuration # Extract configuration
config = extract_config_mace_model(source_model) config = extract_config_mace_model(source_model)
# Get max_L and correlation from config # Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax max_L = config["hidden_irreps"].lmax
correlation = config["correlation"] correlation = config["correlation"]
# Add cuequivariance config # Add cuequivariance config
config["cueq_config"] = CuEquivarianceConfig( config["cueq_config"] = CuEquivarianceConfig(
enabled=True, enabled=True,
layout="ir_mul", layout="ir_mul",
group="O3_e3nn", group="O3_e3nn",
optimize_all=True, optimize_all=True,
) )
# Create new model with cuequivariance config # Create new model with cuequivariance config
logging.info("Creating new model with cuequivariance settings") logging.info("Creating new model with cuequivariance settings")
target_model = source_model.__class__(**config).to(device) target_model = source_model.__class__(**config).to(device)
# Transfer weights with proper remapping # Transfer weights with proper remapping
num_layers = config["num_interactions"] num_layers = config["num_interactions"]
transfer_weights(source_model, target_model, max_L, correlation, num_layers) transfer_weights(source_model, target_model, max_L, correlation, num_layers)
if return_model: if return_model:
return target_model return target_model
if isinstance(input_model, str): if isinstance(input_model, str):
base = os.path.splitext(input_model)[0] base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}" output_model = f"{base}.{output_model}"
logging.warning(f"Saving CuEq model to {output_model}") logging.warning(f"Saving CuEq model to {output_model}")
torch.save(target_model, output_model) torch.save(target_model, output_model)
return None return None
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input MACE model") parser.add_argument("input_model", help="Path to input MACE model")
parser.add_argument( parser.add_argument(
"--output_model", "--output_model",
help="Path to output cuequivariance model", help="Path to output cuequivariance model",
default="cueq_model.pt", default="cueq_model.pt",
) )
parser.add_argument("--device", default="cpu", help="Device to use") parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument( parser.add_argument(
"--return_model", "--return_model",
action="store_false", action="store_false",
help="Return model instead of saving to file", help="Return model instead of saving to file",
) )
args = parser.parse_args() args = parser.parse_args()
run( run(
input_model=args.input_model, input_model=args.input_model,
output_model=args.output_model, output_model=args.output_model,
device=args.device, device=args.device,
return_model=args.return_model, return_model=args.return_model,
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
import argparse import argparse
import copy import copy
import os import os
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import torch import torch
from e3nn.util import jit from e3nn.util import jit
from mace.calculators import LAMMPS_MACE from mace.calculators import LAMMPS_MACE
from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
"model_path", "model_path",
type=str, type=str,
help="Path to the model to be converted to LAMMPS", help="Path to the model to be converted to LAMMPS",
) )
parser.add_argument( parser.add_argument(
"--head", "--head",
type=str, type=str,
nargs="?", nargs="?",
help="Head of the model to be converted to LAMMPS", help="Head of the model to be converted to LAMMPS",
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
nargs="?", nargs="?",
help="Data type of the model to be converted to LAMMPS", help="Data type of the model to be converted to LAMMPS",
default="float64", default="float64",
) )
parser.add_argument( parser.add_argument(
"--format", "--format",
type=str, type=str,
help="Old libtorch format, or new mliap format", help="Old libtorch format, or new mliap format",
default="libtorch", default="libtorch",
) )
return parser.parse_args() return parser.parse_args()
def select_head(model): def select_head(model):
if hasattr(model, "heads"): if hasattr(model, "heads"):
heads = model.heads heads = model.heads
else: else:
heads = [None] heads = [None]
if len(heads) == 1: if len(heads) == 1:
print(f"Only one head found in the model: {heads[0]}. Skipping selection.") print(f"Only one head found in the model: {heads[0]}. Skipping selection.")
return heads[0] return heads[0]
print("Available heads in the model:") print("Available heads in the model:")
for i, head in enumerate(heads): for i, head in enumerate(heads):
print(f"{i + 1}: {head}") print(f"{i + 1}: {head}")
# Ask the user to select a head # Ask the user to select a head
selected = input( selected = input(
f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): "
) )
if selected.isdigit() and 1 <= int(selected) <= len(heads): if selected.isdigit() and 1 <= int(selected) <= len(heads):
return heads[int(selected) - 1] return heads[int(selected) - 1]
if selected == "": if selected == "":
print("No head selected. Proceeding without specifying a head.") print("No head selected. Proceeding without specifying a head.")
return None return None
print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") print(f"No valid selection made. Defaulting to the last head: {heads[-1]}")
return heads[-1] return heads[-1]
def main(): def main():
args = parse_args() args = parse_args()
model_path = args.model_path # takes model name as command-line input model_path = args.model_path # takes model name as command-line input
model = torch.load( model = torch.load(
model_path, model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) )
if args.dtype == "float64": if args.dtype == "float64":
model = model.double().to("cpu") model = model.double().to("cpu")
elif args.dtype == "float32": elif args.dtype == "float32":
print("Converting model to float32, this may cause loss of precision.") print("Converting model to float32, this may cause loss of precision.")
model = model.float().to("cpu") model = model.float().to("cpu")
if args.format == "mliap": if args.format == "mliap":
# Enabling cuequivariance by default. TODO: switch? # Enabling cuequivariance by default. TODO: switch?
model = run_e3nn_to_cueq(copy.deepcopy(model)) model = run_e3nn_to_cueq(copy.deepcopy(model))
model.lammps_mliap = True model.lammps_mliap = True
if args.head is None: if args.head is None:
head = select_head(model) head = select_head(model)
else: else:
head = args.head head = args.head
print( print(
f"Selected head: {head} from command line in the list available heads: {model.heads}" f"Selected head: {head} from command line in the list available heads: {model.heads}"
) )
lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE
lammps_model = ( lammps_model = (
lammps_class(model, head=head) if head is not None else lammps_class(model) lammps_class(model, head=head) if head is not None else lammps_class(model)
) )
if args.format == "mliap": if args.format == "mliap":
torch.save(lammps_model, model_path + "-mliap_lammps.pt") torch.save(lammps_model, model_path + "-mliap_lammps.pt")
else: else:
lammps_model_compiled = jit.compile(lammps_model) lammps_model_compiled = jit.compile(lammps_model)
lammps_model_compiled.save(model_path + "-lammps.pt") lammps_model_compiled.save(model_path + "-lammps.pt")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
########################################################################################### ###########################################################################################
# Script for evaluating configurations contained in an xyz file with a trained model # Script for evaluating configurations contained in an xyz file with a trained model
# Authors: Ilyes Batatia, Gregor Simm # Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md) # This program is distributed under the MIT License (see MIT.md)
########################################################################################### ###########################################################################################
import argparse import argparse
import ase.data import ase.data
import ase.io import ase.io
import numpy as np import numpy as np
import torch import torch
from mace import data from mace import data
from mace.tools import torch_geometric, torch_tools, utils from mace.tools import torch_geometric, torch_tools, utils
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument("--configs", help="path to XYZ configurations", required=True) parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True) parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True) parser.add_argument("--output", help="output path", required=True)
parser.add_argument( parser.add_argument(
"--device", "--device",
help="select device", help="select device",
type=str, type=str,
choices=["cpu", "cuda"], choices=["cpu", "cuda"],
default="cpu", default="cpu",
) )
parser.add_argument( parser.add_argument(
"--default_dtype", "--default_dtype",
help="set default dtype", help="set default dtype",
type=str, type=str,
choices=["float32", "float64"], choices=["float32", "float64"],
default="float64", default="float64",
) )
parser.add_argument("--batch_size", help="batch size", type=int, default=64) parser.add_argument("--batch_size", help="batch size", type=int, default=64)
parser.add_argument( parser.add_argument(
"--compute_stress", "--compute_stress",
help="compute stress", help="compute stress",
action="store_true", action="store_true",
default=False, default=False,
) )
parser.add_argument( parser.add_argument(
"--return_contributions", "--return_contributions",
help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE", help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE",
action="store_true", action="store_true",
default=False, default=False,
) )
parser.add_argument( parser.add_argument(
"--info_prefix", "--info_prefix",
help="prefix for energy, forces and stress keys", help="prefix for energy, forces and stress keys",
type=str, type=str,
default="MACE_", default="MACE_",
) )
parser.add_argument( parser.add_argument(
"--head", "--head",
help="Model head used for evaluation", help="Model head used for evaluation",
type=str, type=str,
required=False, required=False,
default=None, default=None,
) )
return parser.parse_args() return parser.parse_args()
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
run(args) run(args)
def run(args: argparse.Namespace) -> None: def run(args: argparse.Namespace) -> None:
torch_tools.set_default_dtype(args.default_dtype) torch_tools.set_default_dtype(args.default_dtype)
device = torch_tools.init_device(args.device) device = torch_tools.init_device(args.device)
# Load model # Load model
model = torch.load(f=args.model, map_location=args.device) model = torch.load(f=args.model, map_location=args.device)
model = model.to( model = model.to(
args.device args.device
) # shouldn't be necessary but seems to help with CUDA problems ) # shouldn't be necessary but seems to help with CUDA problems
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
# Load data and prepare input # Load data and prepare input
atoms_list = ase.io.read(args.configs, index=":") atoms_list = ase.io.read(args.configs, index=":")
if args.head is not None: if args.head is not None:
for atoms in atoms_list: for atoms in atoms_list:
atoms.info["head"] = args.head atoms.info["head"] = args.head
configs = [data.config_from_atoms(atoms) for atoms in atoms_list] configs = [data.config_from_atoms(atoms) for atoms in atoms_list]
z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])
try: try:
heads = model.heads heads = model.heads
except AttributeError: except AttributeError:
heads = None heads = None
data_loader = torch_geometric.dataloader.DataLoader( data_loader = torch_geometric.dataloader.DataLoader(
dataset=[ dataset=[
data.AtomicData.from_config( data.AtomicData.from_config(
config, z_table=z_table, cutoff=float(model.r_max), heads=heads config, z_table=z_table, cutoff=float(model.r_max), heads=heads
) )
for config in configs for config in configs
], ],
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
) )
# Collect data # Collect data
energies_list = [] energies_list = []
contributions_list = [] contributions_list = []
stresses_list = [] stresses_list = []
forces_collection = [] forces_collection = []
for batch in data_loader: for batch in data_loader:
batch = batch.to(device) batch = batch.to(device)
output = model(batch.to_dict(), compute_stress=args.compute_stress) output = model(batch.to_dict(), compute_stress=args.compute_stress)
energies_list.append(torch_tools.to_numpy(output["energy"])) energies_list.append(torch_tools.to_numpy(output["energy"]))
if args.compute_stress: if args.compute_stress:
stresses_list.append(torch_tools.to_numpy(output["stress"])) stresses_list.append(torch_tools.to_numpy(output["stress"]))
if args.return_contributions: if args.return_contributions:
contributions_list.append(torch_tools.to_numpy(output["contributions"])) contributions_list.append(torch_tools.to_numpy(output["contributions"]))
forces = np.split( forces = np.split(
torch_tools.to_numpy(output["forces"]), torch_tools.to_numpy(output["forces"]),
indices_or_sections=batch.ptr[1:], indices_or_sections=batch.ptr[1:],
axis=0, axis=0,
) )
forces_collection.append(forces[:-1]) # drop last as its empty forces_collection.append(forces[:-1]) # drop last as its empty
energies = np.concatenate(energies_list, axis=0) energies = np.concatenate(energies_list, axis=0)
forces_list = [ forces_list = [
forces for forces_list in forces_collection for forces in forces_list forces for forces_list in forces_collection for forces in forces_list
] ]
assert len(atoms_list) == len(energies) == len(forces_list) assert len(atoms_list) == len(energies) == len(forces_list)
if args.compute_stress: if args.compute_stress:
stresses = np.concatenate(stresses_list, axis=0) stresses = np.concatenate(stresses_list, axis=0)
assert len(atoms_list) == stresses.shape[0] assert len(atoms_list) == stresses.shape[0]
if args.return_contributions: if args.return_contributions:
contributions = np.concatenate(contributions_list, axis=0) contributions = np.concatenate(contributions_list, axis=0)
assert len(atoms_list) == contributions.shape[0] assert len(atoms_list) == contributions.shape[0]
# Store data in atoms objects # Store data in atoms objects
for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)): for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)):
atoms.calc = None # crucial atoms.calc = None # crucial
atoms.info[args.info_prefix + "energy"] = energy atoms.info[args.info_prefix + "energy"] = energy
atoms.arrays[args.info_prefix + "forces"] = forces atoms.arrays[args.info_prefix + "forces"] = forces
if args.compute_stress: if args.compute_stress:
atoms.info[args.info_prefix + "stress"] = stresses[i] atoms.info[args.info_prefix + "stress"] = stresses[i]
if args.return_contributions: if args.return_contributions:
atoms.info[args.info_prefix + "BO_contributions"] = contributions[i] atoms.info[args.info_prefix + "BO_contributions"] = contributions[i]
# Write atoms to output path # Write atoms to output path
ase.io.write(args.output, images=atoms_list, format="extxyz") ase.io.write(args.output, images=atoms_list, format="extxyz")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
########################################################################################### ###########################################################################################
# This program is distributed under the MIT License (see MIT.md) # This program is distributed under the MIT License (see MIT.md)
########################################################################################### ###########################################################################################
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Tuple, Union from typing import List, Tuple, Union
import ase.data import ase.data
import ase.io import ase.io
import numpy as np import numpy as np
import torch import torch
from mace.calculators import MACECalculator, mace_mp from mace.calculators import MACECalculator, mace_mp
try: try:
import fpsample # type: ignore import fpsample # type: ignore
except ImportError: except ImportError:
pass pass
class FilteringType(Enum): class FilteringType(Enum):
NONE = "none" NONE = "none"
COMBINATIONS = "combinations" COMBINATIONS = "combinations"
EXCLUSIVE = "exclusive" EXCLUSIVE = "exclusive"
INCLUSIVE = "inclusive" INCLUSIVE = "inclusive"
class SubselectType(Enum): class SubselectType(Enum):
FPS = "fps" FPS = "fps"
RANDOM = "random" RANDOM = "random"
@dataclass @dataclass
class SelectionSettings: class SelectionSettings:
configs_pt: str configs_pt: str
output: str output: str
configs_ft: str | None = None configs_ft: str | None = None
atomic_numbers: List[int] | None = None atomic_numbers: List[int] | None = None
num_samples: int | None = None num_samples: int | None = None
subselect: SubselectType = SubselectType.FPS subselect: SubselectType = SubselectType.FPS
model: str = "small" model: str = "small"
descriptors: str | None = None descriptors: str | None = None
device: str = "cpu" device: str = "cpu"
default_dtype: str = "float64" default_dtype: str = "float64"
head_pt: str | None = None head_pt: str | None = None
head_ft: str | None = None head_ft: str | None = None
filtering_type: FilteringType = FilteringType.COMBINATIONS filtering_type: FilteringType = FilteringType.COMBINATIONS
weight_ft: float = 1.0 weight_ft: float = 1.0
weight_pt: float = 1.0 weight_pt: float = 1.0
seed: int = 42 seed: int = 42
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
"--configs_pt", "--configs_pt",
help="path to XYZ configurations for the pretraining", help="path to XYZ configurations for the pretraining",
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
"--configs_ft", "--configs_ft",
help="path or list of paths to XYZ configurations for the finetuning", help="path or list of paths to XYZ configurations for the finetuning",
required=False, required=False,
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
help="number of samples to select for the pretraining", help="number of samples to select for the pretraining",
type=int, type=int,
required=False, required=False,
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--subselect", "--subselect",
help="method to subselect the configurations of the pretraining set", help="method to subselect the configurations of the pretraining set",
type=SubselectType, type=SubselectType,
choices=list(SubselectType), choices=list(SubselectType),
default=SubselectType.FPS, default=SubselectType.FPS,
) )
parser.add_argument( parser.add_argument(
"--model", help="path to model", default="small", required=False "--model", help="path to model", default="small", required=False
) )
parser.add_argument("--output", help="output path", required=True) parser.add_argument("--output", help="output path", required=True)
parser.add_argument( parser.add_argument(
"--descriptors", help="path to descriptors", required=False, default=None "--descriptors", help="path to descriptors", required=False, default=None
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
help="select device", help="select device",
type=str, type=str,
choices=["cpu", "cuda"], choices=["cpu", "cuda"],
default="cpu", default="cpu",
) )
parser.add_argument( parser.add_argument(
"--default_dtype", "--default_dtype",
help="set default dtype", help="set default dtype",
type=str, type=str,
choices=["float32", "float64"], choices=["float32", "float64"],
default="float64", default="float64",
) )
parser.add_argument( parser.add_argument(
"--head_pt", "--head_pt",
help="level of head for the pretraining set", help="level of head for the pretraining set",
type=str, type=str,
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--head_ft", "--head_ft",
help="level of head for the finetuning set", help="level of head for the finetuning set",
type=str, type=str,
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--filtering_type", "--filtering_type",
help="filtering type", help="filtering type",
type=FilteringType, type=FilteringType,
choices=list(FilteringType), choices=list(FilteringType),
default=FilteringType.NONE, default=FilteringType.NONE,
) )
parser.add_argument( parser.add_argument(
"--weight_ft", "--weight_ft",
help="weight for the finetuning set", help="weight for the finetuning set",
type=float, type=float,
default=1.0, default=1.0,
) )
parser.add_argument( parser.add_argument(
"--weight_pt", "--weight_pt",
help="weight for the pretraining set", help="weight for the pretraining set",
type=float, type=float,
default=1.0, default=1.0,
) )
parser.add_argument("--seed", help="random seed", type=int, default=42) parser.add_argument("--seed", help="random seed", type=int, default=42)
return parser.parse_args() return parser.parse_args()
def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None:
logging.info("Calculating descriptors") logging.info("Calculating descriptors")
for mol in atoms: for mol in atoms:
descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) descriptors = calc.get_descriptors(mol.copy(), invariants_only=True)
# average descriptors over atoms for each element # average descriptors over atoms for each element
descriptors_dict = { descriptors_dict = {
element: np.mean(descriptors[mol.symbols == element], axis=0) element: np.mean(descriptors[mol.symbols == element], axis=0)
for element in np.unique(mol.symbols) for element in np.unique(mol.symbols)
} }
mol.info["mace_descriptors"] = descriptors_dict mol.info["mace_descriptors"] = descriptors_dict
def filter_atoms( def filter_atoms(
atoms: ase.Atoms, atoms: ase.Atoms,
element_subset: List[str], element_subset: List[str],
filtering_type: FilteringType = FilteringType.COMBINATIONS, filtering_type: FilteringType = FilteringType.COMBINATIONS,
) -> bool: ) -> bool:
""" """
Filters atoms based on the provided filtering type and element subset. Filters atoms based on the provided filtering type and element subset.
Parameters: Parameters:
atoms (ase.Atoms): The atoms object to filter. atoms (ase.Atoms): The atoms object to filter.
element_subset (list): The list of elements to consider during filtering. element_subset (list): The list of elements to consider during filtering.
filtering_type (FilteringType): The type of filtering to apply. filtering_type (FilteringType): The type of filtering to apply.
Can be one of the following `FilteringType` enum members: Can be one of the following `FilteringType` enum members:
- `FilteringType.NONE`: No filtering is applied. - `FilteringType.NONE`: No filtering is applied.
- `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. - `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present.
- `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise. - `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise.
- `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. - `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements.
Returns: Returns:
bool: True if the atoms pass the filter, False otherwise. bool: True if the atoms pass the filter, False otherwise.
""" """
if filtering_type == FilteringType.NONE: if filtering_type == FilteringType.NONE:
return True return True
if filtering_type == FilteringType.COMBINATIONS: if filtering_type == FilteringType.COMBINATIONS:
atom_symbols = np.unique(atoms.symbols) atom_symbols = np.unique(atoms.symbols)
return all( return all(
x in element_subset for x in atom_symbols x in element_subset for x in atom_symbols
) # atoms must *only* contain elements in the subset ) # atoms must *only* contain elements in the subset
if filtering_type == FilteringType.EXCLUSIVE: if filtering_type == FilteringType.EXCLUSIVE:
atom_symbols = set(list(atoms.symbols)) atom_symbols = set(list(atoms.symbols))
return atom_symbols == set(element_subset) return atom_symbols == set(element_subset)
if filtering_type == FilteringType.INCLUSIVE: if filtering_type == FilteringType.INCLUSIVE:
atom_symbols = np.unique(atoms.symbols) atom_symbols = np.unique(atoms.symbols)
return all( return all(
x in atom_symbols for x in element_subset x in atom_symbols for x in element_subset
) # atoms must *at least* contain elements in the subset ) # atoms must *at least* contain elements in the subset
raise ValueError( raise ValueError(
f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}." f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}."
) )
class FPS: class FPS:
def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): def __init__(self, atoms_list: List[ase.Atoms], n_samples: int):
self.n_samples = n_samples self.n_samples = n_samples
self.atoms_list = atoms_list self.atoms_list = atoms_list
self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore
self.species_dict = {x: i for i, x in enumerate(self.species)} self.species_dict = {x: i for i, x in enumerate(self.species)}
# start from a random configuration # start from a random configuration
self.list_index = [np.random.randint(0, len(atoms_list))] self.list_index = [np.random.randint(0, len(atoms_list))]
self.assemble_descriptors() self.assemble_descriptors()
def run( def run(
self, self,
) -> List[int]: ) -> List[int]:
""" """
Run the farthest point sampling algorithm. Run the farthest point sampling algorithm.
""" """
descriptor_dataset_reshaped = ( descriptor_dataset_reshaped = (
self.descriptors_dataset.reshape( # pylint: disable=E1121 self.descriptors_dataset.reshape( # pylint: disable=E1121
(len(self.atoms_list), -1) (len(self.atoms_list), -1)
) )
) )
logging.info(f"{descriptor_dataset_reshaped.shape}") logging.info(f"{descriptor_dataset_reshaped.shape}")
logging.info(f"n_samples: {self.n_samples}") logging.info(f"n_samples: {self.n_samples}")
self.list_index = fpsample.fps_npdu_kdtree_sampling( self.list_index = fpsample.fps_npdu_kdtree_sampling(
descriptor_dataset_reshaped, descriptor_dataset_reshaped,
self.n_samples, self.n_samples,
) )
return self.list_index return self.list_index
def assemble_descriptors(self) -> None: def assemble_descriptors(self) -> None:
""" """
Assemble the descriptors for all the configurations. Assemble the descriptors for all the configurations.
""" """
self.descriptors_dataset: np.ndarray = 10e10 * np.ones( self.descriptors_dataset: np.ndarray = 10e10 * np.ones(
( (
len(self.atoms_list), len(self.atoms_list),
len(self.species), len(self.species),
len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]),
), ),
dtype=np.float32, dtype=np.float32,
).astype(np.float32) ).astype(np.float32)
for i, atoms in enumerate(self.atoms_list): for i, atoms in enumerate(self.atoms_list):
descriptors = atoms.info["mace_descriptors"] descriptors = atoms.info["mace_descriptors"]
for z in descriptors: for z in descriptors:
self.descriptors_dataset[i, self.species_dict[z]] = np.array( self.descriptors_dataset[i, self.species_dict[z]] = np.array(
descriptors[z] descriptors[z]
).astype(np.float32) ).astype(np.float32)
def _load_calc( def _load_calc(
model: str, device: str, default_dtype: str, subselect: SubselectType model: str, device: str, default_dtype: str, subselect: SubselectType
) -> Union[MACECalculator, None]: ) -> Union[MACECalculator, None]:
if subselect == SubselectType.RANDOM: if subselect == SubselectType.RANDOM:
return None return None
if model in ["small", "medium", "large"]: if model in ["small", "medium", "large"]:
calc = mace_mp(model, device=device, default_dtype=default_dtype) calc = mace_mp(model, device=device, default_dtype=default_dtype)
else: else:
calc = MACECalculator( calc = MACECalculator(
model_paths=model, model_paths=model,
device=device, device=device,
default_dtype=default_dtype, default_dtype=default_dtype,
) )
return calc return calc
def _get_finetuning_elements( def _get_finetuning_elements(
atoms: List[ase.Atoms], atomic_numbers: List[int] | None atoms: List[ase.Atoms], atomic_numbers: List[int] | None
) -> List[str]: ) -> List[str]:
if atoms: if atoms:
logging.debug( logging.debug(
"Using elements from the finetuning configurations for filtering." "Using elements from the finetuning configurations for filtering."
) )
species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore
elif atomic_numbers is not None and atomic_numbers: elif atomic_numbers is not None and atomic_numbers:
logging.debug("Using the supplied atomic numbers for filtering.") logging.debug("Using the supplied atomic numbers for filtering.")
species = [ase.data.chemical_symbols[z] for z in atomic_numbers] species = [ase.data.chemical_symbols[z] for z in atomic_numbers]
else: else:
species = [] species = []
return species return species
def _read_finetuning_configs( def _read_finetuning_configs(
configs_ft: Union[str, list[str], None], configs_ft: Union[str, list[str], None],
) -> List[ase.Atoms]: ) -> List[ase.Atoms]:
if isinstance(configs_ft, str): if isinstance(configs_ft, str):
path = configs_ft path = configs_ft
return ase.io.read(path, index=":") # type: ignore return ase.io.read(path, index=":") # type: ignore
if isinstance(configs_ft, list): if isinstance(configs_ft, list):
assert all(isinstance(x, str) for x in configs_ft) assert all(isinstance(x, str) for x in configs_ft)
atoms_list_ft = [] atoms_list_ft = []
for path in configs_ft: for path in configs_ft:
atoms_list_ft += ase.io.read(path, index=":") atoms_list_ft += ase.io.read(path, index=":")
return atoms_list_ft return atoms_list_ft
if configs_ft is None: if configs_ft is None:
return [] return []
raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}") raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}")
def _filter_pretraining_data( def _filter_pretraining_data(
atoms: list[ase.Atoms], atoms: list[ase.Atoms],
filtering_type: FilteringType, filtering_type: FilteringType,
all_species_ft: List[str], all_species_ft: List[str],
) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]: ) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]:
logging.info( logging.info(
"Filtering configurations based on the finetuning set, " "Filtering configurations based on the finetuning set, "
f"filtering type: {filtering_type}, elements: {all_species_ft}" f"filtering type: {filtering_type}, elements: {all_species_ft}"
) )
passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms] passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms]
assert len(passes_filter) == len(atoms), "Filtering failed" assert len(passes_filter) == len(atoms), "Filtering failed"
filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes] filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes]
remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes] remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes]
return filtered_atoms, remaining_atoms, passes_filter return filtered_atoms, remaining_atoms, passes_filter
def _get_random_configs( def _get_random_configs(
num_samples: int, num_samples: int,
atoms: List[ase.Atoms], atoms: List[ase.Atoms],
) -> list[ase.Atoms]: ) -> list[ase.Atoms]:
if num_samples > len(atoms): if num_samples > len(atoms):
raise ValueError( raise ValueError(
f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})" f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})"
) )
indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False) indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False)
return [atoms[i] for i in indices] return [atoms[i] for i in indices]
def _load_descriptors( def _load_descriptors(
atoms: List[ase.Atoms], atoms: List[ase.Atoms],
passes_filter: List[bool], passes_filter: List[bool],
descriptors_path: str | None, descriptors_path: str | None,
calc: MACECalculator | None, calc: MACECalculator | None,
full_data_length: int, full_data_length: int,
) -> None: ) -> None:
if descriptors_path is not None: if descriptors_path is not None:
logging.info(f"Loading descriptors from {descriptors_path}") logging.info(f"Loading descriptors from {descriptors_path}")
descriptors = np.load(descriptors_path, allow_pickle=True) descriptors = np.load(descriptors_path, allow_pickle=True)
assert sum(passes_filter) == len(atoms) assert sum(passes_filter) == len(atoms)
if len(descriptors) != full_data_length: if len(descriptors) != full_data_length:
raise ValueError( raise ValueError(
f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})" f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})"
"Please provide descriptors for all configurations" "Please provide descriptors for all configurations"
) )
required_descriptors = [ required_descriptors = [
descriptors[i] for i, passes in enumerate(passes_filter) if passes descriptors[i] for i, passes in enumerate(passes_filter) if passes
] ]
for i, atoms_ in enumerate(atoms): for i, atoms_ in enumerate(atoms):
atoms_.info["mace_descriptors"] = required_descriptors[i] atoms_.info["mace_descriptors"] = required_descriptors[i]
else: else:
logging.info("Calculating descriptors") logging.info("Calculating descriptors")
if calc is None: if calc is None:
raise ValueError("MACECalculator must be provided to calculate descriptors") raise ValueError("MACECalculator must be provided to calculate descriptors")
calculate_descriptors(atoms, calc) calculate_descriptors(atoms, calc)
def _maybe_save_descriptors( def _maybe_save_descriptors(
atoms: List[ase.Atoms], atoms: List[ase.Atoms],
output_path: str, output_path: str,
) -> None: ) -> None:
""" """
Save the descriptors if they are present in the atoms objects. Save the descriptors if they are present in the atoms objects.
Also, delete the descriptors from the atoms objects. Also, delete the descriptors from the atoms objects.
""" """
if all("mace_descriptors" in x.info for x in atoms): if all("mace_descriptors" in x.info for x in atoms):
descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy") descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy")
logging.info(f"Saving descriptors at {descriptor_save_path}") logging.info(f"Saving descriptors at {descriptor_save_path}")
descriptors_list = [x.info["mace_descriptors"] for x in atoms] descriptors_list = [x.info["mace_descriptors"] for x in atoms]
np.save(descriptor_save_path, descriptors_list, allow_pickle=True) np.save(descriptor_save_path, descriptors_list, allow_pickle=True)
for x in atoms: for x in atoms:
del x.info["mace_descriptors"] del x.info["mace_descriptors"]
def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]: def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]:
try: try:
fps_pt = FPS(atoms, num_samples) fps_pt = FPS(atoms, num_samples)
idx_pt = fps_pt.run() idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations") logging.info(f"Selected {len(idx_pt)} configurations")
return [atoms[i] for i in idx_pt] return [atoms[i] for i in idx_pt]
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
logging.error(f"FPS failed, selecting random configurations instead: {e}") logging.error(f"FPS failed, selecting random configurations instead: {e}")
return _get_random_configs(num_samples, atoms) return _get_random_configs(num_samples, atoms)
def _subsample_data( def _subsample_data(
filtered_atoms: List[ase.Atoms], filtered_atoms: List[ase.Atoms],
remaining_atoms: List[ase.Atoms], remaining_atoms: List[ase.Atoms],
passes_filter: List[bool], passes_filter: List[bool],
num_samples: int | None, num_samples: int | None,
subselect: SubselectType, subselect: SubselectType,
descriptors_path: str | None, descriptors_path: str | None,
calc: MACECalculator | None, calc: MACECalculator | None,
) -> List[ase.Atoms]: ) -> List[ase.Atoms]:
if num_samples is None or num_samples == len(filtered_atoms): if num_samples is None or num_samples == len(filtered_atoms):
logging.info( logging.info(
f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations" f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations"
) )
return filtered_atoms return filtered_atoms
if num_samples > len(filtered_atoms): if num_samples > len(filtered_atoms):
num_sample_randomly = num_samples - len(filtered_atoms) num_sample_randomly = num_samples - len(filtered_atoms)
logging.info( logging.info(
f"Number of configurations after filtering {len(filtered_atoms)} " f"Number of configurations after filtering {len(filtered_atoms)} "
f"is less than the number of samples {num_samples}, " f"is less than the number of samples {num_samples}, "
f"selecting {num_sample_randomly} random configurations for the rest." f"selecting {num_sample_randomly} random configurations for the rest."
) )
return filtered_atoms + _get_random_configs( return filtered_atoms + _get_random_configs(
num_sample_randomly, remaining_atoms num_sample_randomly, remaining_atoms
) )
if num_samples == 0: if num_samples == 0:
raise ValueError("Number of samples must be greater than 0") raise ValueError("Number of samples must be greater than 0")
if subselect == SubselectType.FPS: if subselect == SubselectType.FPS:
_load_descriptors( _load_descriptors(
filtered_atoms, filtered_atoms,
passes_filter, passes_filter,
descriptors_path, descriptors_path,
calc, calc,
full_data_length=len(filtered_atoms) + len(remaining_atoms), full_data_length=len(filtered_atoms) + len(remaining_atoms),
) )
logging.info("Selecting configurations using Farthest Point Sampling") logging.info("Selecting configurations using Farthest Point Sampling")
return _maybe_fps(filtered_atoms, num_samples) return _maybe_fps(filtered_atoms, num_samples)
if subselect == SubselectType.RANDOM: if subselect == SubselectType.RANDOM:
return _get_random_configs(num_samples, filtered_atoms) return _get_random_configs(num_samples, filtered_atoms)
raise ValueError(f"Invalid subselect type: {subselect}") raise ValueError(f"Invalid subselect type: {subselect}")
def _write_metadata( def _write_metadata(
atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None
) -> None: ) -> None:
for a in atoms: for a in atoms:
a.info["pretrained"] = pretrained a.info["pretrained"] = pretrained
a.info["config_weight"] = config_weight a.info["config_weight"] = config_weight
if head is not None: if head is not None:
a.info["head"] = head a.info["head"] = head
def select_samples( def select_samples(
settings: SelectionSettings, settings: SelectionSettings,
) -> None: ) -> None:
np.random.seed(settings.seed) np.random.seed(settings.seed)
torch.manual_seed(settings.seed) torch.manual_seed(settings.seed)
calc = _load_calc( calc = _load_calc(
settings.model, settings.device, settings.default_dtype, settings.subselect settings.model, settings.device, settings.default_dtype, settings.subselect
) )
atoms_list_ft = _read_finetuning_configs(settings.configs_ft) atoms_list_ft = _read_finetuning_configs(settings.configs_ft)
all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers) all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers)
if settings.filtering_type is not FilteringType.NONE and not all_species_ft: if settings.filtering_type is not FilteringType.NONE and not all_species_ft:
raise ValueError( raise ValueError(
"Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag." "Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag."
) )
atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore
filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data( filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data(
atoms_list_pt, settings.filtering_type, all_species_ft atoms_list_pt, settings.filtering_type, all_species_ft
) )
subsampled_atoms = _subsample_data( subsampled_atoms = _subsample_data(
filtered_pt_atoms, filtered_pt_atoms,
remaining_atoms, remaining_atoms,
passes_filter, passes_filter,
settings.num_samples, settings.num_samples,
settings.subselect, settings.subselect,
settings.descriptors, settings.descriptors,
calc, calc,
) )
_maybe_save_descriptors(subsampled_atoms, settings.output) _maybe_save_descriptors(subsampled_atoms, settings.output)
_write_metadata( _write_metadata(
subsampled_atoms, subsampled_atoms,
pretrained=True, pretrained=True,
config_weight=settings.weight_pt, config_weight=settings.weight_pt,
head=settings.head_pt, head=settings.head_pt,
) )
_write_metadata( _write_metadata(
atoms_list_ft, atoms_list_ft,
pretrained=False, pretrained=False,
config_weight=settings.weight_ft, config_weight=settings.weight_ft,
head=settings.head_ft, head=settings.head_ft,
) )
logging.info("Saving the selected configurations") logging.info("Saving the selected configurations")
ase.io.write(settings.output, subsampled_atoms, format="extxyz") ase.io.write(settings.output, subsampled_atoms, format="extxyz")
logging.info("Saving a combined XYZ file") logging.info("Saving a combined XYZ file")
atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft
ase.io.write( ase.io.write(
settings.output.replace(".xyz", "_combined.xyz"), settings.output.replace(".xyz", "_combined.xyz"),
atoms_fps_pt_ft, atoms_fps_pt_ft,
format="extxyz", format="extxyz",
) )
def main(): def main():
args = parse_args() args = parse_args()
settings = SelectionSettings(**vars(args)) settings = SelectionSettings(**vars(args))
select_samples(settings) select_samples(settings)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import argparse import argparse
import dataclasses import dataclasses
import glob import glob
import json import json
import os import os
import re import re
from typing import List from typing import List
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
plt.rcParams.update({"font.size": 8}) plt.rcParams.update({"font.size": 8})
plt.style.use("seaborn-v0_8-paper") plt.style.use("seaborn-v0_8-paper")
colors = [ colors = [
"#1f77b4", # muted blue "#1f77b4", # muted blue
"#d62728", # brick red "#d62728", # brick red
"#ff7f0e", # safety orange "#ff7f0e", # safety orange
"#2ca02c", # cooked asparagus green "#2ca02c", # cooked asparagus green
"#9467bd", # muted purple "#9467bd", # muted purple
"#8c564b", # chestnut brown "#8c564b", # chestnut brown
"#e377c2", # raspberry yogurt pink "#e377c2", # raspberry yogurt pink
"#7f7f7f", # middle gray "#7f7f7f", # middle gray
"#bcbd22", # curry yellow-green "#bcbd22", # curry yellow-green
"#17becf", # blue-teal "#17becf", # blue-teal
] ]
@dataclasses.dataclass @dataclasses.dataclass
class RunInfo: class RunInfo:
name: str name: str
seed: int seed: int
name_re = re.compile(r"(?P<name>.+)_run-(?P<seed>\d+)_train.txt") name_re = re.compile(r"(?P<name>.+)_run-(?P<seed>\d+)_train.txt")
def parse_path(path: str) -> RunInfo: def parse_path(path: str) -> RunInfo:
match = name_re.match(os.path.basename(path)) match = name_re.match(os.path.basename(path))
if not match: if not match:
raise RuntimeError(f"Cannot parse {path}") raise RuntimeError(f"Cannot parse {path}")
return RunInfo(name=match.group("name"), seed=int(match.group("seed"))) return RunInfo(name=match.group("name"), seed=int(match.group("seed")))
def parse_training_results(path: str) -> List[dict]: def parse_training_results(path: str) -> List[dict]:
run_info = parse_path(path) run_info = parse_path(path)
results = [] results = []
with open(path, mode="r", encoding="utf-8") as f: with open(path, mode="r", encoding="utf-8") as f:
for line in f: for line in f:
d = json.loads(line) d = json.loads(line)
d["name"] = run_info.name d["name"] = run_info.name
d["seed"] = run_info.seed d["seed"] = run_info.seed
results.append(d) results.append(d)
return results return results
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Plot mace training statistics", description="Plot mace training statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
"--path", help="Path to results file (.txt) or directory.", required=True "--path", help="Path to results file (.txt) or directory.", required=True
) )
parser.add_argument( parser.add_argument(
"--min_epoch", help="Minimum epoch.", default=0, type=int, required=False "--min_epoch", help="Minimum epoch.", default=0, type=int, required=False
) )
parser.add_argument( parser.add_argument(
"--start_stage_two", "--start_stage_two",
"--start_swa", "--start_swa",
help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.", help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.",
default=None, default=None,
type=int, type=int,
required=False, required=False,
dest="start_swa", dest="start_swa",
) )
parser.add_argument( parser.add_argument(
"--linear", "--linear",
help="Whether to plot linear instead of log scales.", help="Whether to plot linear instead of log scales.",
default=False, default=False,
required=False, required=False,
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--error_bars", "--error_bars",
help="Whether to plot standard deviations.", help="Whether to plot standard deviations.",
default=False, default=False,
required=False, required=False,
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--keys", "--keys",
help="Comma-separated list of keys to plot.", help="Comma-separated list of keys to plot.",
default="rmse_e,rmse_f", default="rmse_e,rmse_f",
type=str, type=str,
required=False, required=False,
) )
parser.add_argument( parser.add_argument(
"--output_format", "--output_format",
help="What file type to save plot as", help="What file type to save plot as",
default="png", default="png",
type=str, type=str,
required=False, required=False,
) )
parser.add_argument( parser.add_argument(
"--heads", "--heads",
help="Comma-separated name of the heads used for multihead training", help="Comma-separated name of the heads used for multihead training",
default=None, default=None,
type=str, type=str,
required=False, required=False,
) )
return parser.parse_args() return parser.parse_args()
def plot( def plot(
data: pd.DataFrame, data: pd.DataFrame,
min_epoch: int, min_epoch: int,
output_path: str, output_path: str,
output_format: str, output_format: str,
linear: bool, linear: bool,
start_swa: int, start_swa: int,
error_bars: bool, error_bars: bool,
keys: str, keys: str,
heads: str, heads: str,
) -> None: ) -> None:
""" """
Plots train,validation loss and errors as a function of epoch. Plots train,validation loss and errors as a function of epoch.
min_epoch: minimum epoch to plot. min_epoch: minimum epoch to plot.
output_path: path to save the plot. output_path: path to save the plot.
output_format: format to save the plot. output_format: format to save the plot.
start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins. start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins.
error_bars: whether to plot standard deviation of loss. error_bars: whether to plot standard deviation of loss.
linear: whether to plot in linear scale or logscale (default). linear: whether to plot in linear scale or logscale (default).
keys: Values to plot. keys: Values to plot.
heads: Heads used for multihead training. heads: Heads used for multihead training.
""" """
labels = { labels = {
"mae_e": "MAE E [meV]", "mae_e": "MAE E [meV]",
"mae_e_per_atom": "MAE E/atom [meV]", "mae_e_per_atom": "MAE E/atom [meV]",
"rmse_e": "RMSE E [meV]", "rmse_e": "RMSE E [meV]",
"rmse_e_per_atom": "RMSE E/atom [meV]", "rmse_e_per_atom": "RMSE E/atom [meV]",
"q95_e": "Q95 E [meV]", "q95_e": "Q95 E [meV]",
"mae_f": "MAE F [meV / A]", "mae_f": "MAE F [meV / A]",
"rel_mae_f": "Relative MAE F [meV / A]", "rel_mae_f": "Relative MAE F [meV / A]",
"rmse_f": "RMSE F [meV / A]", "rmse_f": "RMSE F [meV / A]",
"rel_rmse_f": "Relative RMSE F [meV / A]", "rel_rmse_f": "Relative RMSE F [meV / A]",
"q95_f": "Q95 F [meV / A]", "q95_f": "Q95 F [meV / A]",
"mae_stress": "MAE Stress", "mae_stress": "MAE Stress",
"rmse_stress": "RMSE Stress [meV / A^3]", "rmse_stress": "RMSE Stress [meV / A^3]",
"rmse_virials_per_atom": " RMSE virials/atom [meV]", "rmse_virials_per_atom": " RMSE virials/atom [meV]",
"mae_virials": "MAE Virials [meV]", "mae_virials": "MAE Virials [meV]",
"rmse_mu_per_atom": "RMSE MU/atom [mDebye]", "rmse_mu_per_atom": "RMSE MU/atom [mDebye]",
} }
data = data[data["epoch"] > min_epoch] data = data[data["epoch"] > min_epoch]
if heads is None: if heads is None:
data = ( data = (
data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index() data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index()
) )
valid_data = data[data["mode"] == "eval"] valid_data = data[data["mode"] == "eval"]
valid_data_dict = {"default": valid_data} valid_data_dict = {"default": valid_data}
train_data = data[data["mode"] == "opt"] train_data = data[data["mode"] == "opt"]
else: else:
heads = heads.split(",") heads = heads.split(",")
# Separate eval and opt data # Separate eval and opt data
valid_data = ( valid_data = (
data[data["mode"] == "eval"] data[data["mode"] == "eval"]
.groupby(["name", "mode", "epoch", "head"]) .groupby(["name", "mode", "epoch", "head"])
.agg(["mean", "std"]) .agg(["mean", "std"])
.reset_index() .reset_index()
) )
train_data = ( train_data = (
data[data["mode"] == "opt"] data[data["mode"] == "opt"]
.groupby(["name", "mode", "epoch"]) .groupby(["name", "mode", "epoch"])
.agg(["mean", "std"]) .agg(["mean", "std"])
.reset_index() .reset_index()
) )
valid_data_dict = { valid_data_dict = {
head: valid_data[valid_data["head"] == head] for head in heads head: valid_data[valid_data["head"] == head] for head in heads
} }
for head, valid_data in valid_data_dict.items(): for head, valid_data in valid_data_dict.items():
fig, axes = plt.subplots( fig, axes = plt.subplots(
nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True
) )
# ---- Plot loss ---- # ---- Plot loss ----
ax = axes[0] ax = axes[0]
ax.plot( ax.plot(
train_data["epoch"], train_data["epoch"],
train_data["loss"]["mean"], train_data["loss"]["mean"],
color=colors[1], color=colors[1],
linewidth=1, linewidth=1,
) )
ax.set_ylabel("Training Loss", color=colors[1]) ax.set_ylabel("Training Loss", color=colors[1])
ax.set_yscale("log") ax.set_yscale("log")
ax2 = ax.twinx() ax2 = ax.twinx()
ax2.plot( ax2.plot(
valid_data["epoch"], valid_data["epoch"],
valid_data["loss"]["mean"], valid_data["loss"]["mean"],
color=colors[0], color=colors[0],
linewidth=1, linewidth=1,
) )
ax2.set_ylabel("Validation Loss", color=colors[0]) ax2.set_ylabel("Validation Loss", color=colors[0])
if not linear: if not linear:
ax.set_yscale("log") ax.set_yscale("log")
ax2.set_yscale("log") ax2.set_yscale("log")
if error_bars: if error_bars:
ax.fill_between( ax.fill_between(
train_data["epoch"], train_data["epoch"],
train_data["loss"]["mean"] - train_data["loss"]["std"], train_data["loss"]["mean"] - train_data["loss"]["std"],
train_data["loss"]["mean"] + train_data["loss"]["std"], train_data["loss"]["mean"] + train_data["loss"]["std"],
alpha=0.3, alpha=0.3,
color=colors[1], color=colors[1],
) )
ax.fill_between( ax.fill_between(
valid_data["epoch"], valid_data["epoch"],
valid_data["loss"]["mean"] - valid_data["loss"]["std"], valid_data["loss"]["mean"] - valid_data["loss"]["std"],
valid_data["loss"]["mean"] + valid_data["loss"]["std"], valid_data["loss"]["mean"] + valid_data["loss"]["std"],
alpha=0.3, alpha=0.3,
color=colors[0], color=colors[0],
) )
if start_swa is not None: if start_swa is not None:
ax.axvline( ax.axvline(
start_swa, start_swa,
color="black", color="black",
linestyle="dashed", linestyle="dashed",
linewidth=1, linewidth=1,
alpha=0.6, alpha=0.6,
label="Stage Two Starts", label="Stage Two Starts",
) )
ax.set_xlabel("Epoch") ax.set_xlabel("Epoch")
ax.set_ylabel("Loss") ax.set_ylabel("Loss")
ax.legend(loc="upper right", fontsize=4) ax.legend(loc="upper right", fontsize=4)
ax.grid(True, linestyle="--", alpha=0.5) ax.grid(True, linestyle="--", alpha=0.5)
# ---- Plot selected keys ---- # ---- Plot selected keys ----
ax = axes[1] ax = axes[1]
twin_axes = [] twin_axes = []
for i, key in enumerate(keys.split(",")): for i, key in enumerate(keys.split(",")):
color = colors[(i + 3)] color = colors[(i + 3)]
label = labels.get(key, key) label = labels.get(key, key)
if i == 0: if i == 0:
main_ax = ax main_ax = ax
else: else:
main_ax = ax.twinx() main_ax = ax.twinx()
main_ax.spines.right.set_position(("outward", 40 * (i - 1))) main_ax.spines.right.set_position(("outward", 40 * (i - 1)))
twin_axes.append(main_ax) twin_axes.append(main_ax)
main_ax.plot( main_ax.plot(
valid_data["epoch"], valid_data["epoch"],
valid_data[key]["mean"] * 1e3, valid_data[key]["mean"] * 1e3,
color=color, color=color,
label=label, label=label,
linewidth=1, linewidth=1,
) )
if error_bars: if error_bars:
main_ax.fill_between( main_ax.fill_between(
valid_data["epoch"], valid_data["epoch"],
(valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3, (valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3,
(valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3, (valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3,
alpha=0.3, alpha=0.3,
color=color, color=color,
) )
main_ax.set_ylabel(label, color=color) main_ax.set_ylabel(label, color=color)
main_ax.tick_params(axis="y", colors=color) main_ax.tick_params(axis="y", colors=color)
if start_swa is not None: if start_swa is not None:
ax.axvline( ax.axvline(
start_swa, start_swa,
color="black", color="black",
linestyle="dashed", linestyle="dashed",
linewidth=1, linewidth=1,
alpha=0.6, alpha=0.6,
label="Stage Two Starts", label="Stage Two Starts",
) )
ax.set_xlabel("Epoch") ax.set_xlabel("Epoch")
ax.set_xlim(left=min_epoch) ax.set_xlim(left=min_epoch)
ax.grid(True, linestyle="--", alpha=0.5) ax.grid(True, linestyle="--", alpha=0.5)
fig.savefig( fig.savefig(
f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight" f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight"
) )
plt.close(fig) plt.close(fig)
def get_paths(path: str) -> List[str]: def get_paths(path: str) -> List[str]:
if os.path.isfile(path): if os.path.isfile(path):
return [path] return [path]
paths = glob.glob(os.path.join(path, "*_train.txt")) paths = glob.glob(os.path.join(path, "*_train.txt"))
if len(paths) == 0: if len(paths) == 0:
raise RuntimeError(f"Cannot find results in '{path}'") raise RuntimeError(f"Cannot find results in '{path}'")
return paths return paths
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
run(args) run(args)
def run(args: argparse.Namespace) -> None: def run(args: argparse.Namespace) -> None:
data = pd.DataFrame( data = pd.DataFrame(
results results
for path in get_paths(args.path) for path in get_paths(args.path)
for results in parse_training_results(path) for results in parse_training_results(path)
) )
for name, group in data.groupby("name"): for name, group in data.groupby("name"):
plot( plot(
group, group,
min_epoch=args.min_epoch, min_epoch=args.min_epoch,
output_path=name, output_path=name,
output_format=args.output_format, output_format=args.output_format,
linear=args.linear, linear=args.linear,
start_swa=args.start_swa, start_swa=args.start_swa,
error_bars=args.error_bars, error_bars=args.error_bars,
keys=args.keys, keys=args.keys,
heads=args.heads, heads=args.heads,
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# This file loads an xyz dataset and prepares # This file loads an xyz dataset and prepares
# new hdf5 file that is ready for training with on-the-fly dataloading # new hdf5 file that is ready for training with on-the-fly dataloading
import argparse import argparse
import ast import ast
import json import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import random import random
from functools import partial from functools import partial
from glob import glob from glob import glob
from typing import List, Tuple from typing import List, Tuple
import h5py import h5py
import numpy as np import numpy as np
import tqdm import tqdm
from mace import data, tools from mace import data, tools
from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.data.utils import save_configurations_as_HDF5 from mace.data.utils import save_configurations_as_HDF5
from mace.modules import compute_statistics from mace.modules import compute_statistics
from mace.tools import torch_geometric from mace.tools import torch_geometric
from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz
from mace.tools.utils import AtomicNumberTable from mace.tools.utils import AtomicNumberTable
def compute_stats_target( def compute_stats_target(
file: str, file: str,
z_table: AtomicNumberTable, z_table: AtomicNumberTable,
r_max: float, r_max: float,
atomic_energies: Tuple, atomic_energies: Tuple,
batch_size: int, batch_size: int,
): ):
train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max) train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max)
train_loader = torch_geometric.dataloader.DataLoader( train_loader = torch_geometric.dataloader.DataLoader(
dataset=train_dataset, dataset=train_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
) )
avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies) avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies)
output = [avg_num_neighbors, mean, std] output = [avg_num_neighbors, mean, std]
return output return output
def pool_compute_stats(inputs: List): def pool_compute_stats(inputs: List):
path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs
with mp.Pool(processes=num_process) as pool: with mp.Pool(processes=num_process) as pool:
re = [ re = [
pool.apply_async( pool.apply_async(
compute_stats_target, compute_stats_target,
args=( args=(
file, file,
z_table, z_table,
r_max, r_max,
atomic_energies, atomic_energies,
batch_size, batch_size,
), ),
) )
for file in glob(path_to_files + "/*") for file in glob(path_to_files + "/*")
] ]
pool.close() pool.close()
pool.join() pool.join()
results = [r.get() for r in tqdm.tqdm(re)] results = [r.get() for r in tqdm.tqdm(re)]
if not results: if not results:
raise ValueError( raise ValueError(
"No results were computed. Check if the input files exist and are readable." "No results were computed. Check if the input files exist and are readable."
) )
# Separate avg_num_neighbors, mean, and std # Separate avg_num_neighbors, mean, and std
avg_num_neighbors = np.mean([r[0] for r in results]) avg_num_neighbors = np.mean([r[0] for r in results])
means = np.array([r[1] for r in results]) means = np.array([r[1] for r in results])
stds = np.array([r[2] for r in results]) stds = np.array([r[2] for r in results])
# Compute averages # Compute averages
mean = np.mean(means, axis=0).item() mean = np.mean(means, axis=0).item()
std = np.mean(stds, axis=0).item() std = np.mean(stds, axis=0).item()
return avg_num_neighbors, mean, std return avg_num_neighbors, mean, std
def split_array(a: np.ndarray, max_size: int): def split_array(a: np.ndarray, max_size: int):
drop_last = False drop_last = False
if len(a) % 2 == 1: if len(a) % 2 == 1:
a = np.append(a, a[-1]) a = np.append(a, a[-1])
drop_last = True drop_last = True
factors = get_prime_factors(len(a)) factors = get_prime_factors(len(a))
max_factor = 1 max_factor = 1
for i in range(1, len(factors) + 1): for i in range(1, len(factors) + 1):
for j in range(0, len(factors) - i + 1): for j in range(0, len(factors) - i + 1):
if np.prod(factors[j : j + i]) <= max_size: if np.prod(factors[j : j + i]) <= max_size:
test = np.prod(factors[j : j + i]) test = np.prod(factors[j : j + i])
max_factor = max(test, max_factor) max_factor = max(test, max_factor)
return np.array_split(a, max_factor), drop_last return np.array_split(a, max_factor), drop_last
def get_prime_factors(n: int): def get_prime_factors(n: int):
factors = [] factors = []
for i in range(2, n + 1): for i in range(2, n + 1):
while n % i == 0: while n % i == 0:
factors.append(i) factors.append(i)
n = n / i n = n / i
return factors return factors
# Define Task for Multiprocessiing # Define Task for Multiprocessiing
def multi_train_hdf5(process, args, split_train, drop_last): def multi_train_hdf5(process, args, split_train, drop_last):
with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f: with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f) save_configurations_as_HDF5(split_train[process], process, f)
def multi_valid_hdf5(process, args, split_valid, drop_last): def multi_valid_hdf5(process, args, split_valid, drop_last):
with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f: with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f) save_configurations_as_HDF5(split_valid[process], process, f)
def multi_test_hdf5(process, name, args, split_test, drop_last): def multi_test_hdf5(process, name, args, split_test, drop_last):
with h5py.File( with h5py.File(
args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w"
) as f: ) as f:
f.attrs["drop_last"] = drop_last f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f) save_configurations_as_HDF5(split_test[process], process, f)
def main() -> None: def main() -> None:
""" """
This script loads an xyz dataset and prepares This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading new hdf5 file that is ready for training with on-the-fly dataloading
""" """
args = tools.build_preprocess_arg_parser().parse_args() args = tools.build_preprocess_arg_parser().parse_args()
run(args) run(args)
def run(args: argparse.Namespace): def run(args: argparse.Namespace):
""" """
This script loads an xyz dataset and prepares This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading new hdf5 file that is ready for training with on-the-fly dataloading
""" """
# currently support only command line property_key syntax # currently support only command line property_key syntax
args.key_specification = KeySpecification() args.key_specification = KeySpecification()
update_keyspec_from_kwargs(args.key_specification, vars(args)) update_keyspec_from_kwargs(args.key_specification, vars(args))
# Setup # Setup
tools.set_seeds(args.seed) tools.set_seeds(args.seed)
random.seed(args.seed) random.seed(args.seed)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s", format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler()], handlers=[logging.StreamHandler()],
) )
try: try:
config_type_weights = ast.literal_eval(args.config_type_weights) config_type_weights = ast.literal_eval(args.config_type_weights)
assert isinstance(config_type_weights, dict) assert isinstance(config_type_weights, dict)
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
logging.warning( logging.warning(
f"Config type weights not specified correctly ({e}), using Default" f"Config type weights not specified correctly ({e}), using Default"
) )
config_type_weights = {"Default": 1.0} config_type_weights = {"Default": 1.0}
folders = ["train", "val", "test"] folders = ["train", "val", "test"]
for sub_dir in folders: for sub_dir in folders:
if not os.path.exists(args.h5_prefix + sub_dir): if not os.path.exists(args.h5_prefix + sub_dir):
os.makedirs(args.h5_prefix + sub_dir) os.makedirs(args.h5_prefix + sub_dir)
# Data preparation # Data preparation
collections, atomic_energies_dict = get_dataset_from_xyz( collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir, work_dir=args.work_dir,
train_path=args.train_file, train_path=args.train_file,
valid_path=args.valid_file, valid_path=args.valid_file,
valid_fraction=args.valid_fraction, valid_fraction=args.valid_fraction,
config_type_weights=config_type_weights, config_type_weights=config_type_weights,
test_path=args.test_file, test_path=args.test_file,
seed=args.seed, seed=args.seed,
key_specification=args.key_specification, key_specification=args.key_specification,
head_name=None, head_name=None,
) )
# Atomic number table # Atomic number table
# yapf: disable # yapf: disable
if args.atomic_numbers is None: if args.atomic_numbers is None:
z_table = tools.get_atomic_number_table_from_zs( z_table = tools.get_atomic_number_table_from_zs(
z z
for configs in (collections.train, collections.valid) for configs in (collections.train, collections.valid)
for config in configs for config in configs
for z in config.atomic_numbers for z in config.atomic_numbers
) )
else: else:
logging.info("Using atomic numbers from command line argument") logging.info("Using atomic numbers from command line argument")
zs_list = ast.literal_eval(args.atomic_numbers) zs_list = ast.literal_eval(args.atomic_numbers)
assert isinstance(zs_list, list) assert isinstance(zs_list, list)
z_table = tools.get_atomic_number_table_from_zs(zs_list) z_table = tools.get_atomic_number_table_from_zs(zs_list)
logging.info("Preparing training set") logging.info("Preparing training set")
if args.shuffle: if args.shuffle:
random.shuffle(collections.train) random.shuffle(collections.train)
# split collections.train into batches and save them to hdf5 # split collections.train into batches and save them to hdf5
split_train = np.array_split(collections.train,args.num_process) split_train = np.array_split(collections.train,args.num_process)
drop_last = False drop_last = False
if len(collections.train) % 2 == 1: if len(collections.train) % 2 == 1:
drop_last = True drop_last = True
multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last) multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last)
processes = [] processes = []
for i in range(args.num_process): for i in range(args.num_process):
p = mp.Process(target=multi_train_hdf5_, args=[i]) p = mp.Process(target=multi_train_hdf5_, args=[i])
p.start() p.start()
processes.append(p) processes.append(p)
for i in processes: for i in processes:
i.join() i.join()
if args.compute_statistics: if args.compute_statistics:
logging.info("Computing statistics") logging.info("Computing statistics")
if len(atomic_energies_dict) == 0: if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
# Remove atomic energies if element not in z_table # Remove atomic energies if element not in z_table
removed_atomic_energies = {} removed_atomic_energies = {}
for z in list(atomic_energies_dict): for z in list(atomic_energies_dict):
if z not in z_table.zs: if z not in z_table.zs:
removed_atomic_energies[z] = atomic_energies_dict.pop(z) removed_atomic_energies[z] = atomic_energies_dict.pop(z)
if len(removed_atomic_energies) > 0: if len(removed_atomic_energies) > 0:
logging.warning("Atomic energies for elements not present in the atomic number table have been removed.") logging.warning("Atomic energies for elements not present in the atomic number table have been removed.")
logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}") logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}")
logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.") logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.")
atomic_energies: np.ndarray = np.array( atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs] [atomic_energies_dict[z] for z in z_table.zs]
) )
logging.info(f"Atomic Energies: {atomic_energies.tolist()}") logging.info(f"Atomic Energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs) avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}") logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}") logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}") logging.info(f"Standard deviation: {std}")
# save the statistics as a json # save the statistics as a json
statistics = { statistics = {
"atomic_energies": str(atomic_energies_dict), "atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors, "avg_num_neighbors": avg_num_neighbors,
"mean": mean, "mean": mean,
"std": std, "std": std,
"atomic_numbers": str([int(z) for z in z_table.zs]), "atomic_numbers": str([int(z) for z in z_table.zs]),
"r_max": args.r_max, "r_max": args.r_max,
} }
with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f) json.dump(statistics, f)
logging.info("Preparing validation set") logging.info("Preparing validation set")
if args.shuffle: if args.shuffle:
random.shuffle(collections.valid) random.shuffle(collections.valid)
split_valid = np.array_split(collections.valid, args.num_process) split_valid = np.array_split(collections.valid, args.num_process)
drop_last = False drop_last = False
if len(collections.valid) % 2 == 1: if len(collections.valid) % 2 == 1:
drop_last = True drop_last = True
multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last) multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last)
processes = [] processes = []
for i in range(args.num_process): for i in range(args.num_process):
p = mp.Process(target=multi_valid_hdf5_, args=[i]) p = mp.Process(target=multi_valid_hdf5_, args=[i])
p.start() p.start()
processes.append(p) processes.append(p)
for i in processes: for i in processes:
i.join() i.join()
if args.test_file is not None: if args.test_file is not None:
logging.info("Preparing test sets") logging.info("Preparing test sets")
for name, subset in collections.tests: for name, subset in collections.tests:
drop_last = False drop_last = False
if len(subset) % 2 == 1: if len(subset) % 2 == 1:
drop_last = True drop_last = True
split_test = np.array_split(subset, args.num_process) split_test = np.array_split(subset, args.num_process)
multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last) multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last)
processes = [] processes = []
for i in range(args.num_process): for i in range(args.num_process):
p = mp.Process(target=multi_test_hdf5_, args=[i, name]) p = mp.Process(target=multi_test_hdf5_, args=[i, name])
p.start() p.start()
processes.append(p) processes.append(p)
for i in processes: for i in processes:
i.join() i.join()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
########################################################################################### ###########################################################################################
# Training script for MACE # Training script for MACE
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs # Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md) # This program is distributed under the MIT License (see MIT.md)
########################################################################################### ###########################################################################################
import ast import ast
import glob import glob
import json import json
import logging import logging
import os import os
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
import torch.distributed import torch.distributed
import torch.nn.functional import torch.nn.functional
from e3nn.util import jit from e3nn.util import jit
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import LBFGS from torch.optim import LBFGS
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset
from torch_ema import ExponentialMovingAverage from torch_ema import ExponentialMovingAverage
import mace import mace
from mace import data, tools from mace import data, tools
from mace.calculators.foundations_models import mace_mp, mace_off from mace.calculators.foundations_models import mace_mp, mace_off
from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.cli.visualise_train import TrainingPlotter from mace.cli.visualise_train import TrainingPlotter
from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.tools import torch_geometric from mace.tools import torch_geometric
from mace.tools.model_script_utils import configure_model from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import ( from mace.tools.multihead_tools import (
HeadConfig, HeadConfig,
assemble_mp_data, assemble_mp_data,
dict_head_to_dataclass, dict_head_to_dataclass,
prepare_default_head, prepare_default_head,
prepare_pt_head, prepare_pt_head,
) )
from mace.tools.run_train_utils import ( from mace.tools.run_train_utils import (
combine_datasets, combine_datasets,
load_dataset_for_path, load_dataset_for_path,
normalize_file_paths, normalize_file_paths,
) )
from mace.tools.scripts_utils import ( from mace.tools.scripts_utils import (
LRScheduler, LRScheduler,
SubsetCollection, SubsetCollection,
check_path_ase_read, check_path_ase_read,
convert_to_json_format, convert_to_json_format,
dict_to_array, dict_to_array,
extract_config_mace_model, extract_config_mace_model,
get_atomic_energies, get_atomic_energies,
get_avg_num_neighbors, get_avg_num_neighbors,
get_config_type_weights, get_config_type_weights,
get_dataset_from_xyz, get_dataset_from_xyz,
get_files_with_suffix, get_files_with_suffix,
get_loss_fn, get_loss_fn,
get_optimizer, get_optimizer,
get_params_options, get_params_options,
get_swa, get_swa,
print_git_commit, print_git_commit,
remove_pt_head, remove_pt_head,
setup_wandb, setup_wandb,
) )
from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable from mace.tools.utils import AtomicNumberTable
def main() -> None: def main() -> None:
""" """
This script runs the training/fine tuning for mace This script runs the training/fine tuning for mace
""" """
args = tools.build_default_arg_parser().parse_args() args = tools.build_default_arg_parser().parse_args()
run(args) run(args)
def run(args) -> None: def run(args) -> None:
""" """
This script runs the training/fine tuning for mace This script runs the training/fine tuning for mace
""" """
tag = tools.get_tag(name=args.name, seed=args.seed) tag = tools.get_tag(name=args.name, seed=args.seed)
args, input_log_messages = tools.check_args(args) args, input_log_messages = tools.check_args(args)
# default keyspec to update using heads dictionary # default keyspec to update using heads dictionary
args.key_specification = KeySpecification() args.key_specification = KeySpecification()
update_keyspec_from_kwargs(args.key_specification, vars(args)) update_keyspec_from_kwargs(args.key_specification, vars(args))
if args.device == "xpu": if args.device == "xpu":
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Error: Intel extension for PyTorch not found, but XPU device was specified" "Error: Intel extension for PyTorch not found, but XPU device was specified"
) from e ) from e
if args.distributed: if args.distributed:
try: try:
distr_env = DistributedEnvironment() distr_env = DistributedEnvironment()
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
logging.error(f"Failed to initialize distributed environment: {e}") logging.error(f"Failed to initialize distributed environment: {e}")
return return
world_size = distr_env.world_size world_size = distr_env.world_size
local_rank = distr_env.local_rank local_rank = distr_env.local_rank
rank = distr_env.rank rank = distr_env.rank
if rank == 0: if rank == 0:
print(distr_env) print(distr_env)
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
else: else:
rank = int(0) rank = int(0)
# Setup # Setup
tools.set_seeds(args.seed) tools.set_seeds(args.seed)
tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank)
logging.info("===========VERIFYING SETTINGS===========") logging.info("===========VERIFYING SETTINGS===========")
for message, loglevel in input_log_messages: for message, loglevel in input_log_messages:
logging.log(level=loglevel, msg=message) logging.log(level=loglevel, msg=message)
if args.distributed: if args.distributed:
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") logging.info(f"Process group initialized: {torch.distributed.is_initialized()}")
logging.info(f"Processes: {world_size}") logging.info(f"Processes: {world_size}")
try: try:
logging.info(f"MACE version: {mace.__version__}") logging.info(f"MACE version: {mace.__version__}")
except AttributeError: except AttributeError:
logging.info("Cannot find MACE version, please install MACE via pip") logging.info("Cannot find MACE version, please install MACE via pip")
logging.debug(f"Configuration: {args}") logging.debug(f"Configuration: {args}")
tools.set_default_dtype(args.default_dtype) tools.set_default_dtype(args.default_dtype)
device = tools.init_device(args.device) device = tools.init_device(args.device)
commit = print_git_commit() commit = print_git_commit()
model_foundation: Optional[torch.nn.Module] = None model_foundation: Optional[torch.nn.Module] = None
foundation_model_avg_num_neighbors = 0 foundation_model_avg_num_neighbors = 0
if args.foundation_model is not None: if args.foundation_model is not None:
if args.foundation_model in ["small", "medium", "large"]: if args.foundation_model in ["small", "medium", "large"]:
logging.info( logging.info(
f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint."
) )
calc = mace_mp( calc = mace_mp(
model=args.foundation_model, model=args.foundation_model,
device=args.device, device=args.device,
default_dtype=args.default_dtype, default_dtype=args.default_dtype,
) )
model_foundation = calc.models[0] model_foundation = calc.models[0]
elif args.foundation_model in ["small_off", "medium_off", "large_off"]: elif args.foundation_model in ["small_off", "medium_off", "large_off"]:
model_type = args.foundation_model.split("_")[0] model_type = args.foundation_model.split("_")[0]
logging.info( logging.info(
f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license." f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license."
) )
calc = mace_off( calc = mace_off(
model=model_type, model=model_type,
device=args.device, device=args.device,
default_dtype=args.default_dtype, default_dtype=args.default_dtype,
) )
model_foundation = calc.models[0] model_foundation = calc.models[0]
else: else:
model_foundation = torch.load( model_foundation = torch.load(
args.foundation_model, map_location=args.device args.foundation_model, map_location=args.device
) )
logging.info( logging.info(
f"Using foundation model {args.foundation_model} as initial checkpoint." f"Using foundation model {args.foundation_model} as initial checkpoint."
) )
args.r_max = model_foundation.r_max.item() args.r_max = model_foundation.r_max.item()
foundation_model_avg_num_neighbors = model_foundation.interactions[ foundation_model_avg_num_neighbors = model_foundation.interactions[
0 0
].avg_num_neighbors ].avg_num_neighbors
if ( if (
args.foundation_model not in ["small", "medium", "large"] args.foundation_model not in ["small", "medium", "large"]
and args.pt_train_file is None and args.pt_train_file is None
): ):
logging.warning( logging.warning(
"Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file."
) )
args.multiheads_finetuning = False args.multiheads_finetuning = False
if args.multiheads_finetuning: if args.multiheads_finetuning:
assert ( assert (
args.E0s != "average" args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning" ), "average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head # check that the foundation model has a single head, if not, use the first head
if not args.force_mh_ft_lr: if not args.force_mh_ft_lr:
logging.info( logging.info(
"Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True." "Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True."
) )
args.lr = 0.0001 args.lr = 0.0001
args.ema = True args.ema = True
args.ema_decay = 0.99999 args.ema_decay = 0.99999
logging.info( logging.info(
"Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True" "Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True"
) )
if hasattr(model_foundation, "heads"): if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1: if len(model_foundation.heads) > 1:
logging.warning( logging.warning(
"Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head."
) )
model_foundation = remove_pt_head( model_foundation = remove_pt_head(
model_foundation, args.foundation_head model_foundation, args.foundation_head
) )
else: else:
args.multiheads_finetuning = False args.multiheads_finetuning = False
if args.heads is not None: if args.heads is not None:
args.heads = ast.literal_eval(args.heads) args.heads = ast.literal_eval(args.heads)
for _, head_dict in args.heads.items(): for _, head_dict in args.heads.items():
# priority is global args < head property_key values < head info_keys+arrays_keys # priority is global args < head property_key values < head info_keys+arrays_keys
head_keyspec = deepcopy(args.key_specification) head_keyspec = deepcopy(args.key_specification)
update_keyspec_from_kwargs(head_keyspec, head_dict) update_keyspec_from_kwargs(head_keyspec, head_dict)
head_keyspec.update( head_keyspec.update(
info_keys=head_dict.get("info_keys", {}), info_keys=head_dict.get("info_keys", {}),
arrays_keys=head_dict.get("arrays_keys", {}), arrays_keys=head_dict.get("arrays_keys", {}),
) )
head_dict["key_specification"] = head_keyspec head_dict["key_specification"] = head_keyspec
else: else:
args.heads = prepare_default_head(args) args.heads = prepare_default_head(args)
if args.multiheads_finetuning: if args.multiheads_finetuning:
pt_keyspec = ( pt_keyspec = (
args.heads["pt_head"]["key_specification"] args.heads["pt_head"]["key_specification"]
if "pt_head" in args.heads if "pt_head" in args.heads
else deepcopy(args.key_specification) else deepcopy(args.key_specification)
) )
args.heads["pt_head"] = prepare_pt_head( args.heads["pt_head"] = prepare_pt_head(
args, pt_keyspec, foundation_model_avg_num_neighbors args, pt_keyspec, foundation_model_avg_num_neighbors
) )
logging.info("===========LOADING INPUT DATA===========") logging.info("===========LOADING INPUT DATA===========")
heads = list(args.heads.keys()) heads = list(args.heads.keys())
logging.info(f"Using heads: {heads}") logging.info(f"Using heads: {heads}")
logging.info("Using the key specifications to parse data:") logging.info("Using the key specifications to parse data:")
for name, head_dict in args.heads.items(): for name, head_dict in args.heads.items():
head_keyspec = head_dict["key_specification"] head_keyspec = head_dict["key_specification"]
logging.info(f"{name}: {head_keyspec}") logging.info(f"{name}: {head_keyspec}")
head_configs: List[HeadConfig] = [] head_configs: List[HeadConfig] = []
for head, head_args in args.heads.items(): for head, head_args in args.heads.items():
logging.info(f"============= Processing head {head} ===========") logging.info(f"============= Processing head {head} ===========")
head_config = dict_head_to_dataclass(head_args, head, args) head_config = dict_head_to_dataclass(head_args, head, args)
# Handle train_file and valid_file - normalize to lists # Handle train_file and valid_file - normalize to lists
if hasattr(head_config, "train_file") and head_config.train_file is not None: if hasattr(head_config, "train_file") and head_config.train_file is not None:
head_config.train_file = normalize_file_paths(head_config.train_file) head_config.train_file = normalize_file_paths(head_config.train_file)
if hasattr(head_config, "valid_file") and head_config.valid_file is not None: if hasattr(head_config, "valid_file") and head_config.valid_file is not None:
head_config.valid_file = normalize_file_paths(head_config.valid_file) head_config.valid_file = normalize_file_paths(head_config.valid_file)
if hasattr(head_config, "test_file") and head_config.test_file is not None: if hasattr(head_config, "test_file") and head_config.test_file is not None:
head_config.test_file = normalize_file_paths(head_config.test_file) head_config.test_file = normalize_file_paths(head_config.test_file)
if ( if (
head_config.statistics_file is not None head_config.statistics_file is not None
and head_config.head_name != "pt_head" and head_config.head_name != "pt_head"
): ):
with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514
statistics = json.load(f) statistics = json.load(f)
logging.info("Using statistics json file") logging.info("Using statistics json file")
head_config.atomic_numbers = statistics["atomic_numbers"] head_config.atomic_numbers = statistics["atomic_numbers"]
head_config.mean = statistics["mean"] head_config.mean = statistics["mean"]
head_config.std = statistics["std"] head_config.std = statistics["std"]
head_config.avg_num_neighbors = statistics["avg_num_neighbors"] head_config.avg_num_neighbors = statistics["avg_num_neighbors"]
head_config.compute_avg_num_neighbors = False head_config.compute_avg_num_neighbors = False
if isinstance(statistics["atomic_energies"], str) and statistics[ if isinstance(statistics["atomic_energies"], str) and statistics[
"atomic_energies" "atomic_energies"
].endswith(".json"): ].endswith(".json"):
with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: with open(statistics["atomic_energies"], "r", encoding="utf-8") as f:
atomic_energies = json.load(f) atomic_energies = json.load(f)
head_config.E0s = atomic_energies head_config.E0s = atomic_energies
head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) head_config.atomic_energies_dict = ast.literal_eval(atomic_energies)
else: else:
head_config.E0s = statistics["atomic_energies"] head_config.E0s = statistics["atomic_energies"]
head_config.atomic_energies_dict = ast.literal_eval( head_config.atomic_energies_dict = ast.literal_eval(
statistics["atomic_energies"] statistics["atomic_energies"]
) )
if head_config.train_file == ["mp"]: if head_config.train_file == ["mp"]:
assert ( assert (
head_config.head_name == "pt_head" head_config.head_name == "pt_head"
), "Only pt_head should use mp as train_file" ), "Only pt_head should use mp as train_file"
logging.info( logging.info(
"Using the full Materials Project data for replay. You can construct a different subset using `fine_tuning_select.py` script." "Using the full Materials Project data for replay. You can construct a different subset using `fine_tuning_select.py` script."
) )
collections = assemble_mp_data(args, head_config, tag) collections = assemble_mp_data(args, head_config, tag)
head_config.collections = collections head_config.collections = collections
elif any(check_path_ase_read(f) for f in head_config.train_file): elif any(check_path_ase_read(f) for f in head_config.train_file):
train_files_ase_list = [ train_files_ase_list = [
f for f in head_config.train_file if check_path_ase_read(f) f for f in head_config.train_file if check_path_ase_read(f)
] ]
valid_files_ase_list = None valid_files_ase_list = None
test_files_ase_list = None test_files_ase_list = None
if head_config.valid_file: if head_config.valid_file:
valid_files_ase_list = [ valid_files_ase_list = [
f for f in head_config.valid_file if check_path_ase_read(f) f for f in head_config.valid_file if check_path_ase_read(f)
] ]
if head_config.test_file: if head_config.test_file:
test_files_ase_list = [ test_files_ase_list = [
f for f in head_config.test_file if check_path_ase_read(f) f for f in head_config.test_file if check_path_ase_read(f)
] ]
config_type_weights = get_config_type_weights( config_type_weights = get_config_type_weights(
head_config.config_type_weights head_config.config_type_weights
) )
collections, atomic_energies_dict = get_dataset_from_xyz( collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir, work_dir=args.work_dir,
train_path=train_files_ase_list, train_path=train_files_ase_list,
valid_path=valid_files_ase_list, valid_path=valid_files_ase_list,
valid_fraction=head_config.valid_fraction, valid_fraction=head_config.valid_fraction,
config_type_weights=config_type_weights, config_type_weights=config_type_weights,
test_path=test_files_ase_list, test_path=test_files_ase_list,
seed=args.seed, seed=args.seed,
key_specification=head_config.key_specification, key_specification=head_config.key_specification,
head_name=head_config.head_name, head_name=head_config.head_name,
keep_isolated_atoms=head_config.keep_isolated_atoms, keep_isolated_atoms=head_config.keep_isolated_atoms,
) )
head_config.collections = SubsetCollection( head_config.collections = SubsetCollection(
train=collections.train, train=collections.train,
valid=collections.valid, valid=collections.valid,
tests=collections.tests, tests=collections.tests,
) )
head_config.atomic_energies_dict = atomic_energies_dict head_config.atomic_energies_dict = atomic_energies_dict
logging.info( logging.info(
f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, "
f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}],"
) )
head_configs.append(head_config) head_configs.append(head_config)
if all( if all(
check_path_ase_read(head_config.train_file[0]) for head_config in head_configs check_path_ase_read(head_config.train_file[0]) for head_config in head_configs
): ):
size_collections_train = sum( size_collections_train = sum(
len(head_config.collections.train) for head_config in head_configs len(head_config.collections.train) for head_config in head_configs
) )
size_collections_valid = sum( size_collections_valid = sum(
len(head_config.collections.valid) for head_config in head_configs len(head_config.collections.valid) for head_config in head_configs
) )
if size_collections_train < args.batch_size: if size_collections_train < args.batch_size:
logging.error( logging.error(
f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})"
) )
if size_collections_valid < args.valid_batch_size: if size_collections_valid < args.valid_batch_size:
logging.warning( logging.warning(
f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})"
) )
if args.multiheads_finetuning: if args.multiheads_finetuning:
logging.info( logging.info(
"==================Using multiheads finetuning mode==================" "==================Using multiheads finetuning mode=================="
) )
args.loss = "universal" args.loss = "universal"
all_ase_readable = all( all_ase_readable = all(
all(check_path_ase_read(f) for f in head_config.train_file) all(check_path_ase_read(f) for f in head_config.train_file)
for head_config in head_configs for head_config in head_configs
) )
head_config_pt = filter(lambda x: x.head_name == "pt_head", head_configs) head_config_pt = filter(lambda x: x.head_name == "pt_head", head_configs)
head_config_pt = next(head_config_pt, None) head_config_pt = next(head_config_pt, None)
assert head_config_pt is not None, "Pretraining head not found" assert head_config_pt is not None, "Pretraining head not found"
if all_ase_readable: if all_ase_readable:
ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train)
if ratio_pt_ft < 0.1: if ratio_pt_ft < 0.1:
logging.warning( logging.warning(
f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, "
f"increasing the number of configurations in the fine-tuning heads by {int(0.1 / ratio_pt_ft)}" f"increasing the number of configurations in the fine-tuning heads by {int(0.1 / ratio_pt_ft)}"
) )
for head_config in head_configs: for head_config in head_configs:
if head_config.head_name == "pt_head": if head_config.head_name == "pt_head":
continue continue
head_config.collections.train += ( head_config.collections.train += (
head_config.collections.train * int(0.1 / ratio_pt_ft) head_config.collections.train * int(0.1 / ratio_pt_ft)
) )
logging.info( logging.info(
f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}"
) )
else: else:
logging.debug( logging.debug(
"Using LMDB/HDF5 datasets for pretraining or fine-tuning - skipping ratio check" "Using LMDB/HDF5 datasets for pretraining or fine-tuning - skipping ratio check"
) )
# Atomic number table # Atomic number table
# yapf: disable # yapf: disable
for head_config in head_configs: for head_config in head_configs:
if head_config.atomic_numbers is None: if head_config.atomic_numbers is None:
assert all(check_path_ase_read(f) for f in head_config.train_file), "Must specify atomic_numbers when using .h5 or .aselmdb train_file input" assert all(check_path_ase_read(f) for f in head_config.train_file), "Must specify atomic_numbers when using .h5 or .aselmdb train_file input"
z_table_head = tools.get_atomic_number_table_from_zs( z_table_head = tools.get_atomic_number_table_from_zs(
z z
for configs in (head_config.collections.train, head_config.collections.valid) for configs in (head_config.collections.train, head_config.collections.valid)
for config in configs for config in configs
for z in config.atomic_numbers for z in config.atomic_numbers
) )
head_config.atomic_numbers = z_table_head.zs head_config.atomic_numbers = z_table_head.zs
head_config.z_table = z_table_head head_config.z_table = z_table_head
else: else:
if head_config.statistics_file is None: if head_config.statistics_file is None:
logging.info("Using atomic numbers from command line argument") logging.info("Using atomic numbers from command line argument")
else: else:
logging.info("Using atomic numbers from statistics file") logging.info("Using atomic numbers from statistics file")
zs_list = ast.literal_eval(head_config.atomic_numbers) zs_list = ast.literal_eval(head_config.atomic_numbers)
assert isinstance(zs_list, list) assert isinstance(zs_list, list)
z_table_head = tools.AtomicNumberTable(zs_list) z_table_head = tools.AtomicNumberTable(zs_list)
head_config.atomic_numbers = zs_list head_config.atomic_numbers = zs_list
head_config.z_table = z_table_head head_config.z_table = z_table_head
# yapf: enable # yapf: enable
all_atomic_numbers = set() all_atomic_numbers = set()
for head_config in head_configs: for head_config in head_configs:
all_atomic_numbers.update(head_config.atomic_numbers) all_atomic_numbers.update(head_config.atomic_numbers)
z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) z_table = AtomicNumberTable(sorted(list(all_atomic_numbers)))
if args.foundation_model_elements and model_foundation: if args.foundation_model_elements and model_foundation:
z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist())) z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist()))
logging.info(f"Atomic Numbers used: {z_table.zs}") logging.info(f"Atomic Numbers used: {z_table.zs}")
# Atomic energies # Atomic energies
atomic_energies_dict = {} atomic_energies_dict = {}
for head_config in head_configs: for head_config in head_configs:
if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0:
assert head_config.E0s is not None, "Atomic energies must be provided" assert head_config.E0s is not None, "Atomic energies must be provided"
if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation": if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation":
atomic_energies_dict[head_config.head_name] = get_atomic_energies( atomic_energies_dict[head_config.head_name] = get_atomic_energies(
head_config.E0s, head_config.collections.train, head_config.z_table head_config.E0s, head_config.collections.train, head_config.z_table
) )
elif head_config.E0s.lower() == "foundation": elif head_config.E0s.lower() == "foundation":
assert args.foundation_model is not None assert args.foundation_model is not None
z_table_foundation = AtomicNumberTable( z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers] [int(z) for z in model_foundation.atomic_numbers]
) )
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1: if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze() foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2: if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0] foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict[head_config.head_name] = { atomic_energies_dict[head_config.head_name] = {
z: foundation_atomic_energies[ z: foundation_atomic_energies[
z_table_foundation.z_to_index(z) z_table_foundation.z_to_index(z)
].item() ].item()
for z in z_table.zs for z in z_table.zs
} }
else: else:
atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table)
else: else:
atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict
# Atomic energies for multiheads finetuning # Atomic energies for multiheads finetuning
if args.multiheads_finetuning: if args.multiheads_finetuning:
assert ( assert (
model_foundation is not None model_foundation is not None
), "Model foundation must be provided for multiheads finetuning" ), "Model foundation must be provided for multiheads finetuning"
z_table_foundation = AtomicNumberTable( z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers] [int(z) for z in model_foundation.atomic_numbers]
) )
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1: if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze() foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2: if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0] foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict["pt_head"] = { atomic_energies_dict["pt_head"] = {
z: foundation_atomic_energies[ z: foundation_atomic_energies[
z_table_foundation.z_to_index(z) z_table_foundation.z_to_index(z)
].item() ].item()
for z in z_table.zs for z in z_table.zs
} }
heads = sorted(heads, key=lambda x: -1000 if x == "pt_head" else 0) heads = sorted(heads, key=lambda x: -1000 if x == "pt_head" else 0)
# Padding atomic energies if keeping all elements of the foundation model # Padding atomic energies if keeping all elements of the foundation model
if args.foundation_model_elements and model_foundation: if args.foundation_model_elements and model_foundation:
atomic_energies_dict_padded = {} atomic_energies_dict_padded = {}
for head_name, head_energies in atomic_energies_dict.items(): for head_name, head_energies in atomic_energies_dict.items():
energy_head_padded = {} energy_head_padded = {}
for z in z_table.zs: for z in z_table.zs:
energy_head_padded[z] = head_energies.get(z, 0.0) energy_head_padded[z] = head_energies.get(z, 0.0)
atomic_energies_dict_padded[head_name] = energy_head_padded atomic_energies_dict_padded[head_name] = energy_head_padded
atomic_energies_dict = atomic_energies_dict_padded atomic_energies_dict = atomic_energies_dict_padded
if args.model == "AtomicDipolesMACE": if args.model == "AtomicDipolesMACE":
atomic_energies = None atomic_energies = None
dipole_only = True dipole_only = True
args.compute_dipole = True args.compute_dipole = True
args.compute_energy = False args.compute_energy = False
args.compute_forces = False args.compute_forces = False
args.compute_virials = False args.compute_virials = False
args.compute_stress = False args.compute_stress = False
else: else:
dipole_only = False dipole_only = False
if args.model == "EnergyDipolesMACE": if args.model == "EnergyDipolesMACE":
args.compute_dipole = True args.compute_dipole = True
args.compute_energy = True args.compute_energy = True
args.compute_forces = True args.compute_forces = True
args.compute_virials = False args.compute_virials = False
args.compute_stress = False args.compute_stress = False
else: else:
args.compute_energy = True args.compute_energy = True
args.compute_dipole = False args.compute_dipole = False
# atomic_energies: np.ndarray = np.array( # atomic_energies: np.ndarray = np.array(
# [atomic_energies_dict[z] for z in z_table.zs] # [atomic_energies_dict[z] for z in z_table.zs]
# ) # )
atomic_energies = dict_to_array(atomic_energies_dict, heads) atomic_energies = dict_to_array(atomic_energies_dict, heads)
for head_config in head_configs: for head_config in head_configs:
try: try:
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
except KeyError as e: except KeyError as e:
raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e
# Load datasets for each head, supporting multiple files per head # Load datasets for each head, supporting multiple files per head
valid_sets = {head: [] for head in heads} valid_sets = {head: [] for head in heads}
train_sets = {head: [] for head in heads} train_sets = {head: [] for head in heads}
for head_config in head_configs: for head_config in head_configs:
train_datasets = [] train_datasets = []
logging.info(f"Processing datasets for head '{head_config.head_name}'") logging.info(f"Processing datasets for head '{head_config.head_name}'")
ase_files = [f for f in head_config.train_file if check_path_ase_read(f)] ase_files = [f for f in head_config.train_file if check_path_ase_read(f)]
non_ase_files = [f for f in head_config.train_file if not check_path_ase_read(f)] non_ase_files = [f for f in head_config.train_file if not check_path_ase_read(f)]
if ase_files: if ase_files:
dataset = load_dataset_for_path( dataset = load_dataset_for_path(
file_path=ase_files, file_path=ase_files,
r_max=args.r_max, r_max=args.r_max,
z_table=z_table, z_table=z_table,
head_config=head_config, head_config=head_config,
heads=heads, heads=heads,
collection=head_config.collections.train, collection=head_config.collections.train,
) )
train_datasets.append(dataset) train_datasets.append(dataset)
logging.debug(f"Successfully loaded dataset from ASE files: {ase_files}") logging.debug(f"Successfully loaded dataset from ASE files: {ase_files}")
for file in non_ase_files: for file in non_ase_files:
dataset = load_dataset_for_path( dataset = load_dataset_for_path(
file_path=file, file_path=file,
r_max=args.r_max, r_max=args.r_max,
z_table=z_table, z_table=z_table,
head_config=head_config, head_config=head_config,
heads=heads, heads=heads,
) )
train_datasets.append(dataset) train_datasets.append(dataset)
logging.debug(f"Successfully loaded dataset from non-ASE file: {file}") logging.debug(f"Successfully loaded dataset from non-ASE file: {file}")
if not train_datasets: if not train_datasets:
raise ValueError(f"No valid training datasets found for head {head_config.head_name}") raise ValueError(f"No valid training datasets found for head {head_config.head_name}")
train_sets[head_config.head_name] = combine_datasets(train_datasets, head_config.head_name) train_sets[head_config.head_name] = combine_datasets(train_datasets, head_config.head_name)
if head_config.valid_file: if head_config.valid_file:
valid_datasets = [] valid_datasets = []
valid_ase_files = [f for f in head_config.valid_file if check_path_ase_read(f)] valid_ase_files = [f for f in head_config.valid_file if check_path_ase_read(f)]
valid_non_ase_files = [f for f in head_config.valid_file if not check_path_ase_read(f)] valid_non_ase_files = [f for f in head_config.valid_file if not check_path_ase_read(f)]
if valid_ase_files: if valid_ase_files:
valid_dataset = load_dataset_for_path( valid_dataset = load_dataset_for_path(
file_path=valid_ase_files, file_path=valid_ase_files,
r_max=args.r_max, r_max=args.r_max,
z_table=z_table, z_table=z_table,
head_config=head_config, head_config=head_config,
heads=heads, heads=heads,
collection=head_config.collections.valid, collection=head_config.collections.valid,
) )
valid_datasets.append(valid_dataset) valid_datasets.append(valid_dataset)
logging.debug(f"Successfully loaded validation dataset from ASE files: {valid_ase_files}") logging.debug(f"Successfully loaded validation dataset from ASE files: {valid_ase_files}")
for valid_file in valid_non_ase_files: for valid_file in valid_non_ase_files:
valid_dataset = load_dataset_for_path( valid_dataset = load_dataset_for_path(
file_path=valid_file, file_path=valid_file,
r_max=args.r_max, r_max=args.r_max,
z_table=z_table, z_table=z_table,
head_config=head_config, head_config=head_config,
heads=heads, heads=heads,
) )
valid_datasets.append(valid_dataset) valid_datasets.append(valid_dataset)
logging.debug(f"Successfully loaded validation dataset from {valid_file}") logging.debug(f"Successfully loaded validation dataset from {valid_file}")
# Combine validation datasets # Combine validation datasets
if valid_datasets: if valid_datasets:
valid_sets[head_config.head_name] = combine_datasets(valid_datasets, f"{head_config.head_name}_valid") valid_sets[head_config.head_name] = combine_datasets(valid_datasets, f"{head_config.head_name}_valid")
logging.info(f"Combined validation datasets for {head_config.head_name}") logging.info(f"Combined validation datasets for {head_config.head_name}")
# If no valid file is provided but collection exist, use the validation set from the collection # If no valid file is provided but collection exist, use the validation set from the collection
if head_config.valid_file is None and head_config.collections.valid: if head_config.valid_file is None and head_config.collections.valid:
valid_sets[head_config.head_name] = [ valid_sets[head_config.head_name] = [
data.AtomicData.from_config( data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max, heads=heads config, z_table=z_table, cutoff=args.r_max, heads=heads
) )
for config in head_config.collections.valid for config in head_config.collections.valid
] ]
if not valid_sets[head_config.head_name]: if not valid_sets[head_config.head_name]:
raise ValueError(f"No valid datasets found for head {head_config.head_name}, please provide a valid_file or a valid_fraction") raise ValueError(f"No valid datasets found for head {head_config.head_name}, please provide a valid_file or a valid_fraction")
# Create data loader for this head # Create data loader for this head
if isinstance(train_sets[head_config.head_name], list): if isinstance(train_sets[head_config.head_name], list):
dataset_size = len(train_sets[head_config.head_name]) dataset_size = len(train_sets[head_config.head_name])
else: else:
dataset_size = len(train_sets[head_config.head_name]) dataset_size = len(train_sets[head_config.head_name])
logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}") logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}")
train_loader_head = torch_geometric.dataloader.DataLoader( train_loader_head = torch_geometric.dataloader.DataLoader(
dataset=train_sets[head_config.head_name], dataset=train_sets[head_config.head_name],
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=(not args.lbfgs), drop_last=(not args.lbfgs),
pin_memory=args.pin_memory, pin_memory=args.pin_memory,
num_workers=args.num_workers, num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed), generator=torch.Generator().manual_seed(args.seed),
) )
head_config.train_loader = train_loader_head head_config.train_loader = train_loader_head
# concatenate all the trainsets # concatenate all the trainsets
train_set = ConcatDataset([train_sets[head] for head in heads]) train_set = ConcatDataset([train_sets[head] for head in heads])
train_sampler, valid_sampler = None, None train_sampler, valid_sampler = None, None
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler( train_sampler = torch.utils.data.distributed.DistributedSampler(
train_set, train_set,
num_replicas=world_size, num_replicas=world_size,
rank=rank, rank=rank,
shuffle=True, shuffle=True,
drop_last=(not args.lbfgs), drop_last=(not args.lbfgs),
seed=args.seed, seed=args.seed,
) )
valid_samplers = {} valid_samplers = {}
for head, valid_set in valid_sets.items(): for head, valid_set in valid_sets.items():
valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_set, valid_set,
num_replicas=world_size, num_replicas=world_size,
rank=rank, rank=rank,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
seed=args.seed, seed=args.seed,
) )
valid_samplers[head] = valid_sampler valid_samplers[head] = valid_sampler
train_loader = torch_geometric.dataloader.DataLoader( train_loader = torch_geometric.dataloader.DataLoader(
dataset=train_set, dataset=train_set,
batch_size=args.batch_size, batch_size=args.batch_size,
sampler=train_sampler, sampler=train_sampler,
shuffle=(train_sampler is None), shuffle=(train_sampler is None),
drop_last=(train_sampler is None and not args.lbfgs), drop_last=(train_sampler is None and not args.lbfgs),
pin_memory=args.pin_memory, pin_memory=args.pin_memory,
num_workers=args.num_workers, num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed), generator=torch.Generator().manual_seed(args.seed),
) )
valid_loaders = {heads[i]: None for i in range(len(heads))} valid_loaders = {heads[i]: None for i in range(len(heads))}
if not isinstance(valid_sets, dict): if not isinstance(valid_sets, dict):
valid_sets = {"Default": valid_sets} valid_sets = {"Default": valid_sets}
for head, valid_set in valid_sets.items(): for head, valid_set in valid_sets.items():
valid_loaders[head] = torch_geometric.dataloader.DataLoader( valid_loaders[head] = torch_geometric.dataloader.DataLoader(
dataset=valid_set, dataset=valid_set,
batch_size=args.valid_batch_size, batch_size=args.valid_batch_size,
sampler=valid_samplers[head] if args.distributed else None, sampler=valid_samplers[head] if args.distributed else None,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
pin_memory=args.pin_memory, pin_memory=args.pin_memory,
num_workers=args.num_workers, num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed), generator=torch.Generator().manual_seed(args.seed),
) )
loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole)
args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device)
# Model # Model
model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs) model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs)
model.to(device) model.to(device)
logging.debug(model) logging.debug(model)
logging.info(f"Total number of parameters: {tools.count_parameters(model)}") logging.info(f"Total number of parameters: {tools.count_parameters(model)}")
logging.info("") logging.info("")
logging.info("===========OPTIMIZER INFORMATION===========") logging.info("===========OPTIMIZER INFORMATION===========")
logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") logging.info(f"Using {args.optimizer.upper()} as parameter optimizer")
logging.info(f"Batch size: {args.batch_size}") logging.info(f"Batch size: {args.batch_size}")
if args.ema: if args.ema:
logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}")
logging.info( logging.info(
f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}" f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}"
) )
logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}")
logging.info(loss_fn) logging.info(loss_fn)
# Cueq # Cueq
if args.enable_cueq: if args.enable_cueq:
logging.info("Converting model to CUEQ for accelerated training") logging.info("Converting model to CUEQ for accelerated training")
assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"] assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"]
model = run_e3nn_to_cueq(deepcopy(model), device=device) model = run_e3nn_to_cueq(deepcopy(model), device=device)
# Optimizer # Optimizer
param_options = get_params_options(args, model) param_options = get_params_options(args, model)
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
optimizer = get_optimizer(args, param_options) optimizer = get_optimizer(args, param_options)
if args.device == "xpu": if args.device == "xpu":
logging.info("Optimzing model and optimzier for XPU") logging.info("Optimzing model and optimzier for XPU")
model, optimizer = ipex.optimize(model, optimizer=optimizer) model, optimizer = ipex.optimize(model, optimizer=optimizer)
logger = tools.MetricsLogger( logger = tools.MetricsLogger(
directory=args.results_dir, tag=tag + "_train" directory=args.results_dir, tag=tag + "_train"
) # pylint: disable=E1123 ) # pylint: disable=E1123
lr_scheduler = LRScheduler(optimizer, args) lr_scheduler = LRScheduler(optimizer, args)
swa: Optional[tools.SWAContainer] = None swa: Optional[tools.SWAContainer] = None
swas = [False] swas = [False]
if args.swa: if args.swa:
swa, swas = get_swa(args, model, optimizer, swas, dipole_only) swa, swas = get_swa(args, model, optimizer, swas, dipole_only)
checkpoint_handler = tools.CheckpointHandler( checkpoint_handler = tools.CheckpointHandler(
directory=args.checkpoints_dir, directory=args.checkpoints_dir,
tag=tag, tag=tag,
keep=args.keep_checkpoints, keep=args.keep_checkpoints,
swa_start=args.start_swa, swa_start=args.start_swa,
) )
start_epoch = 0 start_epoch = 0
restart_lbfgs = False restart_lbfgs = False
opt_start_epoch = None opt_start_epoch = None
if args.restart_latest: if args.restart_latest:
try: try:
opt_start_epoch = checkpoint_handler.load_latest( opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler), state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=True, swa=True,
device=device, device=device,
) )
except Exception: # pylint: disable=W0703 except Exception: # pylint: disable=W0703
try: try:
opt_start_epoch = checkpoint_handler.load_latest( opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler), state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False, swa=False,
device=device, device=device,
) )
except Exception: # pylint: disable=W0703 except Exception: # pylint: disable=W0703
restart_lbfgs = True restart_lbfgs = True
if opt_start_epoch is not None: if opt_start_epoch is not None:
start_epoch = opt_start_epoch start_epoch = opt_start_epoch
ema: Optional[ExponentialMovingAverage] = None ema: Optional[ExponentialMovingAverage] = None
if args.ema: if args.ema:
ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)
else: else:
for group in optimizer.param_groups: for group in optimizer.param_groups:
group["lr"] = args.lr group["lr"] = args.lr
if args.lbfgs: if args.lbfgs:
logging.info("Switching optimizer to LBFGS") logging.info("Switching optimizer to LBFGS")
optimizer = LBFGS(model.parameters(), optimizer = LBFGS(model.parameters(),
history_size=200, history_size=200,
max_iter=20, max_iter=20,
line_search_fn="strong_wolfe") line_search_fn="strong_wolfe")
if restart_lbfgs: if restart_lbfgs:
opt_start_epoch = checkpoint_handler.load_latest( opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler), state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False, swa=False,
device=device, device=device,
) )
if opt_start_epoch is not None: if opt_start_epoch is not None:
start_epoch = opt_start_epoch start_epoch = opt_start_epoch
if args.wandb: if args.wandb:
setup_wandb(args) setup_wandb(args)
if args.distributed: if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank]) distributed_model = DDP(model, device_ids=[local_rank])
else: else:
distributed_model = None distributed_model = None
train_valid_data_loader = {} train_valid_data_loader = {}
for head_config in head_configs: for head_config in head_configs:
data_loader_name = "train_" + head_config.head_name data_loader_name = "train_" + head_config.head_name
train_valid_data_loader[data_loader_name] = head_config.train_loader train_valid_data_loader[data_loader_name] = head_config.train_loader
for head, valid_loader in valid_loaders.items(): for head, valid_loader in valid_loaders.items():
data_load_name = "valid_" + head data_load_name = "valid_" + head
train_valid_data_loader[data_load_name] = valid_loader train_valid_data_loader[data_load_name] = valid_loader
if args.plot and args.plot_frequency > 0: if args.plot and args.plot_frequency > 0:
try: try:
plotter = TrainingPlotter( plotter = TrainingPlotter(
results_dir=logger.path, results_dir=logger.path,
heads=heads, heads=heads,
table_type=args.error_table, table_type=args.error_table,
train_valid_data=train_valid_data_loader, train_valid_data=train_valid_data_loader,
test_data={}, test_data={},
output_args=output_args, output_args=output_args,
device=device, device=device,
plot_frequency=args.plot_frequency, plot_frequency=args.plot_frequency,
distributed=args.distributed, distributed=args.distributed,
swa_start=swa.start if swa else None swa_start=swa.start if swa else None
) )
except Exception as e: # pylint: disable=W0718 except Exception as e: # pylint: disable=W0718
logging.debug(f"Creating Plotter failed: {e}") logging.debug(f"Creating Plotter failed: {e}")
else: else:
plotter = None plotter = None
if args.dry_run: if args.dry_run:
logging.info("DRY RUN mode enabled. Stopping now.") logging.info("DRY RUN mode enabled. Stopping now.")
return return
tools.train( tools.train(
model=model, model=model,
loss_fn=loss_fn, loss_fn=loss_fn,
train_loader=train_loader, train_loader=train_loader,
valid_loaders=valid_loaders, valid_loaders=valid_loaders,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
checkpoint_handler=checkpoint_handler, checkpoint_handler=checkpoint_handler,
eval_interval=args.eval_interval, eval_interval=args.eval_interval,
start_epoch=start_epoch, start_epoch=start_epoch,
max_num_epochs=args.max_num_epochs, max_num_epochs=args.max_num_epochs,
logger=logger, logger=logger,
patience=args.patience, patience=args.patience,
save_all_checkpoints=args.save_all_checkpoints, save_all_checkpoints=args.save_all_checkpoints,
output_args=output_args, output_args=output_args,
device=device, device=device,
swa=swa, swa=swa,
ema=ema, ema=ema,
max_grad_norm=args.clip_grad, max_grad_norm=args.clip_grad,
log_errors=args.error_table, log_errors=args.error_table,
log_wandb=args.wandb, log_wandb=args.wandb,
distributed=args.distributed, distributed=args.distributed,
distributed_model=distributed_model, distributed_model=distributed_model,
plotter=plotter, plotter=plotter,
train_sampler=train_sampler, train_sampler=train_sampler,
rank=rank, rank=rank,
) )
logging.info("") logging.info("")
logging.info("===========RESULTS===========") logging.info("===========RESULTS===========")
train_valid_data_loader = {} train_valid_data_loader = {}
for head_config in head_configs: for head_config in head_configs:
data_loader_name = "train_" + head_config.head_name data_loader_name = "train_" + head_config.head_name
train_valid_data_loader[data_loader_name] = head_config.train_loader train_valid_data_loader[data_loader_name] = head_config.train_loader
for head, valid_loader in valid_loaders.items(): for head, valid_loader in valid_loaders.items():
data_load_name = "valid_" + head data_load_name = "valid_" + head
train_valid_data_loader[data_load_name] = valid_loader train_valid_data_loader[data_load_name] = valid_loader
test_sets = {} test_sets = {}
stop_first_test = False stop_first_test = False
test_data_loader = {} test_data_loader = {}
if all( if all(
head_config.test_file == head_configs[0].test_file head_config.test_file == head_configs[0].test_file
for head_config in head_configs for head_config in head_configs
) and head_configs[0].test_file is not None: ) and head_configs[0].test_file is not None:
stop_first_test = True stop_first_test = True
if all( if all(
head_config.test_dir == head_configs[0].test_dir head_config.test_dir == head_configs[0].test_dir
for head_config in head_configs for head_config in head_configs
) and head_configs[0].test_dir is not None: ) and head_configs[0].test_dir is not None:
stop_first_test = True stop_first_test = True
for head_config in head_configs: for head_config in head_configs:
if all(check_path_ase_read(f) for f in head_config.train_file): if all(check_path_ase_read(f) for f in head_config.train_file):
for name, subset in head_config.collections.tests: for name, subset in head_config.collections.tests:
test_sets[name] = [ test_sets[name] = [
data.AtomicData.from_config( data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max, heads=heads config, z_table=z_table, cutoff=args.r_max, heads=heads
) )
for config in subset for config in subset
] ]
if head_config.test_dir is not None: if head_config.test_dir is not None:
if not args.multi_processed_test: if not args.multi_processed_test:
test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") test_files = get_files_with_suffix(head_config.test_dir, "_test.h5")
for test_file in test_files: for test_file in test_files:
name = os.path.splitext(os.path.basename(test_file))[0] name = os.path.splitext(os.path.basename(test_file))[0]
test_sets[name] = data.HDF5Dataset( test_sets[name] = data.HDF5Dataset(
test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
) )
else: else:
test_folders = glob(head_config.test_dir + "/*") test_folders = glob(head_config.test_dir + "/*")
for folder in test_folders: for folder in test_folders:
name = os.path.splitext(os.path.basename(test_file))[0] name = os.path.splitext(os.path.basename(test_file))[0]
test_sets[name] = data.dataset_from_sharded_hdf5( test_sets[name] = data.dataset_from_sharded_hdf5(
folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
) )
for test_name, test_set in test_sets.items(): for test_name, test_set in test_sets.items():
test_sampler = None test_sampler = None
if args.distributed: if args.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler( test_sampler = torch.utils.data.distributed.DistributedSampler(
test_set, test_set,
num_replicas=world_size, num_replicas=world_size,
rank=rank, rank=rank,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
seed=args.seed, seed=args.seed,
) )
try: try:
drop_last = test_set.drop_last drop_last = test_set.drop_last
except AttributeError as e: # pylint: disable=W0612 except AttributeError as e: # pylint: disable=W0612
drop_last = False drop_last = False
test_loader = torch_geometric.dataloader.DataLoader( test_loader = torch_geometric.dataloader.DataLoader(
test_set, test_set,
batch_size=args.valid_batch_size, batch_size=args.valid_batch_size,
shuffle=(test_sampler is None), shuffle=(test_sampler is None),
drop_last=drop_last, drop_last=drop_last,
num_workers=args.num_workers, num_workers=args.num_workers,
pin_memory=args.pin_memory, pin_memory=args.pin_memory,
) )
test_data_loader[test_name] = test_loader test_data_loader[test_name] = test_loader
if stop_first_test: if stop_first_test:
break break
for swa_eval in swas: for swa_eval in swas:
epoch = checkpoint_handler.load_latest( epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler), state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval, swa=swa_eval,
device=device, device=device,
) )
model.to(device) model.to(device)
if args.distributed: if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank]) distributed_model = DDP(model, device_ids=[local_rank])
model_to_evaluate = model if not args.distributed else distributed_model model_to_evaluate = model if not args.distributed else distributed_model
if swa_eval: if swa_eval:
logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation")
else: else:
logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation")
if rank == 0: if rank == 0:
# Save entire model # Save entire model
if swa_eval: if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model")
else: else:
model_path = Path(args.checkpoints_dir) / (tag + ".model") model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}") logging.info(f"Saving model to {model_path}")
model_to_save = deepcopy(model) model_to_save = deepcopy(model)
if args.enable_cueq: if args.enable_cueq:
print("RUNING CUEQ TO E3NN") print("RUNING CUEQ TO E3NN")
print("swa_eval", swa_eval) print("swa_eval", swa_eval)
model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device)
if args.save_cpu: if args.save_cpu:
model_to_save = model_to_save.to("cpu") model_to_save = model_to_save.to("cpu")
torch.save(model_to_save, model_path) torch.save(model_to_save, model_path)
extra_files = { extra_files = {
"commit.txt": commit.encode("utf-8") if commit is not None else b"", "commit.txt": commit.encode("utf-8") if commit is not None else b"",
"config.yaml": json.dumps( "config.yaml": json.dumps(
convert_to_json_format(extract_config_mace_model(model)) convert_to_json_format(extract_config_mace_model(model))
), ),
} }
if swa_eval: if swa_eval:
torch.save( torch.save(
model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model") model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model")
) )
try: try:
path_complied = Path(args.model_dir) / ( path_complied = Path(args.model_dir) / (
args.name + "_stagetwo_compiled.model" args.name + "_stagetwo_compiled.model"
) )
logging.info(f"Compiling model, saving metadata {path_complied}") logging.info(f"Compiling model, saving metadata {path_complied}")
model_compiled = jit.compile(deepcopy(model_to_save)) model_compiled = jit.compile(deepcopy(model_to_save))
torch.jit.save( torch.jit.save(
model_compiled, model_compiled,
path_complied, path_complied,
_extra_files=extra_files, _extra_files=extra_files,
) )
except Exception as e: # pylint: disable=W0718 except Exception as e: # pylint: disable=W0718
pass pass
else: else:
torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model")) torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model"))
try: try:
path_complied = Path(args.model_dir) / ( path_complied = Path(args.model_dir) / (
args.name + "_compiled.model" args.name + "_compiled.model"
) )
logging.info(f"Compiling model, saving metadata to {path_complied}") logging.info(f"Compiling model, saving metadata to {path_complied}")
model_compiled = jit.compile(deepcopy(model_to_save)) model_compiled = jit.compile(deepcopy(model_to_save))
torch.jit.save( torch.jit.save(
model_compiled, model_compiled,
path_complied, path_complied,
_extra_files=extra_files, _extra_files=extra_files,
) )
except Exception as e: # pylint: disable=W0718 except Exception as e: # pylint: disable=W0718
pass pass
logging.info("Computing metrics for training, validation, and test sets") logging.info("Computing metrics for training, validation, and test sets")
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
skip_heads = args.skip_evaluate_heads.split(",") if args.skip_evaluate_heads else [] skip_heads = args.skip_evaluate_heads.split(",") if args.skip_evaluate_heads else []
if skip_heads: if skip_heads:
logging.info(f"Skipping evaluation for heads: {skip_heads}") logging.info(f"Skipping evaluation for heads: {skip_heads}")
table_train_valid = create_error_table( table_train_valid = create_error_table(
table_type=args.error_table, table_type=args.error_table,
all_data_loaders=train_valid_data_loader, all_data_loaders=train_valid_data_loader,
model=model_to_evaluate, model=model_to_evaluate,
loss_fn=loss_fn, loss_fn=loss_fn,
output_args=output_args, output_args=output_args,
log_wandb=args.wandb, log_wandb=args.wandb,
device=device, device=device,
distributed=args.distributed, distributed=args.distributed,
skip_heads=skip_heads, skip_heads=skip_heads,
) )
logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid))
if test_data_loader: if test_data_loader:
table_test = create_error_table( table_test = create_error_table(
table_type=args.error_table, table_type=args.error_table,
all_data_loaders=test_data_loader, all_data_loaders=test_data_loader,
model=model_to_evaluate, model=model_to_evaluate,
loss_fn=loss_fn, loss_fn=loss_fn,
output_args=output_args, output_args=output_args,
log_wandb=args.wandb, log_wandb=args.wandb,
device=device, device=device,
distributed=args.distributed, distributed=args.distributed,
) )
logging.info("Error-table on TEST:\n" + str(table_test)) logging.info("Error-table on TEST:\n" + str(table_test))
if args.plot: if args.plot:
try: try:
plotter = TrainingPlotter( plotter = TrainingPlotter(
results_dir=logger.path, results_dir=logger.path,
heads=heads, heads=heads,
table_type=args.error_table, table_type=args.error_table,
train_valid_data=train_valid_data_loader, train_valid_data=train_valid_data_loader,
test_data=test_data_loader, test_data=test_data_loader,
output_args=output_args, output_args=output_args,
device=device, device=device,
plot_frequency=args.plot_frequency, plot_frequency=args.plot_frequency,
distributed=args.distributed, distributed=args.distributed,
swa_start=swa.start if swa else None swa_start=swa.start if swa else None
) )
plotter.plot(epoch, model_to_evaluate, rank) plotter.plot(epoch, model_to_evaluate, rank)
except Exception as e: # pylint: disable=W0718 except Exception as e: # pylint: disable=W0718
logging.debug(f"Plotting failed: {e}") logging.debug(f"Plotting failed: {e}")
if args.distributed: if args.distributed:
torch.distributed.barrier() torch.distributed.barrier()
logging.info("Done") logging.info("Done")
if args.distributed: if args.distributed:
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
from mace.tools.scripts_utils import remove_pt_head from mace.tools.scripts_utils import remove_pt_head
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
grp = parser.add_mutually_exclusive_group() grp = parser.add_mutually_exclusive_group()
grp.add_argument( grp.add_argument(
"--head_name", "--head_name",
"-n", "-n",
help="name of the head to extract", help="name of the head to extract",
default=None, default=None,
) )
grp.add_argument( grp.add_argument(
"--list_heads", "--list_heads",
"-l", "-l",
action="store_true", action="store_true",
help="list names of the heads", help="list names of the heads",
) )
parser.add_argument( parser.add_argument(
"--target_device", "--target_device",
"-d", "-d",
help="target device, defaults to model's current device", help="target device, defaults to model's current device",
) )
parser.add_argument( parser.add_argument(
"--output_file", "--output_file",
"-o", "-o",
help="name for output model, defaults to model.head_name, followed by .target_device if specified", help="name for output model, defaults to model.head_name, followed by .target_device if specified",
) )
parser.add_argument("model_file", help="input model file path") parser.add_argument("model_file", help="input model file path")
args = parser.parse_args() args = parser.parse_args()
model = torch.load(args.model_file, map_location=args.target_device) model = torch.load(args.model_file, map_location=args.target_device)
torch.set_default_dtype(next(model.parameters()).dtype) torch.set_default_dtype(next(model.parameters()).dtype)
if args.list_heads: if args.list_heads:
print("Available heads:") print("Available heads:")
print("\n".join([" " + h for h in model.heads])) print("\n".join([" " + h for h in model.heads]))
else: else:
if args.output_file is None: if args.output_file is None:
args.output_file = ( args.output_file = (
args.model_file args.model_file
+ "." + "."
+ args.head_name + args.head_name
+ ("." + args.target_device if (args.target_device is not None) else "") + ("." + args.target_device if (args.target_device is not None) else "")
) )
model_single = remove_pt_head(model, args.head_name) model_single = remove_pt_head(model, args.head_name)
if args.target_device is not None: if args.target_device is not None:
target_device = str(next(model.parameters()).device) target_device = str(next(model.parameters()).device)
model_single.to(target_device) model_single.to(target_device)
torch.save(model_single, args.output_file) torch.save(model_single, args.output_file)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import json import json
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.distributed import torch.distributed
from torchmetrics import Metric from torchmetrics import Metric
plt.rcParams.update({"font.size": 8}) plt.rcParams.update({"font.size": 8})
mpl_logger = logging.getLogger("matplotlib") mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above
colors = [ colors = [
"#1f77b4", # muted blue "#1f77b4", # muted blue
"#d62728", # brick red "#d62728", # brick red
"#7f7f7f", # middle gray "#7f7f7f", # middle gray
"#2ca02c", # cooked asparagus green "#2ca02c", # cooked asparagus green
"#ff7f0e", # safety orange "#ff7f0e", # safety orange
"#9467bd", # muted purple "#9467bd", # muted purple
"#8c564b", # chestnut brown "#8c564b", # chestnut brown
"#e377c2", # raspberry yogurt pink "#e377c2", # raspberry yogurt pink
"#bcbd22", # curry yellow-green "#bcbd22", # curry yellow-green
"#17becf", # blue-teal "#17becf", # blue-teal
] ]
error_type = { error_type = {
"TotalRMSE": ( "TotalRMSE": (
[("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")], [("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
), ),
"PerAtomRMSE": ( "PerAtomRMSE": (
[("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")], [("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
), ),
"PerAtomRMSEstressvirials": ( "PerAtomRMSEstressvirials": (
[ [
("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_e_per_atom", "RMSE E/atom [meV]"),
("rmse_f", "RMSE F [meV / A]"), ("rmse_f", "RMSE F [meV / A]"),
("rmse_stress", "RMSE Stress [meV / A^3]"), ("rmse_stress", "RMSE Stress [meV / A^3]"),
], ],
[ [
("energy", "Energy per atom [eV]"), ("energy", "Energy per atom [eV]"),
("force", "Force [eV / A]"), ("force", "Force [eV / A]"),
("stress", "Stress [eV / A^3]"), ("stress", "Stress [eV / A^3]"),
], ],
), ),
"PerAtomMAEstressvirials": ( "PerAtomMAEstressvirials": (
[ [
("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_e_per_atom", "MAE E/atom [meV]"),
("mae_f", "MAE F [meV / A]"), ("mae_f", "MAE F [meV / A]"),
("mae_stress", "MAE Stress [meV / A^3]"), ("mae_stress", "MAE Stress [meV / A^3]"),
], ],
[ [
("energy", "Energy per atom [eV]"), ("energy", "Energy per atom [eV]"),
("force", "Force [eV / A]"), ("force", "Force [eV / A]"),
("stress", "Stress [eV / A^3]"), ("stress", "Stress [eV / A^3]"),
], ],
), ),
"TotalMAE": ( "TotalMAE": (
[("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")], [("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
), ),
"PerAtomMAE": ( "PerAtomMAE": (
[("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")], [("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")],
[("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")],
), ),
"DipoleRMSE": ( "DipoleRMSE": (
[ [
("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"),
("rel_rmse_f", "Relative MU RMSE [%]"), ("rel_rmse_f", "Relative MU RMSE [%]"),
], ],
[("dipole", "Dipole per atom [Debye]")], [("dipole", "Dipole per atom [Debye]")],
), ),
"DipoleMAE": ( "DipoleMAE": (
[("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")], [("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")],
[("dipole", "Dipole per atom [Debye]")], [("dipole", "Dipole per atom [Debye]")],
), ),
"EnergyDipoleRMSE": ( "EnergyDipoleRMSE": (
[ [
("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_e_per_atom", "RMSE E/atom [meV]"),
("rmse_f", "RMSE F [meV / A]"), ("rmse_f", "RMSE F [meV / A]"),
("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"),
], ],
[ [
("energy", "Energy per atom [eV]"), ("energy", "Energy per atom [eV]"),
("force", "Force [eV / A]"), ("force", "Force [eV / A]"),
("dipole", "Dipole per atom [Debye]"), ("dipole", "Dipole per atom [Debye]"),
], ],
), ),
} }
class TrainingPlotter: class TrainingPlotter:
def __init__( def __init__(
self, self,
results_dir: str, results_dir: str,
heads: List[str], heads: List[str],
table_type: str, table_type: str,
train_valid_data: Dict, train_valid_data: Dict,
test_data: Dict, test_data: Dict,
output_args: str, output_args: str,
device: str, device: str,
plot_frequency: int, plot_frequency: int,
distributed: bool = False, distributed: bool = False,
swa_start: Optional[int] = None, swa_start: Optional[int] = None,
): ):
self.results_dir = results_dir self.results_dir = results_dir
self.heads = heads self.heads = heads
self.table_type = table_type self.table_type = table_type
self.train_valid_data = train_valid_data self.train_valid_data = train_valid_data
self.test_data = test_data self.test_data = test_data
self.output_args = output_args self.output_args = output_args
self.device = device self.device = device
self.plot_frequency = plot_frequency self.plot_frequency = plot_frequency
self.distributed = distributed self.distributed = distributed
self.swa_start = swa_start self.swa_start = swa_start
def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None: def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None:
# All ranks process data through model_inference # All ranks process data through model_inference
train_valid_dict = model_inference( train_valid_dict = model_inference(
self.train_valid_data, self.train_valid_data,
model, model,
self.output_args, self.output_args,
self.device, self.device,
self.distributed, self.distributed,
) )
test_dict = model_inference( test_dict = model_inference(
self.test_data, model, self.output_args, self.device, self.distributed self.test_data, model, self.output_args, self.device, self.distributed
) )
# Only rank 0 creates and saves plots # Only rank 0 creates and saves plots
if rank != 0: if rank != 0:
return return
data = pd.DataFrame( data = pd.DataFrame(
results for results in parse_training_results(self.results_dir) results for results in parse_training_results(self.results_dir)
) )
labels, quantities = error_type[self.table_type] labels, quantities = error_type[self.table_type]
for head in self.heads: for head in self.heads:
fig = plt.figure(layout="constrained", figsize=(10, 6)) fig = plt.figure(layout="constrained", figsize=(10, 6))
fig.suptitle( fig.suptitle(
f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16 f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16
) )
subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05) subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05)
axsTop = subfigs[0].subplots(1, 2, sharey=False) axsTop = subfigs[0].subplots(1, 2, sharey=False)
axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False) axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False)
plot_epoch_dependence(axsTop, data, head, model_epoch, labels) plot_epoch_dependence(axsTop, data, head, model_epoch, labels)
# Use the pre-computed results for plotting # Use the pre-computed results for plotting
plot_inference_from_results( plot_inference_from_results(
axsBottom, train_valid_dict, test_dict, head, quantities axsBottom, train_valid_dict, test_dict, head, quantities
) )
if self.swa_start is not None: if self.swa_start is not None:
# Add vertical lines to both axes # Add vertical lines to both axes
for ax in axsTop: for ax in axsTop:
ax.axvline( ax.axvline(
self.swa_start, self.swa_start,
color="black", color="black",
linestyle="dashed", linestyle="dashed",
linewidth=1, linewidth=1,
alpha=0.6, alpha=0.6,
label="Stage Two Starts", label="Stage Two Starts",
) )
stage = "stage_two" if self.swa_start < model_epoch else "stage_one" stage = "stage_two" if self.swa_start < model_epoch else "stage_one"
else: else:
stage = "stage_one" stage = "stage_one"
axsTop[0].legend(loc="best") axsTop[0].legend(loc="best")
# Save the figure using the appropriate stage in the filename # Save the figure using the appropriate stage in the filename
filename = f"{self.results_dir[:-4]}_{head}_{stage}.png" filename = f"{self.results_dir[:-4]}_{head}_{stage}.png"
fig.savefig(filename, dpi=300, bbox_inches="tight") fig.savefig(filename, dpi=300, bbox_inches="tight")
plt.close(fig) plt.close(fig)
def parse_training_results(path: str) -> List[dict]: def parse_training_results(path: str) -> List[dict]:
results = [] results = []
with open(path, mode="r", encoding="utf-8") as f: with open(path, mode="r", encoding="utf-8") as f:
for line in f: for line in f:
try: try:
d = json.loads(line.strip()) # Ensure it's valid JSON d = json.loads(line.strip()) # Ensure it's valid JSON
results.append(d) results.append(d)
except json.JSONDecodeError: except json.JSONDecodeError:
print( print(
f"Skipping invalid line: {line.strip()}" f"Skipping invalid line: {line.strip()}"
) # Handle non-JSON lines gracefully ) # Handle non-JSON lines gracefully
return results return results
def plot_epoch_dependence( def plot_epoch_dependence(
axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str] axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str]
) -> None: ) -> None:
valid_data = ( valid_data = (
data[data["mode"] == "eval"] data[data["mode"] == "eval"]
.groupby(["mode", "epoch", "head"]) .groupby(["mode", "epoch", "head"])
.agg(["mean", "std"]) .agg(["mean", "std"])
.reset_index() .reset_index()
) )
valid_data = valid_data[valid_data["head"] == head] valid_data = valid_data[valid_data["head"] == head]
train_data = ( train_data = (
data[data["mode"] == "opt"] data[data["mode"] == "opt"]
.groupby(["mode", "epoch"]) .groupby(["mode", "epoch"])
.agg(["mean", "std"]) .agg(["mean", "std"])
.reset_index() .reset_index()
) )
# ---- Plot loss ---- # ---- Plot loss ----
ax = axes[0] ax = axes[0]
ax.plot( ax.plot(
train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1 train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1
) )
ax.set_ylabel("Training Loss", color=colors[1]) ax.set_ylabel("Training Loss", color=colors[1])
ax.set_yscale("log") ax.set_yscale("log")
ax2 = ax.twinx() ax2 = ax.twinx()
ax2.plot( ax2.plot(
valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1 valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1
) )
ax2.set_ylabel("Validation Loss", color=colors[0]) ax2.set_ylabel("Validation Loss", color=colors[0])
ax2.set_yscale("log") ax2.set_yscale("log")
ax.axvline( ax.axvline(
model_epoch, model_epoch,
color="black", color="black",
linestyle="solid", linestyle="solid",
linewidth=1, linewidth=1,
alpha=0.8, alpha=0.8,
label="Loaded Model", label="Loaded Model",
) )
ax.set_xlabel("Epoch") ax.set_xlabel("Epoch")
ax.grid(True, linestyle="--", alpha=0.5) ax.grid(True, linestyle="--", alpha=0.5)
# ---- Plot selected keys ---- # ---- Plot selected keys ----
ax = axes[1] ax = axes[1]
twin_axes = [] twin_axes = []
for i, label in enumerate(labels): for i, label in enumerate(labels):
color = colors[(i + 3)] color = colors[(i + 3)]
key, axis_label = label key, axis_label = label
if i == 0: if i == 0:
main_ax = ax main_ax = ax
else: else:
main_ax = ax.twinx() main_ax = ax.twinx()
main_ax.spines.right.set_position(("outward", 60 * (i - 1))) main_ax.spines.right.set_position(("outward", 60 * (i - 1)))
twin_axes.append(main_ax) twin_axes.append(main_ax)
main_ax.plot( main_ax.plot(
valid_data["epoch"], valid_data["epoch"],
valid_data[key]["mean"] * 1e3, valid_data[key]["mean"] * 1e3,
color=color, color=color,
label=label, label=label,
linewidth=1, linewidth=1,
) )
main_ax.set_yscale("log") main_ax.set_yscale("log")
main_ax.set_ylabel(axis_label, color=color) main_ax.set_ylabel(axis_label, color=color)
main_ax.tick_params(axis="y", colors=color) main_ax.tick_params(axis="y", colors=color)
ax.axvline( ax.axvline(
model_epoch, model_epoch,
color="black", color="black",
linestyle="solid", linestyle="solid",
linewidth=1, linewidth=1,
alpha=0.8, alpha=0.8,
label="Loaded Model", label="Loaded Model",
) )
ax.set_xlabel("Epoch") ax.set_xlabel("Epoch")
ax.grid(True, linestyle="--", alpha=0.5) ax.grid(True, linestyle="--", alpha=0.5)
# INFERENCE========= # INFERENCE=========
def plot_inference_from_results( def plot_inference_from_results(
axes: np.ndarray, axes: np.ndarray,
train_valid_dict: dict, train_valid_dict: dict,
test_dict: dict, test_dict: dict,
head: str, head: str,
quantities: List[str], quantities: List[str],
) -> None: ) -> None:
for ax, quantity in zip(axes, quantities): for ax, quantity in zip(axes, quantities):
key, label = quantity key, label = quantity
# Store legend handles to avoid duplicates # Store legend handles to avoid duplicates
legend_labels = {} legend_labels = {}
# Plot train/valid data (each entry keeps its own name) # Plot train/valid data (each entry keeps its own name)
for name, result in train_valid_dict.items(): for name, result in train_valid_dict.items():
if "train" in name: if "train" in name:
fixed_color_train_valid = colors[1] fixed_color_train_valid = colors[1]
marker = "x" marker = "x"
else: else:
fixed_color_train_valid = colors[0] fixed_color_train_valid = colors[0]
marker = "+" marker = "+"
if head not in name: if head not in name:
continue continue
# Initialize scatter to None # Initialize scatter to None
scatter = None scatter = None
if key == "energy" and "energy" in result: if key == "energy" and "energy" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["energy"]["reference_per_atom"], result["energy"]["reference_per_atom"],
result["energy"]["predicted_per_atom"], result["energy"]["predicted_per_atom"],
marker=marker, marker=marker,
color=fixed_color_train_valid, color=fixed_color_train_valid,
label=name, label=name,
) )
elif key == "force" and "forces" in result: elif key == "force" and "forces" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["forces"]["reference"], result["forces"]["reference"],
result["forces"]["predicted"], result["forces"]["predicted"],
marker=marker, marker=marker,
color=fixed_color_train_valid, color=fixed_color_train_valid,
label=name, label=name,
) )
elif key == "stress" and "stress" in result: elif key == "stress" and "stress" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["stress"]["reference"], result["stress"]["reference"],
result["stress"]["predicted"], result["stress"]["predicted"],
marker=marker, marker=marker,
color=fixed_color_train_valid, color=fixed_color_train_valid,
label=name, label=name,
) )
elif key == "virials" and "virials" in result: elif key == "virials" and "virials" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["virials"]["reference_per_atom"], result["virials"]["reference_per_atom"],
result["virials"]["predicted_per_atom"], result["virials"]["predicted_per_atom"],
marker=marker, marker=marker,
color=fixed_color_train_valid, color=fixed_color_train_valid,
label=name, label=name,
) )
elif key == "dipole" and "dipole" in result: elif key == "dipole" and "dipole" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["dipole"]["reference_per_atom"], result["dipole"]["reference_per_atom"],
result["dipole"]["predicted_per_atom"], result["dipole"]["predicted_per_atom"],
marker=marker, marker=marker,
color=fixed_color_train_valid, color=fixed_color_train_valid,
label=name, label=name,
) )
# Add each train/valid dataset's name to the legend if scatter was assigned # Add each train/valid dataset's name to the legend if scatter was assigned
if scatter is not None: if scatter is not None:
legend_labels[name] = scatter legend_labels[name] = scatter
fixed_color_test = colors[2] # Color for test dataset fixed_color_test = colors[2] # Color for test dataset
# Plot test data (single legend entry) # Plot test data (single legend entry)
for name, result in test_dict.items(): for name, result in test_dict.items():
# Initialize scatter to None to avoid possibly used before assignment # Initialize scatter to None to avoid possibly used before assignment
scatter = None scatter = None
if key == "energy" and "energy" in result: if key == "energy" and "energy" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["energy"]["reference_per_atom"], result["energy"]["reference_per_atom"],
result["energy"]["predicted_per_atom"], result["energy"]["predicted_per_atom"],
marker="o", marker="o",
color=fixed_color_test, color=fixed_color_test,
label="Test", label="Test",
) )
elif key == "force" and "forces" in result: elif key == "force" and "forces" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["forces"]["reference"], result["forces"]["reference"],
result["forces"]["predicted"], result["forces"]["predicted"],
marker="o", marker="o",
color=fixed_color_test, color=fixed_color_test,
label="Test", label="Test",
) )
elif key == "stress" and "stress" in result: elif key == "stress" and "stress" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["stress"]["reference"], result["stress"]["reference"],
result["stress"]["predicted"], result["stress"]["predicted"],
marker="o", marker="o",
color=fixed_color_test, color=fixed_color_test,
label="Test", label="Test",
) )
elif key == "virials" and "virials" in result: elif key == "virials" and "virials" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["virials"]["reference_per_atom"], result["virials"]["reference_per_atom"],
result["virials"]["predicted_per_atom"], result["virials"]["predicted_per_atom"],
marker="o", marker="o",
color=fixed_color_test, color=fixed_color_test,
label="Test", label="Test",
) )
elif key == "dipole" and "dipole" in result: elif key == "dipole" and "dipole" in result:
scatter = ax.scatter( scatter = ax.scatter(
result["dipole"]["reference_per_atom"], result["dipole"]["reference_per_atom"],
result["dipole"]["predicted_per_atom"], result["dipole"]["predicted_per_atom"],
marker="o", marker="o",
color=fixed_color_test, color=fixed_color_test,
label="Test", label="Test",
) )
# Only add to legend_labels if scatter was assigned # Only add to legend_labels if scatter was assigned
if scatter is not None: if scatter is not None:
legend_labels["Test"] = scatter legend_labels["Test"] = scatter
# Add diagonal line for guide # Add diagonal line for guide
min_val = min(ax.get_xlim()[0], ax.get_ylim()[0]) min_val = min(ax.get_xlim()[0], ax.get_ylim()[0])
max_val = max(ax.get_xlim()[1], ax.get_ylim()[1]) max_val = max(ax.get_xlim()[1], ax.get_ylim()[1])
ax.plot( ax.plot(
[min_val, max_val], [min_val, max_val],
[min_val, max_val], [min_val, max_val],
linestyle="--", linestyle="--",
color="black", color="black",
alpha=0.7, alpha=0.7,
) )
# Set legend with unique entries (Test + individual train/valid names) # Set legend with unique entries (Test + individual train/valid names)
if legend_labels: if legend_labels:
ax.legend( ax.legend(
handles=legend_labels.values(), labels=legend_labels.keys(), loc="best" handles=legend_labels.values(), labels=legend_labels.keys(), loc="best"
) )
ax.set_xlabel(f"Reference {label}") ax.set_xlabel(f"Reference {label}")
ax.set_ylabel(f"MACE {label}") ax.set_ylabel(f"MACE {label}")
ax.grid(True, linestyle="--", alpha=0.5) ax.grid(True, linestyle="--", alpha=0.5)
def model_inference( def model_inference(
all_data_loaders: dict, all_data_loaders: dict,
model: torch.nn.Module, model: torch.nn.Module,
output_args: Dict[str, bool], output_args: Dict[str, bool],
device: str, device: str,
distributed: bool = False, distributed: bool = False,
): ):
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
results_dict = {} results_dict = {}
for name in all_data_loaders: for name in all_data_loaders:
data_loader = all_data_loaders[name] data_loader = all_data_loaders[name]
logging.debug(f"Running inference on {name} dataset") logging.debug(f"Running inference on {name} dataset")
scatter_metric = InferenceMetric().to(device) scatter_metric = InferenceMetric().to(device)
for batch in data_loader: for batch in data_loader:
batch = batch.to(device) batch = batch.to(device)
batch_dict = batch.to_dict() batch_dict = batch.to_dict()
output = model( output = model(
batch_dict, batch_dict,
training=False, training=False,
compute_force=output_args.get("forces", False), compute_force=output_args.get("forces", False),
compute_virials=output_args.get("virials", False), compute_virials=output_args.get("virials", False),
compute_stress=output_args.get("stress", False), compute_stress=output_args.get("stress", False),
) )
results = scatter_metric(batch, output) results = scatter_metric(batch, output)
if distributed: if distributed:
torch.distributed.barrier() torch.distributed.barrier()
results = scatter_metric.compute() results = scatter_metric.compute()
results_dict[name] = results results_dict[name] = results
scatter_metric.reset() scatter_metric.reset()
del data_loader del data_loader
for param in model.parameters(): for param in model.parameters():
param.requires_grad = True param.requires_grad = True
return results_dict return results_dict
def to_numpy(tensor: torch.Tensor) -> np.ndarray: def to_numpy(tensor: torch.Tensor) -> np.ndarray:
return tensor.cpu().detach().numpy() return tensor.cpu().detach().numpy()
class InferenceMetric(Metric): class InferenceMetric(Metric):
"""Metric class for collecting reference and predicted values for scatterplot visualization.""" """Metric class for collecting reference and predicted values for scatterplot visualization."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# Raw values # Raw values
self.add_state("ref_energies", default=[], dist_reduce_fx="cat") self.add_state("ref_energies", default=[], dist_reduce_fx="cat")
self.add_state("pred_energies", default=[], dist_reduce_fx="cat") self.add_state("pred_energies", default=[], dist_reduce_fx="cat")
self.add_state("ref_forces", default=[], dist_reduce_fx="cat") self.add_state("ref_forces", default=[], dist_reduce_fx="cat")
self.add_state("pred_forces", default=[], dist_reduce_fx="cat") self.add_state("pred_forces", default=[], dist_reduce_fx="cat")
self.add_state("ref_stress", default=[], dist_reduce_fx="cat") self.add_state("ref_stress", default=[], dist_reduce_fx="cat")
self.add_state("pred_stress", default=[], dist_reduce_fx="cat") self.add_state("pred_stress", default=[], dist_reduce_fx="cat")
self.add_state("ref_virials", default=[], dist_reduce_fx="cat") self.add_state("ref_virials", default=[], dist_reduce_fx="cat")
self.add_state("pred_virials", default=[], dist_reduce_fx="cat") self.add_state("pred_virials", default=[], dist_reduce_fx="cat")
self.add_state("ref_dipole", default=[], dist_reduce_fx="cat") self.add_state("ref_dipole", default=[], dist_reduce_fx="cat")
self.add_state("pred_dipole", default=[], dist_reduce_fx="cat") self.add_state("pred_dipole", default=[], dist_reduce_fx="cat")
# Per-atom normalized values # Per-atom normalized values
self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat") self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat") self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat") self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat") self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat") self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat") self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat")
# Store atom counts for each configuration # Store atom counts for each configuration
self.add_state("atom_counts", default=[], dist_reduce_fx="cat") self.add_state("atom_counts", default=[], dist_reduce_fx="cat")
# Counters # Counters
self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, batch, output): # pylint: disable=arguments-differ def update(self, batch, output): # pylint: disable=arguments-differ
"""Update metric states with new batch data.""" """Update metric states with new batch data."""
# Calculate number of atoms per configuration # Calculate number of atoms per configuration
atoms_per_config = batch.ptr[1:] - batch.ptr[:-1] atoms_per_config = batch.ptr[1:] - batch.ptr[:-1]
self.atom_counts.append(atoms_per_config) self.atom_counts.append(atoms_per_config)
# Energy # Energy
if output.get("energy") is not None and batch.energy is not None: if output.get("energy") is not None and batch.energy is not None:
self.n_energy += 1.0 self.n_energy += 1.0
self.ref_energies.append(batch.energy) self.ref_energies.append(batch.energy)
self.pred_energies.append(output["energy"]) self.pred_energies.append(output["energy"])
# Per-atom normalization # Per-atom normalization
self.ref_energies_per_atom.append(batch.energy / atoms_per_config) self.ref_energies_per_atom.append(batch.energy / atoms_per_config)
self.pred_energies_per_atom.append(output["energy"] / atoms_per_config) self.pred_energies_per_atom.append(output["energy"] / atoms_per_config)
# Forces # Forces
if output.get("forces") is not None and batch.forces is not None: if output.get("forces") is not None and batch.forces is not None:
self.n_forces += 1.0 self.n_forces += 1.0
self.ref_forces.append(batch.forces) self.ref_forces.append(batch.forces)
self.pred_forces.append(output["forces"]) self.pred_forces.append(output["forces"])
# Stress # Stress
if output.get("stress") is not None and batch.stress is not None: if output.get("stress") is not None and batch.stress is not None:
self.n_stress += 1.0 self.n_stress += 1.0
self.ref_stress.append(batch.stress) self.ref_stress.append(batch.stress)
self.pred_stress.append(output["stress"]) self.pred_stress.append(output["stress"])
# Virials # Virials
if output.get("virials") is not None and batch.virials is not None: if output.get("virials") is not None and batch.virials is not None:
self.n_virials += 1.0 self.n_virials += 1.0
self.ref_virials.append(batch.virials) self.ref_virials.append(batch.virials)
self.pred_virials.append(output["virials"]) self.pred_virials.append(output["virials"])
# Per-atom normalization # Per-atom normalization
atoms_per_config_3d = atoms_per_config.view(-1, 1, 1) atoms_per_config_3d = atoms_per_config.view(-1, 1, 1)
self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d) self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d)
self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d) self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d)
# Dipole # Dipole
if output.get("dipole") is not None and batch.dipole is not None: if output.get("dipole") is not None and batch.dipole is not None:
self.n_dipole += 1.0 self.n_dipole += 1.0
self.ref_dipole.append(batch.dipole) self.ref_dipole.append(batch.dipole)
self.pred_dipole.append(output["dipole"]) self.pred_dipole.append(output["dipole"])
atoms_per_config_3d = atoms_per_config.view(-1, 1) atoms_per_config_3d = atoms_per_config.view(-1, 1)
self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d) self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d)
self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d) self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d)
def _process_data(self, ref_list, pred_list): def _process_data(self, ref_list, pred_list):
# Handle different possible states of ref_list and pred_list in distributed mode # Handle different possible states of ref_list and pred_list in distributed mode
# Check if this is a list type object # Check if this is a list type object
if isinstance(ref_list, (list, tuple)): if isinstance(ref_list, (list, tuple)):
if len(ref_list) == 0: if len(ref_list) == 0:
return None, None return None, None
ref = torch.cat(ref_list).reshape(-1) ref = torch.cat(ref_list).reshape(-1)
pred = torch.cat(pred_list).reshape(-1) pred = torch.cat(pred_list).reshape(-1)
# Handle case where ref_list is already a tensor (happens after reset in distributed mode) # Handle case where ref_list is already a tensor (happens after reset in distributed mode)
elif isinstance(ref_list, torch.Tensor): elif isinstance(ref_list, torch.Tensor):
ref = ref_list.reshape(-1) ref = ref_list.reshape(-1)
pred = pred_list.reshape(-1) pred = pred_list.reshape(-1)
# Handle other possible types # Handle other possible types
else: else:
return None, None return None, None
return to_numpy(ref), to_numpy(pred) return to_numpy(ref), to_numpy(pred)
def compute(self): def compute(self):
"""Compute final results for scatterplot.""" """Compute final results for scatterplot."""
results = {} results = {}
# Process energies # Process energies
if self.n_energy: if self.n_energy:
ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies) ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies)
ref_e_pa, pred_e_pa = self._process_data( ref_e_pa, pred_e_pa = self._process_data(
self.ref_energies_per_atom, self.pred_energies_per_atom self.ref_energies_per_atom, self.pred_energies_per_atom
) )
results["energy"] = { results["energy"] = {
"reference": ref_e, "reference": ref_e,
"predicted": pred_e, "predicted": pred_e,
"reference_per_atom": ref_e_pa, "reference_per_atom": ref_e_pa,
"predicted_per_atom": pred_e_pa, "predicted_per_atom": pred_e_pa,
} }
# Process forces # Process forces
if self.n_forces: if self.n_forces:
ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces) ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces)
results["forces"] = { results["forces"] = {
"reference": ref_f, "reference": ref_f,
"predicted": pred_f, "predicted": pred_f,
} }
# Process stress # Process stress
if self.n_stress: if self.n_stress:
ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress) ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress)
results["stress"] = { results["stress"] = {
"reference": ref_s, "reference": ref_s,
"predicted": pred_s, "predicted": pred_s,
} }
# Process virials # Process virials
if self.n_virials: if self.n_virials:
ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials) ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials)
ref_v_pa, pred_v_pa = self._process_data( ref_v_pa, pred_v_pa = self._process_data(
self.ref_virials_per_atom, self.pred_virials_per_atom self.ref_virials_per_atom, self.pred_virials_per_atom
) )
results["virials"] = { results["virials"] = {
"reference": ref_v, "reference": ref_v,
"predicted": pred_v, "predicted": pred_v,
"reference_per_atom": ref_v_pa, "reference_per_atom": ref_v_pa,
"predicted_per_atom": pred_v_pa, "predicted_per_atom": pred_v_pa,
} }
# Process dipoles # Process dipoles
if self.n_dipole: if self.n_dipole:
ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole) ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole)
ref_d_pa, pred_d_pa = self._process_data( ref_d_pa, pred_d_pa = self._process_data(
self.ref_dipole_per_atom, self.pred_dipole_per_atom self.ref_dipole_per_atom, self.pred_dipole_per_atom
) )
results["dipole"] = { results["dipole"] = {
"reference": ref_d, "reference": ref_d,
"predicted": pred_d, "predicted": pred_d,
"reference_per_atom": ref_d_pa, "reference_per_atom": ref_d_pa,
"predicted_per_atom": pred_d_pa, "predicted_per_atom": pred_d_pa,
} }
return results return results
from .atomic_data import AtomicData from .atomic_data import AtomicData
from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5
from .lmdb_dataset import LMDBDataset from .lmdb_dataset import LMDBDataset
from .neighborhood import get_neighborhood from .neighborhood import get_neighborhood
from .utils import ( from .utils import (
Configuration, Configuration,
Configurations, Configurations,
KeySpecification, KeySpecification,
compute_average_E0s, compute_average_E0s,
config_from_atoms, config_from_atoms,
config_from_atoms_list, config_from_atoms_list,
load_from_xyz, load_from_xyz,
random_train_valid_split, random_train_valid_split,
save_AtomicData_to_HDF5, save_AtomicData_to_HDF5,
save_configurations_as_HDF5, save_configurations_as_HDF5,
save_dataset_as_HDF5, save_dataset_as_HDF5,
test_config_types, test_config_types,
update_keyspec_from_kwargs, update_keyspec_from_kwargs,
) )
__all__ = [ __all__ = [
"get_neighborhood", "get_neighborhood",
"Configuration", "Configuration",
"Configurations", "Configurations",
"random_train_valid_split", "random_train_valid_split",
"load_from_xyz", "load_from_xyz",
"test_config_types", "test_config_types",
"config_from_atoms", "config_from_atoms",
"config_from_atoms_list", "config_from_atoms_list",
"AtomicData", "AtomicData",
"compute_average_E0s", "compute_average_E0s",
"save_dataset_as_HDF5", "save_dataset_as_HDF5",
"HDF5Dataset", "HDF5Dataset",
"dataset_from_sharded_hdf5", "dataset_from_sharded_hdf5",
"save_AtomicData_to_HDF5", "save_AtomicData_to_HDF5",
"save_configurations_as_HDF5", "save_configurations_as_HDF5",
"KeySpecification", "KeySpecification",
"update_keyspec_from_kwargs", "update_keyspec_from_kwargs",
"LMDBDataset", "LMDBDataset",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment