Unverified Commit 73ff4f3a authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent fb246ae0
###########################################################################################
# Parsing functionalities
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import argparse
import os
from typing import Optional
from .default_keys import DefaultKeys
def build_default_arg_parser() -> argparse.ArgumentParser:
try:
import configargparse
parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
type=str,
is_config_file=True,
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Name and seed
parser.add_argument("--name", help="experiment name", required=True)
parser.add_argument("--seed", help="random seed", type=int, default=123)
# Directories
parser.add_argument(
"--work_dir",
help="set directory for all files and folders",
type=str,
default=".",
)
parser.add_argument(
"--log_dir", help="directory for log files", type=str, default=None
)
parser.add_argument(
"--model_dir", help="directory for final model", type=str, default=None
)
parser.add_argument(
"--checkpoints_dir",
help="directory for checkpoint files",
type=str,
default=None,
)
parser.add_argument(
"--results_dir", help="directory for results", type=str, default=None
)
parser.add_argument(
"--downloads_dir", help="directory for downloads", type=str, default=None
)
# Device and logging
parser.add_argument(
"--device",
help="select device",
type=str,
choices=["cpu", "cuda", "mps", "xpu"],
default="cpu",
)
parser.add_argument(
"--default_dtype",
help="set default dtype",
type=str,
choices=["float32", "float64"],
default="float64",
)
parser.add_argument(
"--distributed",
help="train in multi-GPU data parallel mode",
action="store_true",
default=False,
)
parser.add_argument("--log_level", help="log level", type=str, default="INFO")
parser.add_argument(
"--plot",
help="Plot results of training",
type=str2bool,
default=True,
)
parser.add_argument(
"--plot_frequency",
help="Set plotting frequency: '0' for only at the end or an integer N to plot every N epochs.",
type=int,
default="0",
)
parser.add_argument(
"--error_table",
help="Type of error table produced at the end of the training",
type=str,
choices=[
"PerAtomRMSE",
"TotalRMSE",
"PerAtomRMSEstressvirials",
"PerAtomMAEstressvirials",
"PerAtomMAE",
"TotalMAE",
"DipoleRMSE",
"DipoleMAE",
"EnergyDipoleRMSE",
],
default="PerAtomRMSE",
)
# Model
parser.add_argument(
"--model",
help="model type",
default="MACE",
choices=[
"BOTNet",
"MACE",
"ScaleShiftMACE",
"ScaleShiftBOTNet",
"AtomicDipolesMACE",
"EnergyDipolesMACE",
],
)
parser.add_argument(
"--r_max", help="distance cutoff (in Ang)", type=float, default=5.0
)
parser.add_argument(
"--radial_type",
help="type of radial basis functions",
type=str,
default="bessel",
choices=["bessel", "gaussian", "chebyshev"],
)
parser.add_argument(
"--num_radial_basis",
help="number of radial basis functions",
type=int,
default=8,
)
parser.add_argument(
"--num_cutoff_basis",
help="number of basis functions for smooth cutoff",
type=int,
default=5,
)
parser.add_argument(
"--pair_repulsion",
help="use pair repulsion term with ZBL potential",
action="store_true",
default=False,
)
parser.add_argument(
"--distance_transform",
help="use distance transform for radial basis functions",
default="None",
choices=["None", "Agnesi", "Soft"],
)
parser.add_argument(
"--interaction",
help="name of interaction block",
type=str,
default="RealAgnosticResidualInteractionBlock",
choices=[
"RealAgnosticResidualInteractionBlock",
"RealAgnosticAttResidualInteractionBlock",
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
],
)
parser.add_argument(
"--interaction_first",
help="name of interaction block",
type=str,
default="RealAgnosticResidualInteractionBlock",
choices=[
"RealAgnosticResidualInteractionBlock",
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
],
)
parser.add_argument(
"--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3
)
parser.add_argument(
"--correlation", help="correlation order at each layer", type=int, default=3
)
parser.add_argument(
"--num_interactions", help="number of interactions", type=int, default=2
)
parser.add_argument(
"--MLP_irreps",
help="hidden irreps of the MLP in last readout",
type=str,
default="16x0e",
)
parser.add_argument(
"--radial_MLP",
help="width of the radial MLP",
type=str,
default="[64, 64, 64]",
)
parser.add_argument(
"--hidden_irreps",
help="irreps for hidden node states",
type=str,
default=None,
)
# add option to specify irreps by channel number and max L
parser.add_argument(
"--num_channels",
help="number of embedding channels",
type=int,
default=None,
)
parser.add_argument(
"--max_L",
help="max L equivariance of the message",
type=int,
default=None,
)
parser.add_argument(
"--gate",
help="non linearity for last readout",
type=str,
default="silu",
choices=["silu", "tanh", "abs", "None"],
)
parser.add_argument(
"--scaling",
help="type of scaling to the output",
type=str,
default="rms_forces_scaling",
choices=["std_scaling", "rms_forces_scaling", "no_scaling"],
)
parser.add_argument(
"--avg_num_neighbors",
help="normalization factor for the message",
type=float,
default=1,
)
parser.add_argument(
"--compute_avg_num_neighbors",
help="normalization factor for the message",
type=str2bool,
default=True,
)
parser.add_argument(
"--compute_stress",
help="Select True to compute stress",
type=str2bool,
default=False,
)
parser.add_argument(
"--compute_forces",
help="Select True to compute forces",
type=str2bool,
default=True,
)
# Dataset
parser.add_argument(
"--train_file",
help="Training set file, format is .xyz or .h5",
type=str,
required=False,
)
parser.add_argument(
"--valid_file",
help="Validation set .xyz or .h5 file",
default=None,
type=str,
required=False,
)
parser.add_argument(
"--valid_fraction",
help="Fraction of training set used for validation",
type=float,
default=0.1,
required=False,
)
parser.add_argument(
"--test_file",
help="Test set .xyz pt .h5 file",
type=str,
)
parser.add_argument(
"--test_dir",
help="Path to directory with test files named as test_*.h5",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--multi_processed_test",
help="Boolean value for whether the test data was multiprocessed",
type=str2bool,
default=False,
required=False,
)
parser.add_argument(
"--num_workers",
help="Number of workers for data loading",
type=int,
default=0,
)
parser.add_argument(
"--pin_memory",
help="Pin memory for data loading",
default=True,
type=str2bool,
)
parser.add_argument(
"--atomic_numbers",
help="List of atomic numbers",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--mean",
help="Mean energy per atom of training set",
type=float,
default=None,
required=False,
)
parser.add_argument(
"--std",
help="Standard deviation of force components in the training set",
type=float,
default=None,
required=False,
)
parser.add_argument(
"--statistics_file",
help="json file containing statistics of training set",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--E0s",
help="Dictionary of isolated atom energies",
type=str,
default=None,
required=False,
)
# Fine-tuning
parser.add_argument(
"--foundation_filter_elements",
help="Filter element during fine-tuning",
type=str2bool,
default=True,
required=False,
)
parser.add_argument(
"--heads",
help="Dict of heads: containing individual files and E0s",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--multiheads_finetuning",
help="Boolean value for whether the model is multiheaded",
type=str2bool,
default=True,
)
parser.add_argument(
"--foundation_head",
help="Name of the head to use for fine-tuning",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--weight_pt_head",
help="Weight of the pretrained head in the loss function",
type=float,
default=1.0,
)
parser.add_argument(
"--num_samples_pt",
help="Number of samples in the pretrained head",
type=int,
default=10000,
)
parser.add_argument(
"--force_mh_ft_lr",
help="Force the multiheaded fine-tuning to use arg_parser lr",
type=str2bool,
default=False,
)
parser.add_argument(
"--subselect_pt",
help="Method to subselect the configurations of the pretraining set",
choices=["fps", "random"],
default="random",
)
parser.add_argument(
"--filter_type_pt",
help="Filtering method for collecting the pretraining set",
choices=["none", "combinations", "inclusive", "exclusive"],
default="none",
)
parser.add_argument(
"--pt_train_file",
help="Training set file for the pretrained head",
type=str,
default=None,
)
parser.add_argument(
"--pt_valid_file",
help="Validation set file for the pretrained head",
type=str,
default=None,
)
parser.add_argument(
"--foundation_model_elements",
help="Keep all elements of the foundation model during fine-tuning",
type=str2bool,
default=False,
)
parser.add_argument(
"--keep_isolated_atoms",
help="Keep isolated atoms in the dataset, useful for transfer learning",
type=str2bool,
default=False,
)
# Keys
parser.add_argument(
"--energy_key",
help="Key of reference energies in training xyz",
type=str,
default=DefaultKeys.ENERGY.value,
)
parser.add_argument(
"--forces_key",
help="Key of reference forces in training xyz",
type=str,
default=DefaultKeys.FORCES.value,
)
parser.add_argument(
"--virials_key",
help="Key of reference virials in training xyz",
type=str,
default=DefaultKeys.VIRIALS.value,
)
parser.add_argument(
"--stress_key",
help="Key of reference stress in training xyz",
type=str,
default=DefaultKeys.STRESS.value,
)
parser.add_argument(
"--dipole_key",
help="Key of reference dipoles in training xyz",
type=str,
default=DefaultKeys.DIPOLE.value,
)
parser.add_argument(
"--head_key",
help="Key of head in training xyz",
type=str,
default=DefaultKeys.HEAD.value,
)
parser.add_argument(
"--charges_key",
help="Key of atomic charges in training xyz",
type=str,
default=DefaultKeys.CHARGES.value,
)
parser.add_argument(
"--skip_evaluate_heads",
help="Comma-separated list of heads to skip during final evaluation",
type=str,
default="pt_head",
)
# Loss and optimization
parser.add_argument(
"--loss",
help="type of loss",
default="weighted",
choices=[
"ef",
"weighted",
"forces_only",
"virials",
"stress",
"dipole",
"huber",
"universal",
"energy_forces_dipole",
"l1l2energyforces",
],
)
parser.add_argument(
"--forces_weight", help="weight of forces loss", type=float, default=100.0
)
parser.add_argument(
"--swa_forces_weight",
"--stage_two_forces_weight",
help="weight of forces loss after starting Stage Two (previously called swa)",
type=float,
default=100.0,
dest="swa_forces_weight",
)
parser.add_argument(
"--energy_weight", help="weight of energy loss", type=float, default=1.0
)
parser.add_argument(
"--swa_energy_weight",
"--stage_two_energy_weight",
help="weight of energy loss after starting Stage Two (previously called swa)",
type=float,
default=1000.0,
dest="swa_energy_weight",
)
parser.add_argument(
"--virials_weight", help="weight of virials loss", type=float, default=1.0
)
parser.add_argument(
"--swa_virials_weight",
"--stage_two_virials_weight",
help="weight of virials loss after starting Stage Two (previously called swa)",
type=float,
default=10.0,
dest="swa_virials_weight",
)
parser.add_argument(
"--stress_weight", help="weight of stress loss", type=float, default=1.0
)
parser.add_argument(
"--swa_stress_weight",
"--stage_two_stress_weight",
help="weight of stress loss after starting Stage Two (previously called swa)",
type=float,
default=10.0,
dest="swa_stress_weight",
)
parser.add_argument(
"--dipole_weight", help="weight of dipoles loss", type=float, default=1.0
)
parser.add_argument(
"--swa_dipole_weight",
"--stage_two_dipole_weight",
help="weight of dipoles after starting Stage Two (previously called swa)",
type=float,
default=1.0,
dest="swa_dipole_weight",
)
parser.add_argument(
"--config_type_weights",
help="String of dictionary containing the weights for each config type",
type=str,
default='{"Default":1.0}',
)
parser.add_argument(
"--huber_delta",
help="delta parameter for huber loss",
type=float,
default=0.01,
)
parser.add_argument(
"--optimizer",
help="Optimizer for parameter optimization",
type=str,
default="adam",
choices=["adam", "adamw", "schedulefree"],
)
parser.add_argument(
"--beta",
help="Beta parameter for the optimizer",
type=float,
default=0.9,
)
parser.add_argument("--batch_size", help="batch size", type=int, default=10)
parser.add_argument(
"--valid_batch_size", help="Validation batch size", type=int, default=10
)
parser.add_argument(
"--lr", help="Learning rate of optimizer", type=float, default=0.01
)
parser.add_argument(
"--swa_lr",
"--stage_two_lr",
help="Learning rate of optimizer in Stage Two (previously called swa)",
type=float,
default=1e-3,
dest="swa_lr",
)
parser.add_argument(
"--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7
)
parser.add_argument(
"--amsgrad",
help="use amsgrad variant of optimizer",
action="store_true",
default=True,
)
parser.add_argument(
"--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau"
)
parser.add_argument(
"--lr_factor", help="Learning rate factor", type=float, default=0.8
)
parser.add_argument(
"--scheduler_patience", help="Learning rate factor", type=int, default=50
)
parser.add_argument(
"--lr_scheduler_gamma",
help="Gamma of learning rate scheduler",
type=float,
default=0.9993,
)
parser.add_argument(
"--swa",
"--stage_two",
help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them",
action="store_true",
default=False,
dest="swa",
)
parser.add_argument(
"--start_swa",
"--start_stage_two",
help="Number of epochs before changing to Stage Two loss weights",
type=int,
default=None,
dest="start_swa",
)
parser.add_argument(
"--lbfgs",
help="Switch to L-BFGS optimizer",
action="store_true",
default=False,
)
parser.add_argument(
"--ema",
help="use Exponential Moving Average",
action="store_true",
default=False,
)
parser.add_argument(
"--ema_decay",
help="Exponential Moving Average decay",
type=float,
default=0.99,
)
parser.add_argument(
"--max_num_epochs", help="Maximum number of epochs", type=int, default=2048
)
parser.add_argument(
"--patience",
help="Maximum number of consecutive epochs of increasing loss",
type=int,
default=2048,
)
parser.add_argument(
"--foundation_model",
help="Path to the foundation model for transfer learning",
type=str,
default=None,
)
parser.add_argument(
"--foundation_model_readout",
help="Use readout of foundation model for transfer learning",
action="store_false",
default=True,
)
parser.add_argument(
"--eval_interval", help="evaluate model every <n> epochs", type=int, default=1
)
parser.add_argument(
"--keep_checkpoints",
help="keep all checkpoints",
action="store_true",
default=False,
)
parser.add_argument(
"--save_all_checkpoints",
help="save all checkpoints",
action="store_true",
default=False,
)
parser.add_argument(
"--restart_latest",
help="restart optimizer from latest checkpoint",
action="store_true",
default=False,
)
parser.add_argument(
"--save_cpu",
help="Save a model to be loaded on cpu",
action="store_true",
default=False,
)
parser.add_argument(
"--clip_grad",
help="Gradient Clipping Value",
type=check_float_or_none,
default=10.0,
)
parser.add_argument(
"--dry_run",
help="Run all steps upto training to test settings.",
action="store_true",
default=False,
)
# option for cuequivariance acceleration
parser.add_argument(
"--enable_cueq",
help="Enable cuequivariance acceleration",
type=str2bool,
default=False,
)
# options for using Weights and Biases for experiment tracking
# to install see https://wandb.ai
parser.add_argument(
"--wandb",
help="Use Weights and Biases for experiment tracking",
action="store_true",
default=False,
)
parser.add_argument(
"--wandb_dir",
help="An absolute path to a directory where Weights and Biases metadata will be stored",
type=str,
default=None,
)
parser.add_argument(
"--wandb_project",
help="Weights and Biases project name",
type=str,
default="",
)
parser.add_argument(
"--wandb_entity",
help="Weights and Biases entity name",
type=str,
default="",
)
parser.add_argument(
"--wandb_name",
help="Weights and Biases experiment name",
type=str,
default="",
)
parser.add_argument(
"--wandb_log_hypers",
help="The hyperparameters to log in Weights and Biases",
type=list,
default=[
"num_channels",
"max_L",
"correlation",
"lr",
"swa_lr",
"weight_decay",
"batch_size",
"max_num_epochs",
"start_swa",
"energy_weight",
"forces_weight",
],
)
return parser
def build_preprocess_arg_parser() -> argparse.ArgumentParser:
try:
import configargparse
parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
type=str,
is_config_file=True,
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--train_file",
help="Training set h5 file",
type=str,
default=None,
required=True,
)
parser.add_argument(
"--valid_file",
help="Training set xyz file",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--num_process",
help="The user defined number of processes to use, as well as the number of files created.",
type=int,
default=int(os.cpu_count() / 4),
)
parser.add_argument(
"--valid_fraction",
help="Fraction of training set used for validation",
type=float,
default=0.1,
required=False,
)
parser.add_argument(
"--test_file",
help="Test set xyz file",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--work_dir",
help="set directory for all files and folders",
type=str,
default=".",
)
parser.add_argument(
"--h5_prefix",
help="Prefix for h5 files when saving",
type=str,
default="",
)
parser.add_argument(
"--r_max", help="distance cutoff (in Ang)", type=float, default=5.0
)
parser.add_argument(
"--config_type_weights",
help="String of dictionary containing the weights for each config type",
type=str,
default='{"Default":1.0}',
)
parser.add_argument(
"--energy_key",
help="Key of reference energies in training xyz",
type=str,
default=DefaultKeys.ENERGY.value,
)
parser.add_argument(
"--forces_key",
help="Key of reference forces in training xyz",
type=str,
default=DefaultKeys.FORCES.value,
)
parser.add_argument(
"--virials_key",
help="Key of reference virials in training xyz",
type=str,
default=DefaultKeys.VIRIALS.value,
)
parser.add_argument(
"--stress_key",
help="Key of reference stress in training xyz",
type=str,
default=DefaultKeys.STRESS.value,
)
parser.add_argument(
"--dipole_key",
help="Key of reference dipoles in training xyz",
type=str,
default=DefaultKeys.DIPOLE.value,
)
parser.add_argument(
"--charges_key",
help="Key of atomic charges in training xyz",
type=str,
default=DefaultKeys.CHARGES.value,
)
parser.add_argument(
"--atomic_numbers",
help="List of atomic numbers",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--compute_statistics",
help="Compute statistics for the dataset",
action="store_true",
default=False,
)
parser.add_argument(
"--batch_size",
help="batch size to compute average number of neighbours",
type=int,
default=16,
)
parser.add_argument(
"--scaling",
help="type of scaling to the output",
type=str,
default="rms_forces_scaling",
choices=["std_scaling", "rms_forces_scaling", "no_scaling"],
)
parser.add_argument(
"--E0s",
help="Dictionary of isolated atom energies",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--shuffle",
help="Shuffle the training dataset",
type=str2bool,
default=True,
)
parser.add_argument(
"--seed",
help="Random seed for splitting training and validation sets",
type=int,
default=123,
)
parser.add_argument(
"--head_key",
help="Key of head in training xyz",
type=str,
default=DefaultKeys.HEAD.value,
)
parser.add_argument(
"--heads",
help="Dict of heads: containing individual files and E0s",
type=str,
default=None,
required=False,
)
return parser
def check_float_or_none(value: str) -> Optional[float]:
try:
return float(value)
except ValueError:
if value != "None":
raise argparse.ArgumentTypeError(
f"{value} is an invalid value (float or None)"
) from None
return None
def str2bool(value):
if isinstance(value, bool):
return value
if value.lower() in ("yes", "true", "t", "y", "1"):
return True
if value.lower() in ("no", "false", "f", "n", "0"):
return False
raise argparse.ArgumentTypeError("Boolean value expected.")
import logging
import os
from e3nn import o3
def check_args(args):
"""
Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing
the (potentially) modified args and a list of log messages.
"""
log_messages = []
# Directories
# Use work_dir for all other directories as well, unless they were specified by the user
if args.log_dir is None:
args.log_dir = os.path.join(args.work_dir, "logs")
if args.model_dir is None:
args.model_dir = args.work_dir
if args.checkpoints_dir is None:
args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints")
if args.results_dir is None:
args.results_dir = os.path.join(args.work_dir, "results")
if args.downloads_dir is None:
args.downloads_dir = os.path.join(args.work_dir, "downloads")
# Model
# Check if hidden_irreps, num_channels and max_L are consistent
if args.hidden_irreps is None and args.num_channels is None and args.max_L is None:
args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1
elif (
args.hidden_irreps is not None
and args.num_channels is not None
and args.max_L is not None
):
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
log_messages.append(
(
"All of hidden_irreps, num_channels and max_L are specified",
logging.WARNING,
)
)
log_messages.append(
(
f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.",
logging.WARNING,
)
)
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
elif args.num_channels is not None and args.max_L is not None:
assert args.num_channels > 0, "num_channels must be positive integer"
assert args.max_L >= 0, "max_L must be non-negative integer"
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
elif args.hidden_irreps is not None:
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
args.num_channels = list(
{irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}
)[0]
args.max_L = o3.Irreps(args.hidden_irreps).lmax
elif args.max_L is not None and args.num_channels is None:
assert args.max_L >= 0, "max_L must be non-negative integer"
args.num_channels = 128
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
elif args.max_L is None and args.num_channels is not None:
assert args.num_channels > 0, "num_channels must be positive integer"
args.max_L = 1
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)
# Loss and optimization
# Check Stage Two loss start
if args.start_swa is not None:
args.swa = True
log_messages.append(
(
"Stage Two is activated as start_stage_two was defined",
logging.INFO,
)
)
if args.swa:
if args.start_swa is None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
if args.start_swa > args.max_num_epochs:
log_messages.append(
(
f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}",
logging.WARNING,
)
)
log_messages.append(
(
"Stage Two will not start, as start_stage_two > max_num_epochs",
logging.WARNING,
)
)
args.swa = False
return args, log_messages
###########################################################################################
# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger)
# Authors: Ilyes Batatia
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import collections
import itertools
import os
from typing import Iterator, List, Union
import numpy as np
import torch
from e3nn import o3
try:
import cuequivariance as cue
CUET_AVAILABLE = True
except ImportError:
CUET_AVAILABLE = False
USE_CUEQ_CG = os.environ.get("MACE_USE_CUEQ_CG", "0").lower() in (
"1",
"true",
"yes",
"y",
)
_TP = collections.namedtuple("_TP", "op, args")
_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop")
def _wigner_nj(
irrepss: List[o3.Irreps],
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
):
irrepss = [o3.Irreps(irreps) for irreps in irrepss]
if filter_ir_mid is not None:
filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]
if len(irrepss) == 1:
(irreps,) = irrepss
ret = []
e = torch.eye(irreps.dim, dtype=dtype)
i = 0
for mul, ir in irreps:
for _ in range(mul):
sl = slice(i, i + ir.dim)
ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
i += ir.dim
return ret
*irrepss_left, irreps_right = irrepss
ret = []
for ir_left, path_left, C_left in _wigner_nj(
irrepss_left,
normalization=normalization,
filter_ir_mid=filter_ir_mid,
dtype=dtype,
):
i = 0
for mul, ir in irreps_right:
for ir_out in ir_left * ir:
if filter_ir_mid is not None and ir_out not in filter_ir_mid:
continue
C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype)
if normalization == "component":
C *= ir_out.dim**0.5
if normalization == "norm":
C *= ir_left.dim**0.5 * ir.dim**0.5
C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
C = C.reshape(
ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
)
for u in range(mul):
E = torch.zeros(
ir_out.dim,
*(irreps.dim for irreps in irrepss_left),
irreps_right.dim,
dtype=dtype,
)
sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
E[..., sl] = C
ret += [
(
ir_out,
_TP(
op=(ir_left, ir, ir_out),
args=(
path_left,
_INPUT(len(irrepss_left), sl.start, sl.stop),
),
),
E,
)
]
i += mul * ir.dim
return sorted(ret, key=lambda x: x[0])
def U_matrix_real(
irreps_in: Union[str, o3.Irreps],
irreps_out: Union[str, o3.Irreps],
correlation: int,
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
use_cueq_cg=None,
):
irreps_out = o3.Irreps(irreps_out)
irrepss = [o3.Irreps(irreps_in)] * correlation
if correlation == 4:
filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)]
if use_cueq_cg is None:
use_cueq_cg = USE_CUEQ_CG
if use_cueq_cg and CUET_AVAILABLE:
return compute_U_cueq(irreps_in, irreps_out=irreps_out, correlation=correlation)
try:
wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype)
except NotImplementedError as e:
if CUET_AVAILABLE:
return compute_U_cueq(
irreps_in, irreps_out=irreps_out, correlation=correlation
)
raise NotImplementedError(
"The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance"
) from e
current_ir = wigners[0][0]
out = []
stack = torch.tensor([])
for ir, _, base_o3 in wigners:
if ir in irreps_out and ir == current_ir:
stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
last_ir = current_ir
elif ir in irreps_out and ir != current_ir:
if len(stack) != 0:
out += [last_ir, stack]
stack = base_o3.squeeze().unsqueeze(-1)
current_ir, last_ir = ir, ir
else:
current_ir = ir
out += [last_ir, stack]
return out
if CUET_AVAILABLE:
def compute_U_cueq(irreps_in, irreps_out, correlation=2):
U = []
irreps_in = cue.Irreps(O3_e3nn, str(irreps_in))
irreps_out = cue.Irreps(O3_e3nn, str(irreps_out))
for _, ir in irreps_out:
ir_str = str(ir)
U.append(ir_str)
U_matrix = cue.reduced_symmetric_tensor_product_basis(
irreps_in, correlation, keep_ir=ir, layout=cue.ir_mul
).array
U_matrix = U_matrix.reshape(ir.dim, *([irreps_in.dim] * correlation), -1)
if ir.dim == 1:
U_matrix = U_matrix[0]
U.append(torch.tensor(U_matrix))
return U
class O3_e3nn(cue.O3):
def __mul__( # pylint: disable=no-self-argument
rep1: "O3_e3nn", rep2: "O3_e3nn"
) -> Iterator["O3_e3nn"]:
return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)]
@classmethod
def clebsch_gordan(
cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn"
) -> np.ndarray:
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
if rep1.p * rep2.p == rep3.p:
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
rep3.dim
)
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
def __lt__( # pylint: disable=no-self-argument
rep1: "O3_e3nn", rep2: "O3_e3nn"
) -> bool:
rep2 = rep1._from(rep2)
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
@classmethod
def iterator(cls) -> Iterator["O3_e3nn"]:
for l in itertools.count(0):
yield O3_e3nn(l=l, p=1 * (-1) ** l)
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
else:
class O3_e3nn:
pass
print(
"cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled."
)
###########################################################################################
# Checkpointing
# Authors: Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import dataclasses
import logging
import os
import re
from typing import Dict, List, Optional, Tuple
import torch
from .torch_tools import TensorDict
Checkpoint = Dict[str, TensorDict]
@dataclasses.dataclass
class CheckpointState:
model: torch.nn.Module
optimizer: torch.optim.Optimizer
lr_scheduler: torch.optim.lr_scheduler.ExponentialLR
class CheckpointBuilder:
@staticmethod
def create_checkpoint(state: CheckpointState) -> Checkpoint:
return {
"model": state.model.state_dict(),
"optimizer": state.optimizer.state_dict(),
"lr_scheduler": state.lr_scheduler.state_dict(),
}
@staticmethod
def load_checkpoint(
state: CheckpointState, checkpoint: Checkpoint, strict: bool
) -> None:
state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore
state.optimizer.load_state_dict(checkpoint["optimizer"])
state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
@dataclasses.dataclass
class CheckpointPathInfo:
path: str
tag: str
epochs: int
swa: bool
class CheckpointIO:
def __init__(
self, directory: str, tag: str, keep: bool = False, swa_start: int = None
) -> None:
self.directory = directory
self.tag = tag
self.keep = keep
self.old_path: Optional[str] = None
self.swa_start = swa_start
self._epochs_string = "_epoch-"
self._filename_extension = "pt"
def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str:
if swa_start is not None and epochs >= swa_start:
return (
self.tag
+ self._epochs_string
+ str(epochs)
+ "_swa"
+ "."
+ self._filename_extension
)
return (
self.tag
+ self._epochs_string
+ str(epochs)
+ "."
+ self._filename_extension
)
def _list_file_paths(self) -> List[str]:
if not os.path.isdir(self.directory):
return []
all_paths = [
os.path.join(self.directory, f) for f in os.listdir(self.directory)
]
return [path for path in all_paths if os.path.isfile(path)]
def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]:
filename = os.path.basename(path)
regex = re.compile(
rf"^(?P<tag>.+){self._epochs_string}(?P<epochs>\d+)\.{self._filename_extension}$"
)
regex2 = re.compile(
rf"^(?P<tag>.+){self._epochs_string}(?P<epochs>\d+)_swa\.{self._filename_extension}$"
)
match = regex.match(filename)
match2 = regex2.match(filename)
swa = False
if not match:
if not match2:
return None
match = match2
swa = True
return CheckpointPathInfo(
path=path,
tag=match.group("tag"),
epochs=int(match.group("epochs")),
swa=swa,
)
def _get_latest_checkpoint_path(self, swa) -> Optional[str]:
all_file_paths = self._list_file_paths()
checkpoint_info_list = [
self._parse_checkpoint_path(path) for path in all_file_paths
]
selected_checkpoint_info_list = [
info for info in checkpoint_info_list if info and info.tag == self.tag
]
if len(selected_checkpoint_info_list) == 0:
logging.warning(
f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'"
)
return None
selected_checkpoint_info_list_swa = []
selected_checkpoint_info_list_no_swa = []
for ckp in selected_checkpoint_info_list:
if ckp.swa:
selected_checkpoint_info_list_swa.append(ckp)
else:
selected_checkpoint_info_list_no_swa.append(ckp)
if swa:
try:
latest_checkpoint_info = max(
selected_checkpoint_info_list_swa, key=lambda info: info.epochs
)
except ValueError:
logging.warning(
"No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint."
)
else:
latest_checkpoint_info = max(
selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs
)
return latest_checkpoint_info.path
def save(
self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False
) -> None:
if not self.keep and self.old_path and not keep_last:
logging.debug(f"Deleting old checkpoint file: {self.old_path}")
os.remove(self.old_path)
filename = self._get_checkpoint_filename(epochs, self.swa_start)
path = os.path.join(self.directory, filename)
logging.debug(f"Saving checkpoint: {path}")
os.makedirs(self.directory, exist_ok=True)
torch.save(obj=checkpoint, f=path)
self.old_path = path
def load_latest(
self, swa: Optional[bool] = False, device: Optional[torch.device] = None
) -> Optional[Tuple[Checkpoint, int]]:
path = self._get_latest_checkpoint_path(swa=swa)
if path is None:
return None
return self.load(path, device=device)
def load(
self, path: str, device: Optional[torch.device] = None
) -> Tuple[Checkpoint, int]:
checkpoint_info = self._parse_checkpoint_path(path)
if checkpoint_info is None:
raise RuntimeError(f"Cannot find path '{path}'")
logging.info(f"Loading checkpoint: {checkpoint_info.path}")
return (
torch.load(f=checkpoint_info.path, map_location=device),
checkpoint_info.epochs,
)
class CheckpointHandler:
def __init__(self, *args, **kwargs) -> None:
self.io = CheckpointIO(*args, **kwargs)
self.builder = CheckpointBuilder()
def save(
self, state: CheckpointState, epochs: int, keep_last: bool = False
) -> None:
checkpoint = self.builder.create_checkpoint(state)
self.io.save(checkpoint, epochs, keep_last)
def load_latest(
self,
state: CheckpointState,
swa: Optional[bool] = False,
device: Optional[torch.device] = None,
strict=False,
) -> Optional[int]:
result = self.io.load_latest(swa=swa, device=device)
if result is None:
return None
checkpoint, epochs = result
self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict)
return epochs
def load(
self,
state: CheckpointState,
path: str,
strict=False,
device: Optional[torch.device] = None,
) -> int:
checkpoint, epochs = self.io.load(path, device=device)
self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict)
return epochs
from contextlib import contextmanager
from functools import wraps
from typing import Callable, Tuple
try:
import torch._dynamo as dynamo
except ImportError:
dynamo = None
from e3nn import get_optimization_defaults, set_optimization_defaults
from torch import autograd, nn
from torch.fx import symbolic_trace
ModuleFactory = Callable[..., nn.Module]
TypeTuple = Tuple[type, ...]
@contextmanager
def disable_e3nn_codegen():
"""Context manager that disables the legacy PyTorch code generation used in e3nn."""
init_val = get_optimization_defaults()["jit_script_fx"]
set_optimization_defaults(jit_script_fx=False)
yield
set_optimization_defaults(jit_script_fx=init_val)
def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory:
"""Function transform that prepares a MACE module for torch.compile
Args:
func (ModuleFactory): A function that creates an nn.Module
allow_autograd (bool, optional): Force inductor compiler to inline call to
`torch.autograd.grad`. Defaults to True.
Returns:
ModuleFactory: Decorated function that creates a torch.compile compatible module
"""
if allow_autograd:
dynamo.allow_in_graph(autograd.grad)
else:
dynamo.disallow_in_graph(autograd.grad)
@wraps(func)
def wrapper(*args, **kwargs):
with disable_e3nn_codegen():
model = func(*args, **kwargs)
model = simplify(model)
return model
return wrapper
_SIMPLIFY_REGISTRY = set()
def simplify_if_compile(module: nn.Module) -> nn.Module:
"""Decorator to register a module for symbolic simplification
The decorated module will be simplifed using `torch.fx.symbolic_trace`.
This constrains the module to not have any dynamic control flow, see:
https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing
Args:
module (nn.Module): the module to register
Returns:
nn.Module: registered module
"""
_SIMPLIFY_REGISTRY.add(module)
return module
def simplify(module: nn.Module) -> nn.Module:
"""Recursively searches for registered modules to simplify with
`torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler.
Modules are registered with the `simplify_if_compile` decorator and
Args:
module (nn.Module): the module to simplify
Returns:
nn.Module: the simplified module
"""
simplify_types = tuple(_SIMPLIFY_REGISTRY)
for name, child in module.named_children():
if isinstance(child, simplify_types):
traced = symbolic_trace(child)
setattr(module, name, traced)
else:
simplify(child)
return module
from __future__ import annotations
from enum import Enum
class DefaultKeys(Enum):
ENERGY = "REF_energy"
FORCES = "REF_forces"
STRESS = "REF_stress"
VIRIALS = "REF_virials"
DIPOLE = "dipole"
HEAD = "head"
CHARGES = "REF_charges"
@staticmethod
def keydict() -> dict[str, str]:
key_dict = {}
for member in DefaultKeys:
key_name = f"{member.name.lower()}_key"
key_dict[key_name] = member.value
return key_dict
from .lmdb_dataset_tools import AseDBDataset
__all__ = ["AseDBDataset"]
# AseDBDataset Library
This library provides a standalone implementation of the AseDBDataset class extracted from the FairChem codebase. The AseDBDataset allows you to connect to ASE databases with various backends including JSON, SQLite, and LMDB.
## License Information
The code in this repository contains components from multiple sources with different licenses:
1. **Main Code (AseDBDataset, AseAtomsDataset, BaseDataset, etc.)**:
- Original Source: Meta's FairChem codebase
- License: MIT License
- Copyright: Meta, Inc. and its affiliates
2. **LMDBDatabase Component**:
- Original Source: Modified from ASE database JSON backend
- License: LGPL 2.1
- The ASE notice for the LGPL 2.1 license is available at: https://gitlab.com/ase/ase/-/blob/master/LICENSE
"""
This module contains the AseDBDataset class and its dependencies.
It is extracted from the fairchem codebase and adapted to remove dependencies on fairchem.
Original code copyright:
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
import bisect
import logging
import os
import zlib
from abc import ABC, abstractmethod
try:
from functools import cache, cached_property
except ImportError:
from functools import cached_property, lru_cache
cache = lru_cache(maxsize=None)
from glob import glob
from pathlib import Path
from typing import Any, Callable, TypeVar
import ase
import ase.db.core
import ase.db.row
import ase.io
import lmdb
import numpy as np
import orjson
import torch
# Type variable for generic dataset return type
T_co = TypeVar("T_co", covariant=True)
def rename_data_object_keys(data_object, key_mapping: dict[str, str | list[str]]):
"""Rename data object keys
Args:
data_object: data object
key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key}
new_key can be a list of new keys, for example,
prev_key: energy
new_key: [common_energy, oc20_energy]
This is currently required when we use a single target/label for multiple tasks
"""
for _property in key_mapping:
# catch for test data not containing labels
if _property in data_object:
list_of_new_keys = key_mapping[_property]
if isinstance(list_of_new_keys, str):
list_of_new_keys = [list_of_new_keys]
for new_property in list_of_new_keys:
if new_property == _property:
continue
assert new_property not in data_object
data_object[new_property] = data_object[_property]
if _property not in list_of_new_keys:
del data_object[_property]
return data_object
def apply_one_tags(
atoms: ase.Atoms, skip_if_nonzero: bool = True, skip_always: bool = False
):
"""
This function will apply tags of 1 to an ASE atoms object.
It is used as an atoms_transform in the datasets contained in this file.
Certain models will treat atoms differently depending on their tags.
For example, GemNet-OC by default will only compute triplet and quadruplet interactions
for atoms with non-zero tags. This model throws an error if there are no tagged atoms.
For this reason, the default behavior is to tag atoms in structures with no tags.
args:
skip_if_nonzero (bool): If at least one atom has a nonzero tag, do not tag any atoms
skip_always (bool): Do not apply any tags. This arg exists so that this function can be disabled
without needing to pass a callable (which is currently difficult to do with main.py)
"""
if skip_always:
return atoms
if np.all(atoms.get_tags() == 0) or not skip_if_nonzero:
atoms.set_tags(np.ones(len(atoms)))
return atoms
class UnsupportedDatasetError(ValueError):
pass
class BaseDataset(ABC):
"""Base Dataset class for all ASE datasets."""
def __init__(self, config: dict):
"""Initialize
Args:
config (dict): dataset configuration
"""
self.config = config
self.paths = []
if "src" in self.config:
if isinstance(config["src"], str):
self.paths = [Path(self.config["src"])]
else:
self.paths = tuple(Path(path) for path in sorted(config["src"]))
self.lin_ref = None
if self.config.get("lin_ref", False):
lin_ref = torch.tensor(
np.load(self.config["lin_ref"], allow_pickle=True)["coeff"]
)
self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False)
def __len__(self) -> int:
return self.num_samples
def metadata_hasattr(self, attr) -> bool:
return attr in self._metadata
@cached_property
def indices(self):
return np.arange(self.num_samples, dtype=int)
@cached_property
def _metadata(self) -> dict[str, np.ndarray]:
# logic to read metadata file here
metadata_npzs = []
if self.config.get("metadata_path", None) is not None:
metadata_npzs.append(
np.load(self.config["metadata_path"], allow_pickle=True)
)
else:
for path in self.paths:
if path.is_file():
metadata_file = path.parent / "metadata.npz"
else:
metadata_file = path / "metadata.npz"
if metadata_file.is_file():
metadata_npzs.append(np.load(metadata_file, allow_pickle=True))
if len(metadata_npzs) == 0:
logging.warning(
f"Could not find dataset metadata.npz files in '{self.paths}'"
)
return {}
metadata = {
field: np.concatenate([metadata[field] for metadata in metadata_npzs])
for field in metadata_npzs[0]
}
assert np.issubdtype(
metadata["natoms"].dtype, np.integer
), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}"
assert metadata["natoms"].shape[0] == len(
self
), "Loaded metadata and dataset size mismatch."
return metadata
def get_metadata(self, attr, idx):
if attr in self._metadata:
metadata_attr = self._metadata[attr]
if isinstance(idx, list):
return [metadata_attr[_idx] for _idx in idx]
return metadata_attr[idx]
return None
class Subset(BaseDataset):
"""A subset that also takes metadata if given."""
def __init__(
self,
dataset: BaseDataset,
indices: list[int],
metadata: dict[str, np.ndarray],
) -> None:
super().__init__(dataset.config)
self.dataset = dataset
self.metadata = metadata
self.indices = indices
self.num_samples = len(indices)
self.config = dataset.config
@cached_property
def _metadata(self) -> dict[str, np.ndarray]:
return self.dataset._metadata # pylint: disable=protected-access
def get_metadata(self, attr, idx):
if isinstance(idx, list):
return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]])
return self.dataset.get_metadata(attr, self.indices[idx])
class LMDBDatabase(ase.db.core.Database):
"""
This module is modified from the ASE db json backend
and is thus licensed under the corresponding LGPL2.1 license.
The ASE notice for the LGPL2.1 license is available here:
https://gitlab.com/ase/ase/-/blob/master/LICENSE
"""
def __init__( # pylint: disable=keyword-arg-before-vararg
self,
filename: str | Path | None = None,
create_indices: bool = True,
use_lock_file: bool = False,
serial: bool = False,
readonly: bool = False, # Moved after *args to make it keyword-only
*args,
**kwargs,
) -> None:
"""
For the most part, this is identical to the standard ase db initiation
arguments, except that we add a readonly flag.
"""
super().__init__(
Path(filename),
create_indices,
use_lock_file,
serial,
*args,
**kwargs,
)
# Add a readonly mode for when we're only training
# to make sure there's no parallel locks
self.readonly = readonly
if self.readonly:
# Open a new env
self.env = lmdb.open(
str(self.filename),
subdir=False,
meminit=False,
map_async=True,
readonly=True,
lock=False,
)
# Open a transaction and keep it open for fast read/writes!
self.txn = self.env.begin(write=False)
else:
# Open a new env with write access
self.env = lmdb.open(
str(self.filename),
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
)
self.txn = self.env.begin(write=True)
# Load all ids based on keys in the DB.
self.ids = []
self.deleted_ids = []
self._load_ids()
def __enter__(self) -> "LMDBDatabase":
return self
def __exit__(self, exc_type, exc_value, tb) -> None:
self.close()
def close(self) -> None:
# Close the lmdb environment and transaction
self.txn.commit()
self.env.close()
def _write(
self,
atoms: ase.Atoms | ase.db.row.AtomsRow,
key_value_pairs: dict,
data: dict | None,
id: int | None = None, # pylint: disable=redefined-builtin
) -> None:
# Call parent method with the original parameter name
super()._write(atoms, key_value_pairs, data)
mtime = ase.db.core.now()
if isinstance(atoms, ase.db.row.AtomsRow):
row = atoms
else:
row = ase.db.row.AtomsRow(atoms)
row.ctime = mtime
row.user = os.getenv("USER")
dct = {}
for key in row.__dict__:
# Use getattr to avoid accessing protected member directly
if key[0] == "_" or key == "id" or key in getattr(row, "_keys", []):
continue
dct[key] = row[key]
dct["mtime"] = mtime
if key_value_pairs:
dct["key_value_pairs"] = key_value_pairs
if data:
dct["data"] = data
constraints = row.get("constraints")
if constraints:
dct["constraints"] = [constraint.todict() for constraint in constraints]
# json doesn't like Cell objects, so make it an array
dct["cell"] = np.asarray(dct["cell"])
if id is None:
id = self._nextid
nextid = id + 1
else:
data = self.txn.get(f"{id}".encode("ascii"))
assert data is not None
# Add the new entry
self.txn.put(
f"{id}".encode("ascii"),
zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)),
)
# only append if idx is not in ids
if id not in self.ids:
self.ids.append(id)
self.txn.put(
"nextid".encode("ascii"),
zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)),
)
# check if id is in removed ids and remove accordingly
if id in self.deleted_ids:
self.deleted_ids.remove(id)
self._write_deleted_ids()
return id
def _update(
self,
idx: int,
key_value_pairs: dict | None = None,
data: dict | None = None,
):
# hack this to play nicely with ASE code
row = self._get_row(idx, include_data=True)
if data is not None or key_value_pairs is not None:
self._write(
atoms=row, key_value_pairs=key_value_pairs, data=data, id=idx
) # Fixed E1123 by using id=idx
def _write_deleted_ids(self):
self.txn.put(
"deleted_ids".encode("ascii"),
zlib.compress(
orjson.dumps(self.deleted_ids, option=orjson.OPT_SERIALIZE_NUMPY)
),
)
def delete(self, ids: list[int]) -> None:
for idx in ids:
self.txn.delete(f"{idx}".encode("ascii"))
self.ids.remove(idx)
self.deleted_ids += ids
self._write_deleted_ids()
def _get_row(self, idx: int, include_data: bool = True):
if idx is None:
assert len(self.ids) == 1
idx = self.ids[0]
data = self.txn.get(f"{idx}".encode("ascii"))
if data is not None:
dct = orjson.loads(zlib.decompress(data))
else:
raise KeyError(f"Id {idx} missing from the database!")
if not include_data:
dct.pop("data", None)
dct["id"] = idx
return ase.db.row.AtomsRow(dct)
def _get_row_by_index(self, index: int, include_data: bool = True):
"""Auxiliary function to get the ith entry, rather than a specific id"""
data = self.txn.get(f"{self.ids[index]}".encode("ascii"))
if data is not None:
dct = orjson.loads(zlib.decompress(data))
else:
raise KeyError(f"Id {id} missing from the database!")
if not include_data:
dct.pop("data", None)
dct["id"] = id
return ase.db.row.AtomsRow(dct)
def _select(
self,
keys,
cmps: list[tuple[str, str, str]],
explain: bool = False,
_verbosity: int = 0, # Unused parameter marked with underscore
limit: int | None = None,
offset: int = 0,
sort: str | None = None,
include_data: bool = True,
_columns: str = "all", # Unused parameter marked with underscore
):
if explain:
yield {"explain": (0, 0, 0, "scan table")}
return
if sort is not None:
if sort[0] == "-":
reverse = True
sort = sort[1:]
else:
reverse = False
rows = []
missing = []
for row in self._select(keys, cmps):
key = row.get(sort)
if key is None:
missing.append((0, row))
else:
rows.append((key, row))
rows.sort(reverse=reverse, key=lambda x: x[0])
rows += missing
if limit:
rows = rows[offset : offset + limit]
for _, row in rows:
yield row
return
if not limit:
limit = -offset - 1
cmps = [(key, ase.db.core.ops[op], val) for key, op, val in cmps]
n = 0
for idx in self.ids:
if n - offset == limit:
return
row = self._get_row(idx, include_data=include_data)
for key in keys:
if key not in row:
break
else:
for key, op, val in cmps:
if isinstance(key, int):
value = np.equal(row.numbers, key).sum()
else:
value = row.get(key)
if key == "pbc":
assert op in [ase.db.core.ops["="], ase.db.core.ops["!="]]
value = "".join("FT"[x] for x in value)
if value is None or not op(value, val):
break
else:
if n >= offset:
yield row
n += 1
@property
def metadata(self):
"""Override abstract metadata method from Database class."""
return self.db_metadata
@property
def db_metadata(self):
"""Load the metadata from the DB if present"""
if self._metadata is None:
metadata = self.txn.get("metadata".encode("ascii"))
if metadata is None:
self._metadata = {}
else:
self._metadata = orjson.loads(zlib.decompress(metadata))
return self._metadata.copy()
@db_metadata.setter
def db_metadata(self, dct):
self._metadata = dct
# Put the updated metadata dictionary
self.txn.put(
"metadata".encode("ascii"),
zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)),
)
@property
def _nextid(self):
"""Get the id of the next row to be written"""
# Get the nextid
nextid_data = self.txn.get("nextid".encode("ascii"))
if nextid_data:
return orjson.loads(zlib.decompress(nextid_data))
return 1 # Removed unnecessary else (R1705)
def count(self, selection=None, **kwargs) -> int:
"""Count rows.
See the select() method for the selection syntax. Use db.count() or
len(db) to count all rows.
"""
if selection is not None:
n = 0
for _row in self.select(selection, **kwargs):
n += 1
return n
return len(self.ids)
def _load_ids(self) -> None:
"""Load ids from the DB
Since ASE db ids are mostly 1-N integers, but can be missing entries
if ids have been deleted. To save space and operating under the assumption
that there will probably not be many deletions in most OCP datasets,
we just store the deleted ids.
"""
# Load the deleted ids
deleted_ids_data = self.txn.get("deleted_ids".encode("ascii"))
if deleted_ids_data is not None:
self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data))
# Reconstruct the full id list
self.ids = [i for i in range(1, self._nextid) if i not in set(self.deleted_ids)]
# Placeholder for AtomsToGraphs class
# This is a minimal implementation without the full functionality
class AtomsToGraphs:
"""Enhanced AtomsToGraphs implementation with proper property handling."""
def __init__(
self,
r_edges=False,
r_pbc=True,
r_energy=False,
r_forces=False,
r_stress=False,
r_data_keys=None,
**kwargs,
):
self.r_edges = r_edges
self.r_pbc = r_pbc
self.r_energy = r_energy
self.r_forces = r_forces
self.r_stress = r_stress
self.r_data_keys = r_data_keys or {}
self.kwargs = kwargs
def convert(self, atoms, sid=None):
"""
Convert ASE atoms to graph data format with proper property handling.
"""
from mace.tools.torch_geometric.data import Data
# Create a minimal data object with required properties
data = Data()
# Set positions
data.pos = torch.tensor(atoms.get_positions(), dtype=torch.float)
# Set atomic numbers
data.atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long)
# Set cell if available
if atoms.cell is not None:
data.cell = torch.tensor(atoms.get_cell(), dtype=torch.float)
# Set PBC if requested
if self.r_pbc:
data.pbc = torch.tensor(atoms.get_pbc(), dtype=torch.bool)
# Set energy if requested
if self.r_energy:
energy = self._get_property(atoms, "energy")
if energy is not None:
data.energy = torch.tensor(energy, dtype=torch.float)
# Set forces if requested
if self.r_forces:
forces = self._get_property(atoms, "forces")
if forces is not None:
data.forces = torch.tensor(forces, dtype=torch.float)
# Set stress if requested
if self.r_stress:
stress = self._get_property(atoms, "stress")
if stress is not None:
data.stress = torch.tensor(stress, dtype=torch.float)
# Set sid if provided
if sid is not None:
data.sid = sid
return data
def _get_property(self, atoms, prop_name):
"""Get property from atoms, checking custom names first then standard methods."""
# Check if we have a custom name for this property
custom_name = self.r_data_keys.get(prop_name)
# Try custom name in info dict
if custom_name and custom_name in atoms.info:
return atoms.info[custom_name]
# Try custom name in arrays dict
if custom_name and custom_name in atoms.arrays:
return atoms.arrays[custom_name]
# Try standard name in info dict
if prop_name in atoms.info:
return atoms.info[prop_name]
# Try standard name in arrays dict
if prop_name in atoms.arrays:
return atoms.arrays[prop_name]
# Try standard ASE methods
method_map = {
"energy": "get_potential_energy",
"forces": "get_forces",
"stress": "get_stress",
}
if prop_name in method_map and hasattr(atoms, method_map[prop_name]):
try:
method = getattr(atoms, method_map[prop_name])
return method()
except (
AttributeError,
RuntimeError,
) as exc: # Fixed W0718 by specifying exceptions
logging.debug(f"Error getting property {prop_name}: {exc}")
# Removed unnecessary pass (W0107)
return None
# Placeholder for DataTransforms class
class DataTransforms:
"""Minimal implementation of DataTransforms to satisfy dependencies."""
def __init__(self, transforms_config=None):
self.transforms_config = transforms_config or {}
def __call__(self, data):
"""Apply transforms to data"""
# No transforms applied in this minimal implementation
return data
class AseAtomsDataset(BaseDataset, ABC):
"""
This is an abstract Dataset that includes helpful utilities for turning
ASE atoms objects into OCP-usable data objects. This should not be instantiated directly
as get_atoms_object and load_dataset_get_ids are not implemented in this base class.
Derived classes must add at least two things:
self.get_atoms_object(id): a function that takes an identifier and returns a corresponding atoms object
self.load_dataset_get_ids(config: dict): This function is responsible for any initialization/loads
of the dataset and importantly must return a list of all possible identifiers that can be passed into
self.get_atoms_object(id)
Identifiers need not be any particular type.
"""
def __init__(
self,
config: dict,
atoms_transform: Callable[[ase.Atoms, Any], ase.Atoms] = apply_one_tags,
) -> None:
super().__init__(config)
a2g_args = config.get("a2g_args", {}) or {}
# set default to False if not set by user, assuming otf_graph will be used
if "r_edges" not in a2g_args:
a2g_args["r_edges"] = False
# Make sure we always include PBC info in the resulting atoms objects
a2g_args["r_pbc"] = True
self.a2g = AtomsToGraphs(**a2g_args)
self.key_mapping = self.config.get("key_mapping", None)
self.transforms = DataTransforms(self.config.get("transforms", {}))
self.atoms_transform = atoms_transform
if self.config.get("keep_in_memory", False):
self.__getitem__ = cache(self.__getitem__)
self.ids = self._load_dataset_get_ids(config)
self.num_samples = len(self.ids)
if len(self.ids) == 0:
raise ValueError(
rf"No valid ase data found! \n"
f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}"
)
def __getitem__(self, idx): # pylint: disable=method-hidden
# Handle slicing
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# Get atoms object via derived class method
atoms = self.get_atoms(self.ids[idx])
# Transform atoms object
if self.atoms_transform is not None:
atoms = self.atoms_transform(
atoms, **self.config.get("atoms_transform_args", {})
)
sid = atoms.info.get("sid", self.ids[idx])
fid = atoms.info.get("fid", torch.tensor([0]))
# Convert to data object
data_object = self.a2g.convert(atoms, sid)
data_object.fid = fid
data_object.natoms = len(atoms)
# apply linear reference
if self.a2g.r_energy is True and self.lin_ref is not None:
data_object.energy -= sum(self.lin_ref[data_object.atomic_numbers.long()])
# Transform data object
data_object = self.transforms(data_object)
if self.key_mapping is not None:
data_object = rename_data_object_keys(data_object, self.key_mapping)
if self.config.get("include_relaxed_energy", False):
data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx])
return data_object
@abstractmethod
def get_atoms(self, idx: str | int) -> ase.Atoms:
# This function should return an ASE atoms object.
raise NotImplementedError(
"Returns an ASE atoms object. Derived classes should implement this function."
)
@abstractmethod
def _load_dataset_get_ids(self, config):
# This function should return a list of ids that can be used to index into the database
raise NotImplementedError(
"Every ASE dataset needs to declare a function to load the dataset and return a list of ids."
)
def get_relaxed_energy(self, identifier):
raise NotImplementedError(
"Reading relaxed energy from trajectory or file is not implemented with this dataset. "
"If relaxed energies are saved with the atoms info dictionary, they can be used by passing the keys in "
"the r_data_keys argument under a2g_args."
)
def get_metadata(self, attr, idx):
# try the parent method
metadata = super().get_metadata(attr, idx)
if metadata is not None:
return metadata
# try to resolve it here
if attr != "natoms":
return None
if isinstance(idx, (list, np.ndarray)):
return np.array([self.get_metadata(attr, i) for i in idx])
return len(self.get_atoms(idx))
class AseDBDataset(AseAtomsDataset):
"""
This Dataset connects to an ASE Database, allowing the storage of atoms objects
with a variety of backends including JSON, SQLite, and database server options.
"""
def _load_dataset_get_ids(self, config: dict) -> list[int]:
if isinstance(config["src"], list):
filepaths = []
for path in sorted(config["src"]):
if os.path.isdir(path):
filepaths.extend(sorted(glob(f"{path}/*")))
elif os.path.isfile(path):
filepaths.append(path)
else:
raise RuntimeError(f"Error reading dataset in {path}!")
elif os.path.isfile(config["src"]):
filepaths = [config["src"]]
elif os.path.isdir(config["src"]):
filepaths = sorted(glob(f'{config["src"]}/*'))
else:
filepaths = sorted(glob(config["src"]))
self.dbs = []
for path in filepaths:
try:
self.dbs.append(self.connect_db(path, config.get("connect_args", {})))
except ValueError:
logging.debug(
f"Tried to connect to {path} but it's not an ASE database!"
)
self.select_args = config.get("select_args", {})
if self.select_args is None:
self.select_args = {}
# Get all unique IDs from the databases
self.db_ids = []
for db in self.dbs:
if hasattr(db, "ids") and self.select_args == {}:
self.db_ids.append(db.ids)
else:
# this is the slow alternative
self.db_ids.append([row.id for row in db.select(**self.select_args)])
idlens = [len(ids) for ids in self.db_ids]
self._idlen_cumulative = np.cumsum(idlens).tolist()
return list(range(sum(idlens)))
def get_atoms(self, idx: int) -> ase.Atoms:
"""Get atoms object corresponding to datapoint idx.
Args:
idx (int): index in dataset
Returns:
atoms: ASE atoms corresponding to datapoint idx
"""
# Figure out which db this should be indexed from
db_idx = bisect.bisect(self._idlen_cumulative, idx)
# Extract index of element within that db
el_idx = idx
if db_idx != 0:
el_idx = idx - self._idlen_cumulative[db_idx - 1]
assert el_idx >= 0
# Use a wrapper method to avoid protected access warning
atoms_row = self.get_row_from_db(db_idx, el_idx)
# Convert to atoms object
atoms = atoms_row.toatoms()
# Put data back into atoms info
if isinstance(atoms_row.data, dict):
atoms.info.update(atoms_row.data)
# Add key-value pairs directly to atoms.info
if hasattr(atoms_row, "key_value_pairs") and atoms_row.key_value_pairs:
atoms.info.update(atoms_row.key_value_pairs)
# Create a SinglePointCalculator to attach energy, forces and stress to atoms
calc_kwargs = {}
# Check for energy, forces, stress in atoms_row and store in info & calc_kwargs
for prop in ["energy", "forces", "stress", "free_energy"]:
if hasattr(atoms_row, prop) and getattr(atoms_row, prop) is not None:
value = getattr(atoms_row, prop)
calc_kwargs[prop] = value
atoms.info[prop] = value
# If we have custom data mappings, copy the standard properties to the custom names
a2g_args = self.config.get("a2g_args", {}) or {}
r_data_keys = a2g_args.get("r_data_keys", {})
if r_data_keys:
# Map from standard names to custom names (in reverse of how they'll be used)
for custom_key, standard_key in r_data_keys.items():
if standard_key in atoms.info:
atoms.info[custom_key] = atoms.info[standard_key]
elif standard_key in atoms.arrays:
atoms.arrays[custom_key] = atoms.arrays[standard_key]
# Create calculator if we have any properties
if calc_kwargs:
from ase.calculators.singlepoint import SinglePointCalculator
calc = SinglePointCalculator(atoms, **calc_kwargs)
atoms.calc = calc
return atoms
def get_row_from_db(self, db_idx, el_idx):
"""Get a row from the database at the given indices."""
db = self.dbs[db_idx]
row_id = self.db_ids[db_idx][el_idx]
if isinstance(db, LMDBDatabase):
return db._get_row(row_id) # pylint: disable=protected-access
return db.get(row_id)
@staticmethod
def connect_db(
address: str | Path, connect_args: dict | None = None
) -> ase.db.core.Database:
if connect_args is None:
connect_args = {}
db_type = connect_args.get("type", "extract_from_name")
if db_type in ("lmdb", "aselmdb") or (
db_type == "extract_from_name"
and str(address).rsplit(".", maxsplit=1)[-1] in ("lmdb", "aselmdb")
):
return LMDBDatabase(address, readonly=True, **connect_args)
return ase.db.connect(address, **connect_args)
def __del__(self):
for db in self.dbs:
if hasattr(db, "close"):
db.close()
def sample_property_metadata(
self,
) -> dict: # Removed unused argument num_samples (W0613)
"""
Sample property metadata from the database.
This method was previously using the copy module which is now removed.
"""
logging.warning(
"You specified a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!"
)
if self.dbs[0].metadata == {}:
return {}
# Fixed unnecessary comprehension (R1721)
return dict(self.dbs[0].metadata.items())
import torch
from mace.tools.utils import AtomicNumberTable
def load_foundations_elements(
model: torch.nn.Module,
model_foundations: torch.nn.Module,
table: AtomicNumberTable,
load_readout=False,
use_shift=True,
use_scale=True,
max_L=2,
):
"""
Load the foundations of a model into a model for fine-tuning.
"""
assert model_foundations.r_max == model.r_max
z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers])
model_heads = model.heads
new_z_table = table
num_species_foundations = len(z_table.zs)
num_channels_foundation = (
model_foundations.node_embedding.linear.weight.shape[0]
// num_species_foundations
)
indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs]
num_radial = model.radial_embedding.out_dim
num_species = len(indices_weights)
max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access
model.node_embedding.linear.weight = torch.nn.Parameter(
model_foundations.node_embedding.linear.weight.view(
num_species_foundations, -1
)[indices_weights, :]
.flatten()
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis":
model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter(
model_foundations.radial_embedding.bessel_fn.bessel_weights.clone()
)
for i in range(int(model.num_interactions)):
model.interactions[i].linear_up.weight = torch.nn.Parameter(
model_foundations.interactions[i].linear_up.weight.clone()
)
model.interactions[i].avg_num_neighbors = model_foundations.interactions[
i
].avg_num_neighbors
for j in range(4): # Assuming 4 layers in conv_tp_weights,
layer_name = f"layer{j}"
if j == 0:
getattr(model.interactions[i].conv_tp_weights, layer_name).weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
)
.weight[:num_radial, :]
.clone()
)
)
else:
getattr(model.interactions[i].conv_tp_weights, layer_name).weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
).weight.clone()
)
)
model.interactions[i].linear.weight = torch.nn.Parameter(
model_foundations.interactions[i].linear.weight.clone()
)
if model.interactions[i].__class__.__name__ in [
"RealAgnosticResidualInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
.skip_tp.weight.reshape(
num_channels_foundation,
num_species_foundations,
num_channels_foundation,
)[:, indices_weights, :]
.flatten()
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
else:
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
.skip_tp.weight.reshape(
num_channels_foundation,
(max_ell + 1),
num_species_foundations,
num_channels_foundation,
)[:, :, indices_weights, :]
.flatten()
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
if model.interactions[i].__class__.__name__ in [
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
).weight.clone()
)
)
# Transferring products
for i in range(2): # Assuming 2 products modules
max_range = max_L + 1 if i == 0 else 1
for j in range(max_range): # Assuming 3 contractions in symmetric_contractions
model.products[i].symmetric_contractions.contractions[j].weights_max = (
torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights_max[indices_weights, :, :]
.clone()
)
)
for k in range(2): # Assuming 2 weights in each contraction
model.products[i].symmetric_contractions.contractions[j].weights[k] = (
torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights[k][indices_weights, :, :]
.clone()
)
)
model.products[i].linear.weight = torch.nn.Parameter(
model_foundations.products[i].linear.weight.clone()
)
if load_readout:
# Transferring readouts
model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone()
model_readouts_zero_linear_weight = (
model_foundations.readouts[0]
.linear.weight.view(num_channels_foundation, -1)
.repeat(1, len(model_heads))
.flatten()
.clone()
)
model.readouts[0].linear.weight = torch.nn.Parameter(
model_readouts_zero_linear_weight
)
shape_input_1 = (
model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps
)
shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps
model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone()
model_readouts_one_linear_1_weight = (
model_foundations.readouts[1]
.linear_1.weight.view(num_channels_foundation, -1)
.repeat(1, len(model_heads))
.flatten()
.clone()
)
model.readouts[1].linear_1.weight = torch.nn.Parameter(
model_readouts_one_linear_1_weight
)
model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone()
model_readouts_one_linear_2_weight = model_foundations.readouts[
1
].linear_2.weight.view(shape_input_1, -1).repeat(
len(model_heads), len(model_heads)
).flatten().clone() / (
((shape_input_1) / (shape_output_1)) ** 0.5
)
model.readouts[1].linear_2.weight = torch.nn.Parameter(
model_readouts_one_linear_2_weight
)
if model_foundations.scale_shift is not None:
if use_scale:
model.scale_shift.scale = model_foundations.scale_shift.scale.repeat(
len(model_heads)
).clone()
if use_shift:
model.scale_shift.shift = model_foundations.scale_shift.shift.repeat(
len(model_heads)
).clone()
return model
def load_foundations(
model,
model_foundations,
):
for name, param in model_foundations.named_parameters():
if name in model.state_dict().keys():
if "readouts" not in name:
model.state_dict()[name].copy_(param)
return model
import ast
import logging
import numpy as np
from e3nn import o3
from mace import modules
from mace.tools.finetuning_utils import load_foundations_elements
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.utils import AtomicNumberTable
def configure_model(
args,
train_loader,
atomic_energies,
model_foundation=None,
heads=None,
z_table=None,
head_configs=None,
):
# Selecting outputs
compute_virials = args.loss == "virials"
compute_stress = args.loss in ("stress", "huber", "universal")
if compute_virials:
args.compute_virials = True
args.error_table = "PerAtomRMSEstressvirials"
elif compute_stress:
args.compute_stress = True
args.error_table = "PerAtomRMSEstressvirials"
output_args = {
"energy": args.compute_energy,
"forces": args.compute_forces,
"virials": compute_virials,
"stress": compute_stress,
"dipoles": args.compute_dipole,
}
logging.info(
f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}"
)
logging.info("===========MODEL DETAILS===========")
if args.scaling == "no_scaling":
args.std = 1.0
if head_configs is not None:
for head_config in head_configs:
head_config.std = 1.0
logging.info("No scaling selected")
if (
head_configs is not None
and args.std is not None
and not isinstance(args.std, list)
):
atomic_inter_scale = []
for head_config in head_configs:
if hasattr(head_config, "std") and head_config.std is not None:
atomic_inter_scale.append(head_config.std)
elif args.std is not None:
atomic_inter_scale.append(
args.std if isinstance(args.std, float) else 1.0
)
args.std = atomic_inter_scale
elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE":
args.mean, args.std = modules.scaling_classes[args.scaling](
train_loader, atomic_energies
)
# Build model
if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]:
logging.info("Loading FOUNDATION model")
model_config_foundation = extract_config_mace_model(model_foundation)
model_config_foundation["atomic_energies"] = atomic_energies
if args.foundation_model_elements:
foundation_z_table = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
model_config_foundation["atomic_numbers"] = foundation_z_table.zs
model_config_foundation["num_elements"] = len(foundation_z_table)
z_table = foundation_z_table
logging.info(
f"Using all elements from foundation model: {foundation_z_table.zs}"
)
else:
model_config_foundation["atomic_numbers"] = z_table.zs
model_config_foundation["num_elements"] = len(z_table)
logging.info(f"Using filtered elements: {z_table.zs}")
args.max_L = model_config_foundation["hidden_irreps"].lmax
if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE":
model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads)
else:
model_config_foundation["atomic_inter_shift"] = (
_determine_atomic_inter_shift(args.mean, heads)
)
model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads)
args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"]
args.model = "FoundationMACE"
model_config_foundation["heads"] = heads
model_config = model_config_foundation
logging.info("Model configuration extracted from foundation model")
logging.info("Using universal loss function for fine-tuning")
logging.info(
f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})"
)
logging.info(
f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}"
)
logging.info(
f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)"
)
logging.info(
f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}"
)
else:
logging.info("Building model")
logging.info(
f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})"
)
logging.info(
f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}"
)
logging.info(
f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions"
)
logging.info(
f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)"
)
logging.info(
f"Distance transform for radial basis functions: {args.distance_transform}"
)
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
logging.info(f"Hidden irreps: {args.hidden_irreps}")
model_config = dict(
r_max=args.r_max,
num_bessel=args.num_radial_basis,
num_polynomial_cutoff=args.num_cutoff_basis,
max_ell=args.max_ell,
interaction_cls=modules.interaction_classes[args.interaction],
num_interactions=args.num_interactions,
num_elements=len(z_table),
hidden_irreps=o3.Irreps(args.hidden_irreps),
atomic_energies=atomic_energies,
avg_num_neighbors=args.avg_num_neighbors,
atomic_numbers=z_table.zs,
)
model_config_foundation = None
model = _build_model(args, model_config, model_config_foundation, heads)
if model_foundation is not None:
model = load_foundations_elements(
model,
model_foundation,
z_table,
load_readout=args.foundation_filter_elements,
max_L=args.max_L,
)
return model, output_args
def _determine_atomic_inter_shift(mean, heads):
if isinstance(mean, np.ndarray):
if mean.size == 1:
return mean.item()
if mean.size == len(heads):
return mean.tolist()
logging.info("Mean not in correct format, using default value of 0.0")
return [0.0] * len(heads)
if isinstance(mean, list) and len(mean) == len(heads):
return mean
if isinstance(mean, float):
return [mean] * len(heads)
logging.info("Mean not in correct format, using default value of 0.0")
return [0.0] * len(heads)
def _build_model(
args, model_config, model_config_foundation, heads
): # pylint: disable=too-many-return-statements
if args.model == "MACE":
if args.interaction_first not in [
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
]:
args.interaction_first = "RealAgnosticInteractionBlock"
return modules.ScaleShiftMACE(
**model_config,
pair_repulsion=args.pair_repulsion,
distance_transform=args.distance_transform,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[args.interaction_first],
MLP_irreps=o3.Irreps(args.MLP_irreps),
atomic_inter_scale=args.std,
atomic_inter_shift=[0.0] * len(heads),
radial_MLP=ast.literal_eval(args.radial_MLP),
radial_type=args.radial_type,
heads=heads,
)
if args.model == "ScaleShiftMACE":
return modules.ScaleShiftMACE(
**model_config,
pair_repulsion=args.pair_repulsion,
distance_transform=args.distance_transform,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[args.interaction_first],
MLP_irreps=o3.Irreps(args.MLP_irreps),
atomic_inter_scale=args.std,
atomic_inter_shift=args.mean,
radial_MLP=ast.literal_eval(args.radial_MLP),
radial_type=args.radial_type,
heads=heads,
)
if args.model == "FoundationMACE":
return modules.ScaleShiftMACE(**model_config_foundation)
if args.model == "ScaleShiftBOTNet":
# say it is deprecated
raise RuntimeError("ScaleShiftBOTNet is deprecated, use MACE instead")
if args.model == "BOTNet":
raise RuntimeError("BOTNet is deprecated, use MACE instead")
if args.model == "AtomicDipolesMACE":
assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model"
assert (
args.error_table == "DipoleRMSE"
), "Use error_table DipoleRMSE with AtomicDipolesMACE model"
return modules.AtomicDipolesMACE(
**model_config,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[
"RealAgnosticInteractionBlock"
],
MLP_irreps=o3.Irreps(args.MLP_irreps),
)
if args.model == "EnergyDipolesMACE":
assert (
args.loss == "energy_forces_dipole"
), "Use energy_forces_dipole loss with EnergyDipolesMACE model"
assert (
args.error_table == "EnergyDipoleRMSE"
), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model"
return modules.EnergyDipolesMACE(
**model_config,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[
"RealAgnosticInteractionBlock"
],
MLP_irreps=o3.Irreps(args.MLP_irreps),
)
raise RuntimeError(f"Unknown model: '{args.model}'")
import argparse
import ast
import dataclasses
import logging
import os
import urllib.request
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from mace.cli.fine_tuning_select import (
FilteringType,
SelectionSettings,
SubselectType,
select_samples,
)
from mace.data import KeySpecification
from mace.tools.scripts_utils import SubsetCollection, get_dataset_from_xyz
@dataclasses.dataclass
class HeadConfig:
head_name: str
key_specification: KeySpecification
train_file: Optional[Union[str, List[str]]] = None
valid_file: Optional[Union[str, List[str]]] = None
test_file: Optional[str] = None
test_dir: Optional[str] = None
E0s: Optional[Any] = None
statistics_file: Optional[str] = None
valid_fraction: Optional[float] = None
config_type_weights: Optional[Dict[str, float]] = None
keep_isolated_atoms: Optional[bool] = None
atomic_numbers: Optional[Union[List[int], List[str]]] = None
mean: Optional[float] = None
std: Optional[float] = None
avg_num_neighbors: Optional[float] = None
compute_avg_num_neighbors: Optional[bool] = None
collections: Optional[SubsetCollection] = None
train_loader: Optional[torch.utils.data.DataLoader] = None
z_table: Optional[Any] = None
atomic_energies_dict: Optional[Dict[str, float]] = None
def dict_head_to_dataclass(
head: Dict[str, Any], head_name: str, args: argparse.Namespace
) -> HeadConfig:
"""Convert head dictionary to HeadConfig dataclass."""
# parser+head args that have no defaults but are required
if (args.train_file is None) and (head.get("train_file", None) is None):
raise ValueError(
"train file is not set in the head config yaml or via command line args"
)
return HeadConfig(
head_name=head_name,
train_file=head.get("train_file", args.train_file),
valid_file=head.get("valid_file", args.valid_file),
test_file=head.get("test_file", None),
test_dir=head.get("test_dir", None),
E0s=head.get("E0s", args.E0s),
statistics_file=head.get("statistics_file", args.statistics_file),
valid_fraction=head.get("valid_fraction", args.valid_fraction),
config_type_weights=head.get("config_type_weights", args.config_type_weights),
compute_avg_num_neighbors=head.get(
"compute_avg_num_neighbors", args.compute_avg_num_neighbors
),
atomic_numbers=head.get("atomic_numbers", args.atomic_numbers),
mean=head.get("mean", args.mean),
std=head.get("std", args.std),
avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors),
key_specification=head["key_specification"],
keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms),
)
def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]:
"""Prepare a default head from args."""
return {
"Default": {
"train_file": args.train_file,
"valid_file": args.valid_file,
"test_file": args.test_file,
"test_dir": args.test_dir,
"E0s": args.E0s,
"statistics_file": args.statistics_file,
"key_specification": args.key_specification,
"valid_fraction": args.valid_fraction,
"config_type_weights": args.config_type_weights,
"keep_isolated_atoms": args.keep_isolated_atoms,
}
}
def prepare_pt_head(
args: argparse.Namespace,
pt_keyspec: KeySpecification,
foundation_model_num_neighbours: float,
) -> Dict[str, Any]:
"""Prepare a pretraining head from args."""
if (
args.foundation_model in ["small", "medium", "large"]
or args.pt_train_file == "mp"
):
logging.info(
"Using foundation model for multiheads finetuning with Materials Project data"
)
pt_keyspec.update(
info_keys={"energy": "energy", "stress": "stress"},
arrays_keys={"forces": "forces"},
)
pt_head = {
"train_file": "mp",
"E0s": "foundation",
"statistics_file": None,
"key_specification": pt_keyspec,
"avg_num_neighbors": foundation_model_num_neighbours,
"compute_avg_num_neighbors": False,
}
else:
pt_head = {
"train_file": args.pt_train_file,
"valid_file": args.pt_valid_file,
"E0s": "foundation",
"statistics_file": args.statistics_file,
"valid_fraction": args.valid_fraction,
"key_specification": pt_keyspec,
"avg_num_neighbors": foundation_model_num_neighbours,
"keep_isolated_atoms": args.keep_isolated_atoms,
"compute_avg_num_neighbors": False,
}
return pt_head
def assemble_mp_data(
args: argparse.Namespace,
head_config_pt: HeadConfig,
tag: str,
) -> SubsetCollection:
"""Assemble Materials Project data for fine-tuning."""
try:
checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz"
cache_dir = (
Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace"
)
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_dataset_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
logging.info("Downloading MP structures for finetuning")
_, http_msg = urllib.request.urlretrieve(
checkpoint_url, cached_dataset_path
)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Dataset download failed, please check the URL {checkpoint_url}"
)
logging.info(f"Materials Project dataset to {cached_dataset_path}")
output = f"mp_finetuning-{tag}.xyz"
atomic_numbers = (
ast.literal_eval(args.atomic_numbers)
if args.atomic_numbers is not None
else None
)
settings = SelectionSettings(
configs_pt=cached_dataset_path,
output=f"mp_finetuning-{tag}.xyz",
atomic_numbers=atomic_numbers,
num_samples=args.num_samples_pt,
seed=args.seed,
head_pt="pbe_mp",
weight_pt=args.weight_pt_head,
filtering_type=FilteringType(args.filter_type_pt),
subselect=SubselectType(args.subselect_pt),
default_dtype=args.default_dtype,
)
select_samples(settings)
head_config_pt.train_file = [output]
collections_mp, _ = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=output,
valid_path=None,
valid_fraction=args.valid_fraction,
config_type_weights=None,
test_path=None,
seed=args.seed,
key_specification=head_config_pt.key_specification,
head_name="pt_head",
keep_isolated_atoms=args.keep_isolated_atoms,
)
return collections_mp
except Exception as exc:
raise RuntimeError(
"Model or descriptors download failed and no local model found"
) from exc
import logging
import os
from pathlib import Path
from typing import Any, List, Optional, Union
import torch
from torch.utils.data import ConcatDataset
from mace import data
from mace.tools.scripts_utils import check_path_ase_read
from mace.tools.torch_geometric.dataset import Dataset
from mace.tools.utils import AtomicNumberTable
def normalize_file_paths(file_paths: Union[str, List[str]]) -> List[str]:
"""
Normalize file paths to a list format.
Args:
file_paths: Either a string or a list of strings representing file paths
Returns:
A list of file paths
"""
if isinstance(file_paths, str):
return [file_paths]
if isinstance(file_paths, list):
return file_paths
raise ValueError(f"Unexpected file paths format: {type(file_paths)}")
def load_dataset_for_path(
file_path: Union[str, Path, List[str]],
r_max: float,
z_table: AtomicNumberTable,
heads: List[str],
head_config: Any,
collection: Optional[Any] = None,
) -> Union[Dataset, List]:
"""
Load a dataset from a file path based on its format.
Args:
file_path: Path to the dataset file
r_max: Cutoff radius
z_table: Atomic number table
heads: List of head names
head_name: Current head name
**kwargs: Additional arguments
Returns:
Loaded dataset
"""
if isinstance(file_path, list):
if len(file_path) == 1:
file_path = file_path[0]
if isinstance(file_path, list):
is_ase_readable = all(check_path_ase_read(p) for p in file_path)
if not is_ase_readable:
raise ValueError(
"Not all paths in the list are ASE readable, not supported"
)
if isinstance(file_path, str):
is_ase_readable = check_path_ase_read(file_path)
if is_ase_readable:
assert (
collection is not None
), "Collection must be provided for ASE readable files"
return [
data.AtomicData.from_config(
config, z_table=z_table, cutoff=r_max, heads=heads
)
for config in collection
]
filepath = Path(file_path)
if filepath.is_dir():
if filepath.name.endswith("_lmdb") or any(
f.endswith(".lmdb") or f.endswith(".aselmdb") for f in os.listdir(filepath)
):
logging.info(f"Loading LMDB dataset from {file_path}")
return data.LMDBDataset(
file_path,
r_max=r_max,
z_table=z_table,
heads=heads,
head=head_config.head_name,
)
h5_files = list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5"))
if h5_files:
logging.info(f"Loading HDF5 dataset from directory {file_path}")
try:
return data.dataset_from_sharded_hdf5(
file_path,
r_max=r_max,
z_table=z_table,
heads=heads,
head=head_config.head_name,
)
except Exception as e:
logging.error(f"Error loading sharded HDF5 dataset: {e}")
raise
if "lmdb" in str(filepath).lower() or "aselmdb" in str(filepath).lower():
logging.info(f"Loading LMDB dataset based on path name: {file_path}")
return data.LMDBDataset(
file_path,
r_max=r_max,
z_table=z_table,
heads=heads,
head=head_config.head_name,
)
logging.info(f"Attempting to load directory as HDF5 dataset: {file_path}")
try:
return data.dataset_from_sharded_hdf5(
file_path,
r_max=r_max,
z_table=z_table,
heads=heads,
head=head_config.head_name,
)
except Exception as e:
logging.error(f"Error loading as sharded HDF5: {e}")
raise
suffix = filepath.suffix.lower()
if suffix in (".h5", ".hdf5"):
logging.info(f"Loading single HDF5 file: {file_path}")
return data.HDF5Dataset(
file_path,
r_max=r_max,
z_table=z_table,
heads=heads,
head=head_config.head_name,
)
if suffix in (".lmdb", ".aselmdb", ".db"):
logging.info(f"Loading single LMDB file: {file_path}")
return data.LMDBDataset(
file_path,
r_max=r_max,
z_table=z_table,
heads=heads,
head=head_config.head_name,
)
logging.info(f"Attempting to load as LMDB: {file_path}")
return data.LMDBDataset(
file_path,
r_max=r_max,
z_table=z_table,
heads=heads,
head=head_config.head_name,
)
def combine_datasets(datasets, head_name):
"""
Combine multiple datasets which might be of different types.
Args:
datasets: List of datasets (can be mixed types)
head_name: Name of the current head
Returns:
Combined dataset
"""
if not datasets:
return []
if all(isinstance(ds, list) for ds in datasets):
logging.info(f"Combining {len(datasets)} list datasets for head '{head_name}'")
return [item for sublist in datasets for item in sublist]
if all(not isinstance(ds, list) for ds in datasets):
logging.info(
f"Combining {len(datasets)} Dataset objects for head '{head_name}'"
)
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
logging.info(f"Converting mixed dataset types for head '{head_name}'")
try:
all_items = []
for ds in datasets:
if isinstance(ds, list):
all_items.extend(ds)
else:
all_items.extend([ds[i] for i in range(len(ds))])
return all_items
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to convert mixed datasets to list: {e}")
try:
dataset_objects = []
for ds in datasets:
if isinstance(ds, list):
from torch.utils.data import TensorDataset
# Convert list to a Dataset
dataset_objects.append(
TensorDataset(*[torch.tensor([i]) for i in range(len(ds))])
)
else:
dataset_objects.append(ds)
return ConcatDataset(dataset_objects)
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to convert mixed datasets to ConcatDataset: {e}")
logging.warning(
"Could not combine datasets of different types. Using only the first dataset."
)
return datasets[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment