Unverified Commit 273418ee authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Merge pull request #1 from hjhk258/main

Fix
parents 7fb73825 9d7b4f63
__version__ = "0.3.13" __version__ = "0.3.13"
__all__ = ["__version__"] __all__ = ["__version__"]
from .foundations_models import mace_anicc, mace_mp, mace_off from .foundations_models import mace_anicc, mace_mp, mace_off
from .lammps_mace import LAMMPS_MACE from .lammps_mace import LAMMPS_MACE
from .mace import MACECalculator from .mace import MACECalculator
__all__ = [ __all__ = [
"MACECalculator", "MACECalculator",
"LAMMPS_MACE", "LAMMPS_MACE",
"mace_mp", "mace_mp",
"mace_off", "mace_off",
"mace_anicc", "mace_anicc",
] ]
import os import os
import urllib.request import urllib.request
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import torch import torch
from ase import units from ase import units
from ase.calculators.mixing import SumCalculator from ase.calculators.mixing import SumCalculator
from .mace import MACECalculator from .mace import MACECalculator
module_dir = os.path.dirname(__file__) module_dir = os.path.dirname(__file__)
local_model_path = os.path.join( local_model_path = os.path.join(
module_dir, "foundations_models/mace-mpa-0-medium.model" module_dir, "foundations_models/mace-mpa-0-medium.model"
) )
def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
""" """
Downloads or locates the MACE-MP checkpoint file. Downloads or locates the MACE-MP checkpoint file.
Args: Args:
model (str, optional): Path to the model or size specification. model (str, optional): Path to the model or size specification.
Defaults to None which uses the medium model. Defaults to None which uses the medium model.
Returns: Returns:
str: Path to the downloaded (or cached, if previously loaded) checkpoint file. str: Path to the downloaded (or cached, if previously loaded) checkpoint file.
""" """
if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path): if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path):
return local_model_path return local_model_path
urls = { urls = {
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
"small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model", "small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model",
"medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model", "medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", "small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model",
"medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", "medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", "large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
"medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", "medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
"medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", "medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
"medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", "medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
"mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", "mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
"mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", "mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
} }
checkpoint_url = ( checkpoint_url = (
urls.get(model, urls["medium-mpa-0"]) urls.get(model, urls["medium-mpa-0"])
if model if model
in ( in (
None, None,
"small", "small",
"medium", "medium",
"large", "large",
"small-0b", "small-0b",
"medium-0b", "medium-0b",
"small-0b2", "small-0b2",
"medium-0b2", "medium-0b2",
"large-0b2", "large-0b2",
"medium-0b3", "medium-0b3",
"medium-mpa-0", "medium-mpa-0",
"medium-omat-0", "medium-omat-0",
) )
else model else model
) )
if checkpoint_url == urls["medium-mpa-0"]: if checkpoint_url == urls["medium-mpa-0"]:
print( print(
"Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument" "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument"
) )
ASL_checkpoint_urls = { ASL_checkpoint_urls = {
urls["medium-omat-0"], urls["medium-omat-0"],
urls["mace-matpes-pbe-0"], urls["mace-matpes-pbe-0"],
urls["mace-matpes-r2scan-0"], urls["mace-matpes-r2scan-0"],
} }
if checkpoint_url in ASL_checkpoint_urls: if checkpoint_url in ASL_checkpoint_urls:
print( print(
"Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license." "Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license."
) )
cache_dir = os.path.expanduser("~/.cache/mace") cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join( checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
) )
cached_model_path = f"{cache_dir}/{checkpoint_url_name}" cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path): if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
print(f"Downloading MACE model from {checkpoint_url!r}") print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
if "Content-Type: text/html" in http_msg: if "Content-Type: text/html" in http_msg:
raise RuntimeError( raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}" f"Model download failed, please check the URL {checkpoint_url}"
) )
print(f"Cached MACE model to {cached_model_path}") print(f"Cached MACE model to {cached_model_path}")
return cached_model_path return cached_model_path
def mace_mp( def mace_mp(
model: Union[str, Path] = None, model: Union[str, Path] = None,
device: str = "", device: str = "",
default_dtype: str = "float32", default_dtype: str = "float32",
dispersion: bool = False, dispersion: bool = False,
damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"]
dispersion_xc: str = "pbe", dispersion_xc: str = "pbe",
dispersion_cutoff: float = 40.0 * units.Bohr, dispersion_cutoff: float = 40.0 * units.Bohr,
return_raw_model: bool = False, return_raw_model: bool = False,
**kwargs, **kwargs,
) -> MACECalculator: ) -> MACECalculator:
""" """
Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements). Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements).
The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models. The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models.
Note: Note:
If you are using this function, please cite the relevant paper for the Materials Project, If you are using this function, please cite the relevant paper for the Materials Project,
any paper associated with the MACE model, and also the following: any paper associated with the MACE model, and also the following:
- MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena, - MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena,
Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096 Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096
- MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b, - MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b,
DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal
- Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang, - Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang,
Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920 Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920
Args: Args:
model (str, optional): Path to the model. Defaults to None which first checks for model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default model from figshare. Specify "small", a local model and then downloads the default model from figshare. Specify "small",
"medium" or "large" to download a smaller or larger model from figshare. "medium" or "large" to download a smaller or larger model from figshare.
device (str, optional): Device to use for the model. Defaults to "cuda" if available. device (str, optional): Device to use for the model. Defaults to "cuda" if available.
default_dtype (str, optional): Default dtype for the model. Defaults to "float32". default_dtype (str, optional): Default dtype for the model. Defaults to "float32".
dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False. dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False.
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator. **kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
Returns: Returns:
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
""" """
try: try:
if model in ( if model in (
None, None,
"small", "small",
"medium", "medium",
"large", "large",
"medium-mpa-0", "medium-mpa-0",
"small-0b", "small-0b",
"medium-0b", "medium-0b",
"small-0b2", "small-0b2",
"medium-0b2", "medium-0b2",
"medium-0b3", "medium-0b3",
"large-0b2", "large-0b2",
"medium-omat-0", "medium-omat-0",
) or str(model).startswith("https:"): ) or str(model).startswith("https:"):
model_path = download_mace_mp_checkpoint(model) model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}") print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else: else:
if not Path(model).exists(): if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally") raise FileNotFoundError(f"{model} not found locally")
model_path = model model_path = model
except Exception as exc: except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc raise RuntimeError("Model download failed and no local model found") from exc
device = device or ("cuda" if torch.cuda.is_available() else "cpu") device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if default_dtype == "float64": if default_dtype == "float64":
print( print(
"Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization."
) )
if default_dtype == "float32": if default_dtype == "float32":
print( print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
) )
if return_raw_model: if return_raw_model:
return torch.load(model_path, map_location=device) return torch.load(model_path, map_location=device)
mace_calc = MACECalculator( mace_calc = MACECalculator(
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
) )
if not dispersion: if not dispersion:
return mace_calc return mace_calc
try: try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc: except ImportError as exc:
raise RuntimeError( raise RuntimeError(
"Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)"
) from exc ) from exc
print("Using TorchDFTD3Calculator for D3 dispersion corrections") print("Using TorchDFTD3Calculator for D3 dispersion corrections")
dtype = torch.float32 if default_dtype == "float32" else torch.float64 dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator( d3_calc = TorchDFTD3Calculator(
device=device, device=device,
damping=damping, damping=damping,
dtype=dtype, dtype=dtype,
xc=dispersion_xc, xc=dispersion_xc,
cutoff=dispersion_cutoff, cutoff=dispersion_cutoff,
**kwargs, **kwargs,
) )
return SumCalculator([mace_calc, d3_calc]) return SumCalculator([mace_calc, d3_calc])
def mace_off( def mace_off(
model: Union[str, Path] = None, model: Union[str, Path] = None,
device: str = "", device: str = "",
default_dtype: str = "float64", default_dtype: str = "float64",
return_raw_model: bool = False, return_raw_model: bool = False,
**kwargs, **kwargs,
) -> MACECalculator: ) -> MACECalculator:
""" """
Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models. Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models.
The model is released under the ASL license. The model is released under the ASL license.
Note: Note:
If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211 If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211
Args: Args:
model (str, optional): Path to the model. Defaults to None which first checks for model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off. a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off.
Specify "small", "medium" or "large" to download a smaller or larger model. Specify "small", "medium" or "large" to download a smaller or larger model.
device (str, optional): Device to use for the model. Defaults to "cuda". device (str, optional): Device to use for the model. Defaults to "cuda".
default_dtype (str, optional): Default dtype for the model. Defaults to "float64". default_dtype (str, optional): Default dtype for the model. Defaults to "float64".
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator. **kwargs: Passed to MACECalculator.
Returns: Returns:
MACECalculator: trained on the MACE-OFF23 dataset MACECalculator: trained on the MACE-OFF23 dataset
""" """
try: try:
if model in (None, "small", "medium", "large") or str(model).startswith( if model in (None, "small", "medium", "large") or str(model).startswith(
"https:" "https:"
): ):
urls = dict( urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
) )
checkpoint_url = ( checkpoint_url = (
urls.get(model, urls["medium"]) urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large") if model in (None, "small", "medium", "large")
else model else model
) )
cache_dir = os.path.expanduser("~/.cache/mace") cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}" cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path): if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
# download and save to disk # download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}") print(f"Downloading MACE model from {checkpoint_url!r}")
print( print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
) )
print( print(
"ASL is based on the Gnu Public License, but does not permit commercial use" "ASL is based on the Gnu Public License, but does not permit commercial use"
) )
urllib.request.urlretrieve(checkpoint_url, cached_model_path) urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}") print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg) print(msg)
else: else:
if not Path(model).exists(): if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally") raise FileNotFoundError(f"{model} not found locally")
except Exception as exc: except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc raise RuntimeError("Model download failed and no local model found") from exc
device = device or ("cuda" if torch.cuda.is_available() else "cpu") device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if return_raw_model: if return_raw_model:
return torch.load(model, map_location=device) return torch.load(model, map_location=device)
if default_dtype == "float64": if default_dtype == "float64":
print( print(
"Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization."
) )
if default_dtype == "float32": if default_dtype == "float32":
print( print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
) )
mace_calc = MACECalculator( mace_calc = MACECalculator(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs model_paths=model, device=device, default_dtype=default_dtype, **kwargs
) )
return mace_calc return mace_calc
def mace_anicc( def mace_anicc(
device: str = "cuda", device: str = "cuda",
model_path: str = None, model_path: str = None,
return_raw_model: bool = False, return_raw_model: bool = False,
) -> MACECalculator: ) -> MACECalculator:
""" """
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O). Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
The model is released under the MIT license. The model is released under the MIT license.
Note: Note:
If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following: If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following:
- "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322 - "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322
""" """
if model_path is None: if model_path is None:
model_path = os.path.join( model_path = os.path.join(
module_dir, "foundations_models/ani500k_large_CC.model" module_dir, "foundations_models/ani500k_large_CC.model"
) )
print( print(
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
) )
if not os.path.exists(model_path): if not os.path.exists(model_path):
model_dir = os.path.dirname(model_path) model_dir = os.path.dirname(model_path)
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
# Download the model # Download the model
print(f"Model not found at {model_path}. Downloading...") print(f"Model not found at {model_path}. Downloading...")
model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model" model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model"
try: try:
def report_progress(block_num, block_size, total_size): def report_progress(block_num, block_size, total_size):
downloaded = block_num * block_size downloaded = block_num * block_size
percent = min(100, downloaded * 100 / total_size) percent = min(100, downloaded * 100 / total_size)
if total_size > 0: if total_size > 0:
print( print(
f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)", f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)",
end="", end="",
) )
urllib.request.urlretrieve( urllib.request.urlretrieve(
model_url, model_path, reporthook=report_progress model_url, model_path, reporthook=report_progress
) )
print("\nDownload complete!") print("\nDownload complete!")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to download model: {e}") from e raise RuntimeError(f"Failed to download model: {e}") from e
if return_raw_model: if return_raw_model:
return torch.load(model_path, map_location=device) return torch.load(model_path, map_location=device)
return MACECalculator( return MACECalculator(
model_paths=model_path, device=device, default_dtype="float64" model_paths=model_path, device=device, default_dtype="float64"
) )
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
from mace.tools.scatter import scatter_sum from mace.tools.scatter import scatter_sum
@compile_mode("script") @compile_mode("script")
class LAMMPS_MACE(torch.nn.Module): class LAMMPS_MACE(torch.nn.Module):
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max) self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions) self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "heads"): if not hasattr(model, "heads"):
model.heads = [None] model.heads = [None]
self.register_buffer( self.register_buffer(
"head", "head",
torch.tensor( torch.tensor(
self.model.heads.index(kwargs.get("head", self.model.heads[-1])), self.model.heads.index(kwargs.get("head", self.model.heads[-1])),
dtype=torch.long, dtype=torch.long,
).unsqueeze(0), ).unsqueeze(0),
) )
for param in self.model.parameters(): for param in self.model.parameters():
param.requires_grad = False param.requires_grad = False
def forward( def forward(
self, self,
data: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor],
local_or_ghost: torch.Tensor, local_or_ghost: torch.Tensor,
compute_virials: bool = False, compute_virials: bool = False,
) -> Dict[str, Optional[torch.Tensor]]: ) -> Dict[str, Optional[torch.Tensor]]:
num_graphs = data["ptr"].numel() - 1 num_graphs = data["ptr"].numel() - 1
compute_displacement = False compute_displacement = False
if compute_virials: if compute_virials:
compute_displacement = True compute_displacement = True
data["head"] = self.head data["head"] = self.head
out = self.model( out = self.model(
data, data,
training=False, training=False,
compute_force=False, compute_force=False,
compute_virials=False, compute_virials=False,
compute_stress=False, compute_stress=False,
compute_displacement=compute_displacement, compute_displacement=compute_displacement,
) )
node_energy = out["node_energy"] node_energy = out["node_energy"]
if node_energy is None: if node_energy is None:
return { return {
"total_energy_local": None, "total_energy_local": None,
"node_energy": None, "node_energy": None,
"forces": None, "forces": None,
"virials": None, "virials": None,
} }
positions = data["positions"] positions = data["positions"]
displacement = out["displacement"] displacement = out["displacement"]
forces: Optional[torch.Tensor] = torch.zeros_like(positions) forces: Optional[torch.Tensor] = torch.zeros_like(positions)
virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"])
# accumulate energies of local atoms # accumulate energies of local atoms
node_energy_local = node_energy * local_or_ghost node_energy_local = node_energy * local_or_ghost
total_energy_local = scatter_sum( total_energy_local = scatter_sum(
src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs
) )
# compute partial forces and (possibly) partial virials # compute partial forces and (possibly) partial virials
grad_outputs: List[Optional[torch.Tensor]] = [ grad_outputs: List[Optional[torch.Tensor]] = [
torch.ones_like(total_energy_local) torch.ones_like(total_energy_local)
] ]
if compute_virials and displacement is not None: if compute_virials and displacement is not None:
forces, virials = torch.autograd.grad( forces, virials = torch.autograd.grad(
outputs=[total_energy_local], outputs=[total_energy_local],
inputs=[positions, displacement], inputs=[positions, displacement],
grad_outputs=grad_outputs, grad_outputs=grad_outputs,
retain_graph=False, retain_graph=False,
create_graph=False, create_graph=False,
allow_unused=True, allow_unused=True,
) )
if forces is not None: if forces is not None:
forces = -1 * forces forces = -1 * forces
else: else:
forces = torch.zeros_like(positions) forces = torch.zeros_like(positions)
if virials is not None: if virials is not None:
virials = -1 * virials virials = -1 * virials
else: else:
virials = torch.zeros_like(displacement) virials = torch.zeros_like(displacement)
else: else:
forces = torch.autograd.grad( forces = torch.autograd.grad(
outputs=[total_energy_local], outputs=[total_energy_local],
inputs=[positions], inputs=[positions],
grad_outputs=grad_outputs, grad_outputs=grad_outputs,
retain_graph=False, retain_graph=False,
create_graph=False, create_graph=False,
allow_unused=True, allow_unused=True,
)[0] )[0]
if forces is not None: if forces is not None:
forces = -1 * forces forces = -1 * forces
else: else:
forces = torch.zeros_like(positions) forces = torch.zeros_like(positions)
return { return {
"total_energy_local": total_energy_local, "total_energy_local": total_energy_local,
"node_energy": node_energy, "node_energy": node_energy,
"forces": forces, "forces": forces,
"virials": virials, "virials": virials,
} }
import logging import logging
import os import os
import sys import sys
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Tuple from typing import Dict, Tuple
import torch import torch
from ase.data import chemical_symbols from ase.data import chemical_symbols
from e3nn.util.jit import compile_mode from e3nn.util.jit import compile_mode
try: try:
from lammps.mliap.mliap_unified_abc import MLIAPUnified from lammps.mliap.mliap_unified_abc import MLIAPUnified
except ImportError: except ImportError:
class MLIAPUnified: class MLIAPUnified:
def __init__(self): def __init__(self):
pass pass
class MACELammpsConfig: class MACELammpsConfig:
"""Configuration settings for MACE-LAMMPS integration.""" """Configuration settings for MACE-LAMMPS integration."""
def __init__(self): def __init__(self):
self.debug_time = self._get_env_bool("MACE_TIME", False) self.debug_time = self._get_env_bool("MACE_TIME", False)
self.debug_profile = self._get_env_bool("MACE_PROFILE", False) self.debug_profile = self._get_env_bool("MACE_PROFILE", False)
self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5")) self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5"))
self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10")) self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10"))
self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False) self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False)
self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False) self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False)
@staticmethod @staticmethod
def _get_env_bool(var_name: str, default: bool) -> bool: def _get_env_bool(var_name: str, default: bool) -> bool:
return os.environ.get(var_name, str(default)).lower() in ( return os.environ.get(var_name, str(default)).lower() in (
"true", "true",
"1", "1",
"t", "t",
"yes", "yes",
) )
@contextmanager @contextmanager
def timer(name: str, enabled: bool = True): def timer(name: str, enabled: bool = True):
"""Context manager for timing code blocks.""" """Context manager for timing code blocks."""
if not enabled: if not enabled:
yield yield
return return
start = time.perf_counter() start = time.perf_counter()
try: try:
yield yield
finally: finally:
elapsed = time.perf_counter() - start elapsed = time.perf_counter() - start
logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms") logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms")
@compile_mode("script") @compile_mode("script")
class MACEEdgeForcesWrapper(torch.nn.Module): class MACEEdgeForcesWrapper(torch.nn.Module):
"""Wrapper that adds per-pair force computation to a MACE model.""" """Wrapper that adds per-pair force computation to a MACE model."""
def __init__(self, model: torch.nn.Module, **kwargs): def __init__(self, model: torch.nn.Module, **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max) self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions) self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "heads"): if not hasattr(model, "heads"):
model.heads = ["Default"] model.heads = ["Default"]
head_name = kwargs.get("head", model.heads[-1]) head_name = kwargs.get("head", model.heads[-1])
head_idx = model.heads.index(head_name) head_idx = model.heads.index(head_name)
self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long)) self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long))
for p in self.model.parameters(): for p in self.model.parameters():
p.requires_grad = False p.requires_grad = False
def forward( def forward(
self, data: Dict[str, torch.Tensor] self, data: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute energies and per-pair forces.""" """Compute energies and per-pair forces."""
data["head"] = self.head data["head"] = self.head
out = self.model( out = self.model(
data, data,
training=False, training=False,
compute_force=False, compute_force=False,
compute_virials=False, compute_virials=False,
compute_stress=False, compute_stress=False,
compute_displacement=False, compute_displacement=False,
compute_edge_forces=True, compute_edge_forces=True,
lammps_mliap=True, lammps_mliap=True,
) )
node_energy = out["node_energy"] node_energy = out["node_energy"]
pair_forces = out["edge_forces"] pair_forces = out["edge_forces"]
total_energy = out["energy"][0] total_energy = out["energy"][0]
if pair_forces is None: if pair_forces is None:
pair_forces = torch.zeros_like(data["vectors"]) pair_forces = torch.zeros_like(data["vectors"])
return total_energy, node_energy, pair_forces return total_energy, node_energy, pair_forces
class LAMMPS_MLIAP_MACE(MLIAPUnified): class LAMMPS_MLIAP_MACE(MLIAPUnified):
"""MACE integration for LAMMPS using the MLIAP interface.""" """MACE integration for LAMMPS using the MLIAP interface."""
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
super().__init__() super().__init__()
self.config = MACELammpsConfig() self.config = MACELammpsConfig()
self.model = MACEEdgeForcesWrapper(model, **kwargs) self.model = MACEEdgeForcesWrapper(model, **kwargs)
self.element_types = [chemical_symbols[s] for s in model.atomic_numbers] self.element_types = [chemical_symbols[s] for s in model.atomic_numbers]
self.num_species = len(self.element_types) self.num_species = len(self.element_types)
self.rcutfac = 0.5 * float(model.r_max) self.rcutfac = 0.5 * float(model.r_max)
self.ndescriptors = 1 self.ndescriptors = 1
self.nparams = 1 self.nparams = 1
self.dtype = model.r_max.dtype self.dtype = model.r_max.dtype
self.device = "cpu" self.device = "cpu"
self.initialized = False self.initialized = False
self.step = 0 self.step = 0
def _initialize_device(self, data): def _initialize_device(self, data):
using_kokkos = "kokkos" in data.__class__.__module__.lower() using_kokkos = "kokkos" in data.__class__.__module__.lower()
if using_kokkos and not self.config.force_cpu: if using_kokkos and not self.config.force_cpu:
device = torch.as_tensor(data.elems).device device = torch.as_tensor(data.elems).device
if device.type == "cpu" and not self.config.allow_cpu: if device.type == "cpu" and not self.config.allow_cpu:
raise ValueError( raise ValueError(
"GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation." "GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation."
) )
else: else:
device = torch.device("cpu") device = torch.device("cpu")
self.device = device self.device = device
self.model = self.model.to(device) self.model = self.model.to(device)
logging.info(f"MACE model initialized on device: {device}") logging.info(f"MACE model initialized on device: {device}")
self.initialized = True self.initialized = True
def compute_forces(self, data): def compute_forces(self, data):
natoms = data.nlocal natoms = data.nlocal
ntotal = data.ntotal ntotal = data.ntotal
nghosts = ntotal - natoms nghosts = ntotal - natoms
npairs = data.npairs npairs = data.npairs
species = torch.as_tensor(data.elems, dtype=torch.int64) species = torch.as_tensor(data.elems, dtype=torch.int64)
if not self.initialized: if not self.initialized:
self._initialize_device(data) self._initialize_device(data)
self.step += 1 self.step += 1
self._manage_profiling() self._manage_profiling()
if natoms == 0 or npairs <= 1: if natoms == 0 or npairs <= 1:
return return
with timer("total_step", enabled=self.config.debug_time): with timer("total_step", enabled=self.config.debug_time):
with timer("prepare_batch", enabled=self.config.debug_time): with timer("prepare_batch", enabled=self.config.debug_time):
batch = self._prepare_batch(data, natoms, nghosts, species) batch = self._prepare_batch(data, natoms, nghosts, species)
with timer("model_forward", enabled=self.config.debug_time): with timer("model_forward", enabled=self.config.debug_time):
_, atom_energies, pair_forces = self.model(batch) _, atom_energies, pair_forces = self.model(batch)
if self.device.type != "cpu": if self.device.type != "cpu":
torch.cuda.synchronize() torch.cuda.synchronize()
with timer("update_lammps", enabled=self.config.debug_time): with timer("update_lammps", enabled=self.config.debug_time):
self._update_lammps_data(data, atom_energies, pair_forces, natoms) self._update_lammps_data(data, atom_energies, pair_forces, natoms)
def _prepare_batch(self, data, natoms, nghosts, species): def _prepare_batch(self, data, natoms, nghosts, species):
"""Prepare the input batch for the MACE model.""" """Prepare the input batch for the MACE model."""
return { return {
"vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device), "vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device),
"node_attrs": torch.nn.functional.one_hot( "node_attrs": torch.nn.functional.one_hot(
species.to(self.device), num_classes=self.num_species species.to(self.device), num_classes=self.num_species
).to(self.dtype), ).to(self.dtype),
"edge_index": torch.stack( "edge_index": torch.stack(
[ [
torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device), torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device),
torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device), torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device),
], ],
dim=0, dim=0,
), ),
"batch": torch.zeros(natoms, dtype=torch.int64, device=self.device), "batch": torch.zeros(natoms, dtype=torch.int64, device=self.device),
"lammps_class": data, "lammps_class": data,
"natoms": (natoms, nghosts), "natoms": (natoms, nghosts),
} }
def _update_lammps_data(self, data, atom_energies, pair_forces, natoms): def _update_lammps_data(self, data, atom_energies, pair_forces, natoms):
"""Update LAMMPS data structures with computed energies and forces.""" """Update LAMMPS data structures with computed energies and forces."""
if self.dtype == torch.float32: if self.dtype == torch.float32:
pair_forces = pair_forces.double() pair_forces = pair_forces.double()
eatoms = torch.as_tensor(data.eatoms) eatoms = torch.as_tensor(data.eatoms)
eatoms.copy_(atom_energies[:natoms]) eatoms.copy_(atom_energies[:natoms])
data.energy = torch.sum(atom_energies[:natoms]) data.energy = torch.sum(atom_energies[:natoms])
data.update_pair_forces_gpu(pair_forces) data.update_pair_forces_gpu(pair_forces)
def _manage_profiling(self): def _manage_profiling(self):
if not self.config.debug_profile: if not self.config.debug_profile:
return return
if self.step == self.config.profile_start_step: if self.step == self.config.profile_start_step:
logging.info(f"Starting CUDA profiler at step {self.step}") logging.info(f"Starting CUDA profiler at step {self.step}")
torch.cuda.profiler.start() torch.cuda.profiler.start()
if self.step == self.config.profile_end_step: if self.step == self.config.profile_end_step:
logging.info(f"Stopping CUDA profiler at step {self.step}") logging.info(f"Stopping CUDA profiler at step {self.step}")
torch.cuda.profiler.stop() torch.cuda.profiler.stop()
logging.info("Profiling complete. Exiting.") logging.info("Profiling complete. Exiting.")
sys.exit() sys.exit()
def compute_descriptors(self, data): def compute_descriptors(self, data):
pass pass
def compute_gradients(self, data): def compute_gradients(self, data):
pass pass
########################################################################################### ###########################################################################################
# The ASE Calculator for MACE # The ASE Calculator for MACE
# Authors: Ilyes Batatia, David Kovacs # Authors: Ilyes Batatia, David Kovacs
# This program is distributed under the MIT License (see MIT.md) # This program is distributed under the MIT License (see MIT.md)
########################################################################################### ###########################################################################################
import logging import logging
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
import os import os
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import numpy as np import numpy as np
import torch import torch
from ase.calculators.calculator import Calculator, all_changes from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3 from e3nn import o3
from mace import data from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare from mace.tools.compile import prepare
from mace.tools.scripts_utils import extract_model from mace.tools.scripts_utils import extract_model
import random import random
from mace.tools.torch_geometric.batch import Batch from mace.tools.torch_geometric.batch import Batch
from mace.tools import ( from mace.tools import (
atomic_numbers_to_indices, atomic_numbers_to_indices,
to_one_hot, to_one_hot,
) )
import time import time
def get_model_dtype(model: torch.nn.Module) -> torch.dtype: def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
"""Get the dtype of the model""" """Get the dtype of the model"""
mode_dtype = next(model.parameters()).dtype mode_dtype = next(model.parameters()).dtype
if mode_dtype == torch.float64: if mode_dtype == torch.float64:
return "float64" return "float64"
if mode_dtype == torch.float32: if mode_dtype == torch.float32:
return "float32" return "float32"
raise ValueError(f"Unknown dtype {mode_dtype}") raise ValueError(f"Unknown dtype {mode_dtype}")
class MACECalculator(Calculator): class MACECalculator(Calculator):
"""MACE ASE Calculator """MACE ASE Calculator
args: args:
model_paths: str, path to model or models if a committee is produced model_paths: str, path to model or models if a committee is produced
to make a committee use a wild card notation like mace_*.model to make a committee use a wild card notation like mace_*.model
device: str, device to run on (cuda or cpu) device: str, device to run on (cuda or cpu)
energy_units_to_eV: float, conversion factor from model energy units to eV energy_units_to_eV: float, conversion factor from model energy units to eV
length_units_to_A: float, conversion factor from model length units to Angstroms length_units_to_A: float, conversion factor from model length units to Angstroms
default_dtype: str, default dtype of model default_dtype: str, default dtype of model
charges_key: str, Array field of atoms object where atomic charges are stored charges_key: str, Array field of atoms object where atomic charges are stored
model_type: str, type of model to load model_type: str, type of model to load
Options: [MACE, DipoleMACE, EnergyDipoleMACE] Options: [MACE, DipoleMACE, EnergyDipoleMACE]
Dipoles are returned in units of Debye Dipoles are returned in units of Debye
""" """
def __init__( def __init__(
self, self,
model_paths: Union[list, str, None] = None, model_paths: Union[list, str, None] = None,
models: Union[List[torch.nn.Module], torch.nn.Module, None] = None, models: Union[List[torch.nn.Module], torch.nn.Module, None] = None,
device: str = "cpu", device: str = "cpu",
energy_units_to_eV: float = 1.0, energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0, length_units_to_A: float = 1.0,
default_dtype="", default_dtype="",
charges_key="Qs", charges_key="Qs",
model_type="MACE", model_type="MACE",
compile_mode=None, compile_mode=None,
fullgraph=True, fullgraph=True,
enable_cueq=False, enable_cueq=False,
**kwargs, **kwargs,
): ):
Calculator.__init__(self, **kwargs) Calculator.__init__(self, **kwargs)
self.device = device self.device = device
self.dtype=None self.dtype=None
if enable_cueq: if enable_cueq:
assert model_type == "MACE", "CuEq only supports MACE models" assert model_type == "MACE", "CuEq only supports MACE models"
compile_mode = None compile_mode = None
if "model_path" in kwargs: if "model_path" in kwargs:
deprecation_message = ( deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'" "'model_path' argument is deprecated, please use 'model_paths'"
) )
if model_paths is None: if model_paths is None:
logging.warning(f"{deprecation_message} in the future.") logging.warning(f"{deprecation_message} in the future.")
model_paths = kwargs["model_path"] model_paths = kwargs["model_path"]
else: else:
raise ValueError( raise ValueError(
f"both 'model_path' and 'model_paths' given, {deprecation_message} only." f"both 'model_path' and 'model_paths' given, {deprecation_message} only."
) )
if (model_paths is None) == (models is None): if (model_paths is None) == (models is None):
raise ValueError( raise ValueError(
"Exactly one of 'model_paths' or 'models' must be provided" "Exactly one of 'model_paths' or 'models' must be provided"
) )
self.results = {} self.results = {}
self.model_type = model_type self.model_type = model_type
self.compute_atomic_stresses = False self.compute_atomic_stresses = False
if model_type == "MACE": if model_type == "MACE":
self.implemented_properties = [ self.implemented_properties = [
"energy", "energy",
"free_energy", "free_energy",
"node_energy", "node_energy",
"forces", "forces",
"stress", "stress",
] ]
if kwargs.get("compute_atomic_stresses", False): if kwargs.get("compute_atomic_stresses", False):
self.implemented_properties.extend(["stresses", "virials"]) self.implemented_properties.extend(["stresses", "virials"])
self.compute_atomic_stresses = True self.compute_atomic_stresses = True
elif model_type == "DipoleMACE": elif model_type == "DipoleMACE":
self.implemented_properties = ["dipole"] self.implemented_properties = ["dipole"]
elif model_type == "EnergyDipoleMACE": elif model_type == "EnergyDipoleMACE":
self.implemented_properties = [ self.implemented_properties = [
"energy", "energy",
"free_energy", "free_energy",
"node_energy", "node_energy",
"forces", "forces",
"stress", "stress",
"dipole", "dipole",
] ]
else: else:
raise ValueError( raise ValueError(
f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported"
) )
if model_paths is not None: if model_paths is not None:
if isinstance(model_paths, str): if isinstance(model_paths, str):
# Find all models that satisfy the wildcard (e.g. mace_model_*.pt) # Find all models that satisfy the wildcard (e.g. mace_model_*.pt)
model_paths_glob = glob(model_paths) model_paths_glob = glob(model_paths)
if len(model_paths_glob) == 0: if len(model_paths_glob) == 0:
raise ValueError(f"Couldn't find MACE model files: {model_paths}") raise ValueError(f"Couldn't find MACE model files: {model_paths}")
model_paths = model_paths_glob model_paths = model_paths_glob
elif isinstance(model_paths, Path): elif isinstance(model_paths, Path):
model_paths = [model_paths] model_paths = [model_paths]
if len(model_paths) == 0: if len(model_paths) == 0:
raise ValueError("No mace file names supplied") raise ValueError("No mace file names supplied")
self.num_models = len(model_paths) self.num_models = len(model_paths)
# Load models from files # Load models from files
self.models = [ self.models = [
torch.load(f=model_path, map_location=device) torch.load(f=model_path, map_location=device)
for model_path in model_paths for model_path in model_paths
] ]
elif models is not None: elif models is not None:
if not isinstance(models, list): if not isinstance(models, list):
models = [models] models = [models]
if len(models) == 0: if len(models) == 0:
raise ValueError("No models supplied") raise ValueError("No models supplied")
self.models = models self.models = models
self.num_models = len(models) self.num_models = len(models)
if self.num_models > 1: if self.num_models > 1:
print(f"Running committee mace with {self.num_models} models") print(f"Running committee mace with {self.num_models} models")
if model_type in ["MACE", "EnergyDipoleMACE"]: if model_type in ["MACE", "EnergyDipoleMACE"]:
self.implemented_properties.extend( self.implemented_properties.extend(
["energies", "energy_var", "forces_comm", "stress_var"] ["energies", "energy_var", "forces_comm", "stress_var"]
) )
elif model_type == "DipoleMACE": elif model_type == "DipoleMACE":
self.implemented_properties.extend(["dipole_var"]) self.implemented_properties.extend(["dipole_var"])
if compile_mode is not None: if compile_mode is not None:
print(f"Torch compile is enabled with mode: {compile_mode}") print(f"Torch compile is enabled with mode: {compile_mode}")
self.models = [ self.models = [
torch.compile( torch.compile(
prepare(extract_model)(model=model, map_location=device), prepare(extract_model)(model=model, map_location=device),
mode=compile_mode, mode=compile_mode,
fullgraph=fullgraph, fullgraph=fullgraph,
) )
for model in self.models for model in self.models
] ]
self.use_compile = True self.use_compile = True
else: else:
self.use_compile = False self.use_compile = False
# Ensure all models are on the same device # Ensure all models are on the same device
for model in self.models: for model in self.models:
model.to(device) model.to(device)
r_maxs = [model.r_max.cpu() for model in self.models] r_maxs = [model.r_max.cpu() for model in self.models]
r_maxs = np.array(r_maxs) r_maxs = np.array(r_maxs)
if not np.all(r_maxs == r_maxs[0]): if not np.all(r_maxs == r_maxs[0]):
raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}") raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}")
self.r_max = float(r_maxs[0]) self.r_max = float(r_maxs[0])
self.device = torch_tools.init_device(device) self.device = torch_tools.init_device(device)
self.energy_units_to_eV = energy_units_to_eV self.energy_units_to_eV = energy_units_to_eV
self.length_units_to_A = length_units_to_A self.length_units_to_A = length_units_to_A
self.z_table = utils.AtomicNumberTable( self.z_table = utils.AtomicNumberTable(
[int(z) for z in self.models[0].atomic_numbers] [int(z) for z in self.models[0].atomic_numbers]
) )
self.charges_key = charges_key self.charges_key = charges_key
try: try:
self.available_heads: List[str] = self.models[0].heads # type: ignore self.available_heads: List[str] = self.models[0].heads # type: ignore
except AttributeError: except AttributeError:
self.available_heads = ["Default"] self.available_heads = ["Default"]
kwarg_head = kwargs.get("head", None) kwarg_head = kwargs.get("head", None)
if kwarg_head is not None: if kwarg_head is not None:
self.head = kwarg_head self.head = kwarg_head
else: else:
self.head = [head for head in self.available_heads if head.lower() == "default"] self.head = [head for head in self.available_heads if head.lower() == "default"]
if len(self.head) == 0: if len(self.head) == 0:
raise ValueError( raise ValueError(
"Head keyword was not provided, and no head in the model is 'default'. " "Head keyword was not provided, and no head in the model is 'default'. "
"Please provide a head keyword to specify the head you want to use. " "Please provide a head keyword to specify the head you want to use. "
f"Available heads are: {self.available_heads}" f"Available heads are: {self.available_heads}"
) )
self.head = self.head[0] self.head = self.head[0]
print("Using head", self.head, "out of", self.available_heads) print("Using head", self.head, "out of", self.available_heads)
model_dtype = get_model_dtype(self.models[0]) model_dtype = get_model_dtype(self.models[0])
if default_dtype == "": if default_dtype == "":
print( print(
f"No dtype selected, switching to {model_dtype} to match model dtype." f"No dtype selected, switching to {model_dtype} to match model dtype."
) )
default_dtype = model_dtype default_dtype = model_dtype
if model_dtype != default_dtype: if model_dtype != default_dtype:
print( print(
f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}." f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}."
) )
if default_dtype == "float64": if default_dtype == "float64":
self.models = [model.double() for model in self.models] self.models = [model.double() for model in self.models]
elif default_dtype == "float32": elif default_dtype == "float32":
self.models = [model.float() for model in self.models] self.models = [model.float() for model in self.models]
torch_tools.set_default_dtype(default_dtype) torch_tools.set_default_dtype(default_dtype)
if enable_cueq: if enable_cueq:
print("Converting models to CuEq for acceleration") print("Converting models to CuEq for acceleration")
self.models = [ self.models = [
run_e3nn_to_cueq(model, device=device).to(device) run_e3nn_to_cueq(model, device=device).to(device)
for model in self.models for model in self.models
] ]
for model in self.models: for model in self.models:
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
self.dtype = torch.float64 if default_dtype == "float64" else torch.float32 self.dtype = torch.float64 if default_dtype == "float64" else torch.float32
self.model_time = 0.0 self.model_time = 0.0
self.calc_time = 0.0 self.calc_time = 0.0
def _create_result_tensors( def _create_result_tensors(
self, model_type: str, num_models: int, num_atoms: int self, model_type: str, num_models: int, num_atoms: int
) -> dict: ) -> dict:
""" """
Create tensors to store the results of the committee Create tensors to store the results of the committee
:param model_type: str, type of model to load :param model_type: str, type of model to load
Options: [MACE, DipoleMACE, EnergyDipoleMACE] Options: [MACE, DipoleMACE, EnergyDipoleMACE]
:param num_models: int, number of models in the committee :param num_models: int, number of models in the committee
:return: tuple of torch tensors :return: tuple of torch tensors
""" """
dict_of_tensors = {} dict_of_tensors = {}
if model_type in ["MACE", "EnergyDipoleMACE"]: if model_type in ["MACE", "EnergyDipoleMACE"]:
energies = torch.zeros(num_models, device=self.device) energies = torch.zeros(num_models, device=self.device)
node_energy = torch.zeros(num_models, num_atoms, device=self.device) node_energy = torch.zeros(num_models, num_atoms, device=self.device)
forces = torch.zeros(num_models, num_atoms, 3, device=self.device) forces = torch.zeros(num_models, num_atoms, 3, device=self.device)
stress = torch.zeros(num_models, 3, 3, device=self.device) stress = torch.zeros(num_models, 3, 3, device=self.device)
dict_of_tensors.update( dict_of_tensors.update(
{ {
"energies": energies, "energies": energies,
"node_energy": node_energy, "node_energy": node_energy,
"forces": forces, "forces": forces,
"stress": stress, "stress": stress,
} }
) )
if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: if model_type in ["EnergyDipoleMACE", "DipoleMACE"]:
dipole = torch.zeros(num_models, 3, device=self.device) dipole = torch.zeros(num_models, 3, device=self.device)
dict_of_tensors.update({"dipole": dipole}) dict_of_tensors.update({"dipole": dipole})
return dict_of_tensors return dict_of_tensors
def _atoms_to_batch(self, atoms): def _atoms_to_batch(self, atoms):
keyspec = data.KeySpecification( keyspec = data.KeySpecification(
info_keys={}, arrays_keys={"charges": self.charges_key} info_keys={}, arrays_keys={"charges": self.charges_key}
) )
config = data.config_from_atoms( config = data.config_from_atoms(
atoms, key_specification=keyspec, head_name=self.head atoms, key_specification=keyspec, head_name=self.head
) )
data_loader = torch_geometric.dataloader.DataLoader( data_loader = torch_geometric.dataloader.DataLoader(
dataset=[ dataset=[
data.AtomicData.from_config( data.AtomicData.from_config(
config, config,
z_table=self.z_table, z_table=self.z_table,
cutoff=self.r_max, cutoff=self.r_max,
heads=self.available_heads, heads=self.available_heads,
) )
], ],
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
) )
batch = next(iter(data_loader)).to(self.device) batch = next(iter(data_loader)).to(self.device)
return batch return batch
def _clone_batch(self, batch): def _clone_batch(self, batch):
batch_clone = batch.clone() batch_clone = batch.clone()
if self.use_compile: if self.use_compile:
batch_clone["node_attrs"].requires_grad_(True) batch_clone["node_attrs"].requires_grad_(True)
batch_clone["positions"].requires_grad_(True) batch_clone["positions"].requires_grad_(True)
return batch_clone return batch_clone
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def calculate(self, atoms=None, properties=None, system_changes=all_changes): def calculate(self, atoms=None, properties=None, system_changes=all_changes):
""" """
Calculate properties. Calculate properties.
:param atoms: ase.Atoms object :param atoms: ase.Atoms object
:param properties: [str], properties to be computed, used by ASE internally :param properties: [str], properties to be computed, used by ASE internally
:param system_changes: [str], system changes since last calculation, used by ASE internally :param system_changes: [str], system changes since last calculation, used by ASE internally
:return: :return:
""" """
# call to base-class to set atoms attribute # call to base-class to set atoms attribute
calc_start_t = time.perf_counter() calc_start_t = time.perf_counter()
Calculator.calculate(self, atoms) Calculator.calculate(self, atoms)
batch_base = self._atoms_to_batch(atoms) batch_base = self._atoms_to_batch(atoms)
if self.model_type in ["MACE", "EnergyDipoleMACE"]: if self.model_type in ["MACE", "EnergyDipoleMACE"]:
batch = self._clone_batch(batch_base) batch = self._clone_batch(batch_base)
node_heads = batch["head"][batch["batch"]] node_heads = batch["head"][batch["batch"]]
num_atoms_arange = torch.arange(batch["positions"].shape[0]) num_atoms_arange = torch.arange(batch["positions"].shape[0])
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
num_atoms_arange, node_heads num_atoms_arange, node_heads
] ]
compute_stress = not self.use_compile compute_stress = not self.use_compile
else: else:
compute_stress = False compute_stress = False
ret_tensors = self._create_result_tensors( ret_tensors = self._create_result_tensors(
self.model_type, self.num_models, len(atoms) self.model_type, self.num_models, len(atoms)
) )
for i, model in enumerate(self.models): for i, model in enumerate(self.models):
batch = self._clone_batch(batch_base) batch = self._clone_batch(batch_base)
# print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}') # print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}')
# set_seed(0) # set_seed(0)
model_start_t = time.perf_counter() model_start_t = time.perf_counter()
out = model( out = model(
batch.to_dict(), batch.to_dict(),
compute_stress=compute_stress, compute_stress=compute_stress,
training=self.use_compile, training=self.use_compile,
compute_edge_forces=self.compute_atomic_stresses, compute_edge_forces=self.compute_atomic_stresses,
compute_atomic_stresses=self.compute_atomic_stresses, compute_atomic_stresses=self.compute_atomic_stresses,
) )
model_end_t = time.perf_counter() model_end_t = time.perf_counter()
self.model_time += (model_end_t - model_start_t) self.model_time += (model_end_t - model_start_t)
# print(f'&&& batch.positions: {batch["positions"]}') # print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.stress: {batch["stress"]}') # print(f'&&& batch.stress: {batch["stress"]}')
# print(f'compute_stress: {compute_stress}') # print(f'compute_stress: {compute_stress}')
# for k,v in batch.to_dict().items(): # for k,v in batch.to_dict().items():
# print(f'&&& batch.to_dict(): {k} {v}') # print(f'&&& batch.to_dict(): {k} {v}')
# print("=======") # print("=======")
# print(f'&&& out["forces"]: {out["forces"]}') # print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}') # print(f'&&& training: {self.use_compile}')
# print(f'@@@File: {__file__}, out: {out}') # print(f'@@@File: {__file__}, out: {out}')
if self.model_type in ["MACE", "EnergyDipoleMACE"]: if self.model_type in ["MACE", "EnergyDipoleMACE"]:
ret_tensors["energies"][i] = out["energy"].detach() ret_tensors["energies"][i] = out["energy"].detach()
ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach() ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach()
ret_tensors["forces"][i] = out["forces"].detach() ret_tensors["forces"][i] = out["forces"].detach()
if out["stress"] is not None: if out["stress"] is not None:
ret_tensors["stress"][i] = out["stress"].detach() ret_tensors["stress"][i] = out["stress"].detach()
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
ret_tensors["dipole"][i] = out["dipole"].detach() ret_tensors["dipole"][i] = out["dipole"].detach()
if self.model_type in ["MACE"]: if self.model_type in ["MACE"]:
if out["atomic_stresses"] is not None: if out["atomic_stresses"] is not None:
ret_tensors.setdefault("atomic_stresses", []).append( ret_tensors.setdefault("atomic_stresses", []).append(
out["atomic_stresses"].detach() out["atomic_stresses"].detach()
) )
if out["atomic_virials"] is not None: if out["atomic_virials"] is not None:
ret_tensors.setdefault("atomic_virials", []).append( ret_tensors.setdefault("atomic_virials", []).append(
out["atomic_virials"].detach() out["atomic_virials"].detach()
) )
self.results = {} self.results = {}
if self.model_type in ["MACE", "EnergyDipoleMACE"]: if self.model_type in ["MACE", "EnergyDipoleMACE"]:
self.results["energy"] = ( self.results["energy"] = (
torch.mean(ret_tensors["energies"], dim=0).cpu().item() torch.mean(ret_tensors["energies"], dim=0).cpu().item()
* self.energy_units_to_eV * self.energy_units_to_eV
) )
self.results["free_energy"] = self.results["energy"] self.results["free_energy"] = self.results["energy"]
self.results["node_energy"] = ( self.results["node_energy"] = (
torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy() torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy()
) )
self.results["forces"] = ( self.results["forces"] = (
torch.mean(ret_tensors["forces"], dim=0).cpu().numpy() torch.mean(ret_tensors["forces"], dim=0).cpu().numpy()
* self.energy_units_to_eV * self.energy_units_to_eV
/ self.length_units_to_A / self.length_units_to_A
) )
if self.num_models > 1: if self.num_models > 1:
self.results["energies"] = ( self.results["energies"] = (
ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV
) )
self.results["energy_var"] = ( self.results["energy_var"] = (
torch.var(ret_tensors["energies"], dim=0, unbiased=False) torch.var(ret_tensors["energies"], dim=0, unbiased=False)
.cpu() .cpu()
.item() .item()
* self.energy_units_to_eV * self.energy_units_to_eV
) )
self.results["forces_comm"] = ( self.results["forces_comm"] = (
ret_tensors["forces"].cpu().numpy() ret_tensors["forces"].cpu().numpy()
* self.energy_units_to_eV * self.energy_units_to_eV
/ self.length_units_to_A / self.length_units_to_A
) )
if out["stress"] is not None: if out["stress"] is not None:
self.results["stress"] = full_3x3_to_voigt_6_stress( self.results["stress"] = full_3x3_to_voigt_6_stress(
torch.mean(ret_tensors["stress"], dim=0).cpu().numpy() torch.mean(ret_tensors["stress"], dim=0).cpu().numpy()
* self.energy_units_to_eV * self.energy_units_to_eV
/ self.length_units_to_A**3 / self.length_units_to_A**3
) )
if self.num_models > 1: if self.num_models > 1:
self.results["stress_var"] = full_3x3_to_voigt_6_stress( self.results["stress_var"] = full_3x3_to_voigt_6_stress(
torch.var(ret_tensors["stress"], dim=0, unbiased=False) torch.var(ret_tensors["stress"], dim=0, unbiased=False)
.cpu() .cpu()
.numpy() .numpy()
* self.energy_units_to_eV * self.energy_units_to_eV
/ self.length_units_to_A**3 / self.length_units_to_A**3
) )
if "atomic_stresses" in ret_tensors: if "atomic_stresses" in ret_tensors:
self.results["stresses"] = ( self.results["stresses"] = (
torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0) torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0)
.cpu() .cpu()
.numpy() .numpy()
* self.energy_units_to_eV * self.energy_units_to_eV
/ self.length_units_to_A**3 / self.length_units_to_A**3
) )
if "atomic_virials" in ret_tensors: if "atomic_virials" in ret_tensors:
self.results["virials"] = ( self.results["virials"] = (
torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0) torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0)
.cpu() .cpu()
.numpy() .numpy()
* self.energy_units_to_eV * self.energy_units_to_eV
) )
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
self.results["dipole"] = ( self.results["dipole"] = (
torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy() torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy()
) )
if self.num_models > 1: if self.num_models > 1:
self.results["dipole_var"] = ( self.results["dipole_var"] = (
torch.var(ret_tensors["dipole"], dim=0, unbiased=False) torch.var(ret_tensors["dipole"], dim=0, unbiased=False)
.cpu() .cpu()
.numpy() .numpy()
) )
calc_end_t = time.perf_counter() calc_end_t = time.perf_counter()
self.calc_time += (calc_end_t - calc_start_t) self.calc_time += (calc_end_t - calc_start_t)
def get_hessian(self, atoms=None): def get_hessian(self, atoms=None):
if atoms is None and self.atoms is None: if atoms is None and self.atoms is None:
raise ValueError("atoms not set") raise ValueError("atoms not set")
if atoms is None: if atoms is None:
atoms = self.atoms atoms = self.atoms
if self.model_type != "MACE": if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models") raise NotImplementedError("Only implemented for MACE models")
batch = self._atoms_to_batch(atoms) batch = self._atoms_to_batch(atoms)
hessians = [ hessians = [
model( model(
self._clone_batch(batch).to_dict(), self._clone_batch(batch).to_dict(),
compute_hessian=True, compute_hessian=True,
compute_stress=False, compute_stress=False,
training=self.use_compile, training=self.use_compile,
)["hessian"] )["hessian"]
for model in self.models for model in self.models
] ]
hessians = [hessian.detach().cpu().numpy() for hessian in hessians] hessians = [hessian.detach().cpu().numpy() for hessian in hessians]
if self.num_models == 1: if self.num_models == 1:
return hessians[0] return hessians[0]
return hessians return hessians
def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
"""Extracts the descriptors from MACE model. """Extracts the descriptors from MACE model.
:param atoms: ase.Atoms object :param atoms: ase.Atoms object
:param invariants_only: bool, if True only the invariant descriptors are returned :param invariants_only: bool, if True only the invariant descriptors are returned
:param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used :param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used
:return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise :return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise
""" """
if atoms is None and self.atoms is None: if atoms is None and self.atoms is None:
raise ValueError("atoms not set") raise ValueError("atoms not set")
if atoms is None: if atoms is None:
atoms = self.atoms atoms = self.atoms
if self.model_type != "MACE": if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models") raise NotImplementedError("Only implemented for MACE models")
num_interactions = int(self.models[0].num_interactions) num_interactions = int(self.models[0].num_interactions)
if num_layers == -1: if num_layers == -1:
num_layers = num_interactions num_layers = num_interactions
batch = self._atoms_to_batch(atoms) batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]
irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out)) irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out))
l_max = irreps_out.lmax l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)] per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = ( per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer num_invariant_features # Equivariant features not created for the last layer
) )
if invariants_only: if invariants_only:
descriptors = [ descriptors = [
extract_invariant( extract_invariant(
descriptor, descriptor,
num_layers=num_layers, num_layers=num_layers,
num_features=num_invariant_features, num_features=num_invariant_features,
l_max=l_max, l_max=l_max,
) )
for descriptor in descriptors for descriptor in descriptors
] ]
to_keep = np.sum(per_layer_features[:num_layers]) to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [ descriptors = [
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
] ]
if self.num_models == 1: if self.num_models == 1:
return descriptors[0] return descriptors[0]
return descriptors return descriptors
def predict(self, atoms_list, compute_stress=False): def predict(self, atoms_list, compute_stress=False):
predictions = {'energy': [], 'forces': []} predictions = {'energy': [], 'forces': []}
configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader( data_loader = torch_geometric.dataloader.DataLoader(
dataset=[ dataset=[
data.AtomicData.from_config( data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
) )
for config in configs for config in configs
], ],
batch_size=len(atoms_list), batch_size=len(atoms_list),
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
) )
# get the first batch of data_loader # get the first batch of data_loader
batch_base = next(iter(data_loader)).to(self.device) batch_base = next(iter(data_loader)).to(self.device)
# calculate node_e0 # calculate node_e0
# batch = self._clone_batch(batch_base) # batch = self._clone_batch(batch_base)
# node_heads = batch["head"][batch["batch"]] # node_heads = batch["head"][batch["batch"]]
# num_atoms_arange = torch.arange(batch["positions"].shape[0]) # num_atoms_arange = torch.arange(batch["positions"].shape[0])
# node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
# num_atoms_arange, node_heads # num_atoms_arange, node_heads
# ] # ]
# set_seed(0) # set_seed(0)
out = self.models[0]( out = self.models[0](
batch_base.to_dict(), batch_base.to_dict(),
compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS?
training=self.use_compile, training=self.use_compile,
) )
# print(f'&&& batch.positions: {batch["positions"]}') # print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.stress: {batch["stress"]}') # print(f'&&& batch.stress: {batch["stress"]}')
# print(f'&&& batch.to_dict(): {k} {v}') # print(f'&&& batch.to_dict(): {k} {v}')
# print("=======") # print("=======")
# print(f'&&& out["forces"]: {out["forces"]}') # print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}') # print(f'&&& training: {self.use_compile}')
predictions["energy"] = out["energy"].unsqueeze(-1).detach() predictions["energy"] = out["energy"].unsqueeze(-1).detach()
predictions["forces"] = out["forces"].detach() predictions["forces"] = out["forces"].detach()
if compute_stress: if compute_stress:
predictions["stress"] = out["stress"].detach() predictions["stress"] = out["stress"].detach()
# print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}')
return predictions return predictions
def fast_predict(self, gbatch, compute_stress=False): def fast_predict(self, gbatch, compute_stress=False):
gbatch.pos = gbatch.pos.to(self.dtype) gbatch.pos = gbatch.pos.to(self.dtype)
gbatch.cell = gbatch.cell.to(self.dtype) gbatch.cell = gbatch.cell.to(self.dtype)
predictions = {'energy': [], 'forces': []} predictions = {'energy': [], 'forces': []}
batch_base = self.convert_batch(gbatch) batch_base = self.convert_batch(gbatch)
out = self.models[0]( out = self.models[0](
batch_base.to_dict(), batch_base.to_dict(),
compute_stress=compute_stress, compute_stress=compute_stress,
training=self.use_compile, training=self.use_compile,
) )
predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64) predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64)
predictions["forces"] = out["forces"].detach().to(torch.float64) predictions["forces"] = out["forces"].detach().to(torch.float64)
if compute_stress: if compute_stress:
predictions["stress"] = out["stress"].detach().to(torch.float64) predictions["stress"] = out["stress"].detach().to(torch.float64)
return predictions return predictions
def convert_batch(self, gbatch): def convert_batch(self, gbatch):
from batchopt import radius_graph_pbc_cuda from batchopt import radius_graph_pbc_cuda
# edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi( # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi(
# from batchopt.pbc_graph_legacy import radius_graph_pbc # from batchopt.pbc_graph_legacy import radius_graph_pbc
# edge_indices, cell_offsets, num_neighbors = radius_graph_pbc( # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc(
edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda( edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda(
gbatch, gbatch,
radius=4.5, radius=4.5,
max_num_neighbors_threshold=float('inf'), max_num_neighbors_threshold=float('inf'),
pbc=[True, True, True], pbc=[True, True, True],
dtype=self.dtype dtype=self.dtype
) )
tmp = edge_indices[0].clone() tmp = edge_indices[0].clone()
edge_indices[0] = edge_indices[1] edge_indices[0] = edge_indices[1]
edge_indices[1] = tmp edge_indices[1] = tmp
# Create a one-hot matrix with number of columns equal to max atomic number + 1 # Create a one-hot matrix with number of columns equal to max atomic number + 1
indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table) indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table)
one_hot = to_one_hot( one_hot = to_one_hot(
torch.tensor(indices, dtype=torch.long).unsqueeze(-1), torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
num_classes=len(self.z_table), num_classes=len(self.z_table),
).to(self.device) ).to(self.device)
cbatch = Batch( cbatch = Batch(
positions = gbatch["pos"].clone(), positions = gbatch["pos"].clone(),
cell = gbatch["cell"].view(-1, 3), cell = gbatch["cell"].view(-1, 3),
batch = gbatch["batch"], batch = gbatch["batch"],
ptr = gbatch["ptr"], ptr = gbatch["ptr"],
edge_index = edge_indices, edge_index = edge_indices,
unit_shifts = cell_offsets, unit_shifts = cell_offsets,
node_attrs = one_hot, node_attrs = one_hot,
) )
return cbatch return cbatch
def predict_debug(self, atoms_list, gbatch, compute_stress=False): def predict_debug(self, atoms_list, gbatch, compute_stress=False):
predictions = {'energy': [], 'forces': []} predictions = {'energy': [], 'forces': []}
configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader( data_loader = torch_geometric.dataloader.DataLoader(
dataset=[ dataset=[
data.AtomicData.from_config( data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
) )
for config in configs for config in configs
], ],
batch_size=len(atoms_list), batch_size=len(atoms_list),
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
) )
# get the first batch of data_loader # get the first batch of data_loader
# batch_base = next(iter(data_loader)).to(self.device) # batch_base = next(iter(data_loader)).to(self.device)
batch_base_tmp = next(iter(data_loader)).to(self.device) batch_base_tmp = next(iter(data_loader)).to(self.device)
batch2 = self.convert_batch(gbatch) batch2 = self.convert_batch(gbatch)
batch_base = Batch( batch_base = Batch(
# positions = batch_base_tmp["positions"], # positions = batch_base_tmp["positions"],
positions = batch2["positions"], positions = batch2["positions"],
# node_attrs = batch_base_tmp["node_attrs"], # node_attrs = batch_base_tmp["node_attrs"],
node_attrs = batch2["node_attrs"], node_attrs = batch2["node_attrs"],
# cell = batch_base_tmp["cell"], # cell = batch_base_tmp["cell"],
cell = batch2["cell"], cell = batch2["cell"],
edge_index = batch2["edge_index"], edge_index = batch2["edge_index"],
unit_shifts = batch2["unit_shifts"], unit_shifts = batch2["unit_shifts"],
# batch = batch_base_tmp["batch"], # batch = batch_base_tmp["batch"],
batch = batch2["batch"], batch = batch2["batch"],
# ptr = batch_base_tmp["ptr"], # ptr = batch_base_tmp["ptr"],
ptr = batch2["ptr"], ptr = batch2["ptr"],
) )
torch.set_printoptions(threshold=float('inf')) torch.set_printoptions(threshold=float('inf'))
# print(f'batch2["edge_index"]: {batch2["edge_index"]}') # print(f'batch2["edge_index"]: {batch2["edge_index"]}')
# print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}') # print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}')
# print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}') # print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}')
# print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}') # print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}')
# calculate node_e0 # calculate node_e0
# batch = self._clone_batch(batch_base) # batch = self._clone_batch(batch_base)
# node_heads = batch["head"][batch["batch"]] # node_heads = batch["head"][batch["batch"]]
# num_atoms_arange = torch.arange(batch["positions"].shape[0]) # num_atoms_arange = torch.arange(batch["positions"].shape[0])
# node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
# num_atoms_arange, node_heads # num_atoms_arange, node_heads
# ] # ]
# set_seed(0) # set_seed(0)
out = self.models[0]( out = self.models[0](
batch_base.to_dict(), batch_base.to_dict(),
compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS?
training=self.use_compile, training=self.use_compile,
) )
# print(f'&&& batch.positions: {batch["positions"]}') # print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.cell: {batch["cell"]}') # print(f'&&& batch.cell: {batch["cell"]}')
# print(f'&&& batch.stress: {batch["stress"]}') # print(f'&&& batch.stress: {batch["stress"]}')
# for k,v in batch.to_dict().items(): # for k,v in batch.to_dict().items():
# print(f'&&& batch.to_dict(): {k} {v}') # print(f'&&& batch.to_dict(): {k} {v}')
# print("=======") # print("=======")
# print(f'&&& out["forces"]: {out["forces"]}') # print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}') # print(f'&&& training: {self.use_compile}')
predictions["energy"] = out["energy"].unsqueeze(-1).detach() predictions["energy"] = out["energy"].unsqueeze(-1).detach()
predictions["forces"] = out["forces"].detach() predictions["forces"] = out["forces"].detach()
if compute_stress: if compute_stress:
predictions["stress"] = out["stress"].detach() predictions["stress"] = out["stress"].detach()
# print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}')
return predictions return predictions
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment