Unverified Commit 273418ee authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Merge pull request #1 from hjhk258/main

Fix
parents 7fb73825 9d7b4f63
__version__ = "0.3.13"
__all__ = ["__version__"]
__version__ = "0.3.13"
__all__ = ["__version__"]
from .foundations_models import mace_anicc, mace_mp, mace_off
from .lammps_mace import LAMMPS_MACE
from .mace import MACECalculator
__all__ = [
"MACECalculator",
"LAMMPS_MACE",
"mace_mp",
"mace_off",
"mace_anicc",
]
from .foundations_models import mace_anicc, mace_mp, mace_off
from .lammps_mace import LAMMPS_MACE
from .mace import MACECalculator
__all__ = [
"MACECalculator",
"LAMMPS_MACE",
"mace_mp",
"mace_off",
"mace_anicc",
]
import os
import urllib.request
from pathlib import Path
from typing import Union
import torch
from ase import units
from ase.calculators.mixing import SumCalculator
from .mace import MACECalculator
module_dir = os.path.dirname(__file__)
local_model_path = os.path.join(
module_dir, "foundations_models/mace-mpa-0-medium.model"
)
def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"""
Downloads or locates the MACE-MP checkpoint file.
Args:
model (str, optional): Path to the model or size specification.
Defaults to None which uses the medium model.
Returns:
str: Path to the downloaded (or cached, if previously loaded) checkpoint file.
"""
if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path):
return local_model_path
urls = {
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
"small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model",
"medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model",
"medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
"medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
"medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
"medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
"mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
"mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
}
checkpoint_url = (
urls.get(model, urls["medium-mpa-0"])
if model
in (
None,
"small",
"medium",
"large",
"small-0b",
"medium-0b",
"small-0b2",
"medium-0b2",
"large-0b2",
"medium-0b3",
"medium-mpa-0",
"medium-omat-0",
)
else model
)
if checkpoint_url == urls["medium-mpa-0"]:
print(
"Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument"
)
ASL_checkpoint_urls = {
urls["medium-omat-0"],
urls["mace-matpes-pbe-0"],
urls["mace-matpes-r2scan-0"],
}
if checkpoint_url in ASL_checkpoint_urls:
print(
"Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license."
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")
return cached_model_path
def mace_mp(
model: Union[str, Path] = None,
device: str = "",
default_dtype: str = "float32",
dispersion: bool = False,
damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"]
dispersion_xc: str = "pbe",
dispersion_cutoff: float = 40.0 * units.Bohr,
return_raw_model: bool = False,
**kwargs,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements).
The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models.
Note:
If you are using this function, please cite the relevant paper for the Materials Project,
any paper associated with the MACE model, and also the following:
- MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena,
Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096
- MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b,
DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal
- Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang,
Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920
Args:
model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default model from figshare. Specify "small",
"medium" or "large" to download a smaller or larger model from figshare.
device (str, optional): Device to use for the model. Defaults to "cuda" if available.
default_dtype (str, optional): Default dtype for the model. Defaults to "float32".
dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False.
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
Returns:
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
if model in (
None,
"small",
"medium",
"large",
"medium-mpa-0",
"small-0b",
"medium-0b",
"small-0b2",
"medium-0b2",
"medium-0b3",
"large-0b2",
"medium-omat-0",
) or str(model).startswith("https:"):
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
model_path = model
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if default_dtype == "float64":
print(
"Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization."
)
if default_dtype == "float32":
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)
if return_raw_model:
return torch.load(model_path, map_location=device)
mace_calc = MACECalculator(
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
)
if not dispersion:
return mace_calc
try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
"Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)"
) from exc
print("Using TorchDFTD3Calculator for D3 dispersion corrections")
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)
return SumCalculator([mace_calc, d3_calc])
def mace_off(
model: Union[str, Path] = None,
device: str = "",
default_dtype: str = "float64",
return_raw_model: bool = False,
**kwargs,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models.
The model is released under the ASL license.
Note:
If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211
Args:
model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off.
Specify "small", "medium" or "large" to download a smaller or larger model.
device (str, optional): Device to use for the model. Defaults to "cuda".
default_dtype (str, optional): Default dtype for the model. Defaults to "float64".
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator.
Returns:
MACECalculator: trained on the MACE-OFF23 dataset
"""
try:
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if return_raw_model:
return torch.load(model, map_location=device)
if default_dtype == "float64":
print(
"Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization."
)
if default_dtype == "float32":
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)
mace_calc = MACECalculator(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)
return mace_calc
def mace_anicc(
device: str = "cuda",
model_path: str = None,
return_raw_model: bool = False,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
The model is released under the MIT license.
Note:
If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following:
- "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322
"""
if model_path is None:
model_path = os.path.join(
module_dir, "foundations_models/ani500k_large_CC.model"
)
print(
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
)
if not os.path.exists(model_path):
model_dir = os.path.dirname(model_path)
os.makedirs(model_dir, exist_ok=True)
# Download the model
print(f"Model not found at {model_path}. Downloading...")
model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model"
try:
def report_progress(block_num, block_size, total_size):
downloaded = block_num * block_size
percent = min(100, downloaded * 100 / total_size)
if total_size > 0:
print(
f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)",
end="",
)
urllib.request.urlretrieve(
model_url, model_path, reporthook=report_progress
)
print("\nDownload complete!")
except Exception as e:
raise RuntimeError(f"Failed to download model: {e}") from e
if return_raw_model:
return torch.load(model_path, map_location=device)
return MACECalculator(
model_paths=model_path, device=device, default_dtype="float64"
)
import os
import urllib.request
from pathlib import Path
from typing import Union
import torch
from ase import units
from ase.calculators.mixing import SumCalculator
from .mace import MACECalculator
module_dir = os.path.dirname(__file__)
local_model_path = os.path.join(
module_dir, "foundations_models/mace-mpa-0-medium.model"
)
def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"""
Downloads or locates the MACE-MP checkpoint file.
Args:
model (str, optional): Path to the model or size specification.
Defaults to None which uses the medium model.
Returns:
str: Path to the downloaded (or cached, if previously loaded) checkpoint file.
"""
if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path):
return local_model_path
urls = {
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
"small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model",
"medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model",
"medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
"medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
"medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
"medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
"mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
"mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
}
checkpoint_url = (
urls.get(model, urls["medium-mpa-0"])
if model
in (
None,
"small",
"medium",
"large",
"small-0b",
"medium-0b",
"small-0b2",
"medium-0b2",
"large-0b2",
"medium-0b3",
"medium-mpa-0",
"medium-omat-0",
)
else model
)
if checkpoint_url == urls["medium-mpa-0"]:
print(
"Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument"
)
ASL_checkpoint_urls = {
urls["medium-omat-0"],
urls["mace-matpes-pbe-0"],
urls["mace-matpes-r2scan-0"],
}
if checkpoint_url in ASL_checkpoint_urls:
print(
"Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license."
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")
return cached_model_path
def mace_mp(
model: Union[str, Path] = None,
device: str = "",
default_dtype: str = "float32",
dispersion: bool = False,
damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"]
dispersion_xc: str = "pbe",
dispersion_cutoff: float = 40.0 * units.Bohr,
return_raw_model: bool = False,
**kwargs,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements).
The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models.
Note:
If you are using this function, please cite the relevant paper for the Materials Project,
any paper associated with the MACE model, and also the following:
- MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena,
Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096
- MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b,
DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal
- Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang,
Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920
Args:
model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default model from figshare. Specify "small",
"medium" or "large" to download a smaller or larger model from figshare.
device (str, optional): Device to use for the model. Defaults to "cuda" if available.
default_dtype (str, optional): Default dtype for the model. Defaults to "float32".
dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False.
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
Returns:
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
if model in (
None,
"small",
"medium",
"large",
"medium-mpa-0",
"small-0b",
"medium-0b",
"small-0b2",
"medium-0b2",
"medium-0b3",
"large-0b2",
"medium-omat-0",
) or str(model).startswith("https:"):
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
model_path = model
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if default_dtype == "float64":
print(
"Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization."
)
if default_dtype == "float32":
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)
if return_raw_model:
return torch.load(model_path, map_location=device)
mace_calc = MACECalculator(
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
)
if not dispersion:
return mace_calc
try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
"Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)"
) from exc
print("Using TorchDFTD3Calculator for D3 dispersion corrections")
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)
return SumCalculator([mace_calc, d3_calc])
def mace_off(
model: Union[str, Path] = None,
device: str = "",
default_dtype: str = "float64",
return_raw_model: bool = False,
**kwargs,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models.
The model is released under the ASL license.
Note:
If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211
Args:
model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off.
Specify "small", "medium" or "large" to download a smaller or larger model.
device (str, optional): Device to use for the model. Defaults to "cuda".
default_dtype (str, optional): Default dtype for the model. Defaults to "float64".
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator.
Returns:
MACECalculator: trained on the MACE-OFF23 dataset
"""
try:
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if return_raw_model:
return torch.load(model, map_location=device)
if default_dtype == "float64":
print(
"Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization."
)
if default_dtype == "float32":
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)
mace_calc = MACECalculator(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)
return mace_calc
def mace_anicc(
device: str = "cuda",
model_path: str = None,
return_raw_model: bool = False,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
The model is released under the MIT license.
Note:
If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following:
- "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322
"""
if model_path is None:
model_path = os.path.join(
module_dir, "foundations_models/ani500k_large_CC.model"
)
print(
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
)
if not os.path.exists(model_path):
model_dir = os.path.dirname(model_path)
os.makedirs(model_dir, exist_ok=True)
# Download the model
print(f"Model not found at {model_path}. Downloading...")
model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model"
try:
def report_progress(block_num, block_size, total_size):
downloaded = block_num * block_size
percent = min(100, downloaded * 100 / total_size)
if total_size > 0:
print(
f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)",
end="",
)
urllib.request.urlretrieve(
model_url, model_path, reporthook=report_progress
)
print("\nDownload complete!")
except Exception as e:
raise RuntimeError(f"Failed to download model: {e}") from e
if return_raw_model:
return torch.load(model_path, map_location=device)
return MACECalculator(
model_paths=model_path, device=device, default_dtype="float64"
)
from typing import Dict, List, Optional
import torch
from e3nn.util.jit import compile_mode
from mace.tools.scatter import scatter_sum
@compile_mode("script")
class LAMMPS_MACE(torch.nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "heads"):
model.heads = [None]
self.register_buffer(
"head",
torch.tensor(
self.model.heads.index(kwargs.get("head", self.model.heads[-1])),
dtype=torch.long,
).unsqueeze(0),
)
for param in self.model.parameters():
param.requires_grad = False
def forward(
self,
data: Dict[str, torch.Tensor],
local_or_ghost: torch.Tensor,
compute_virials: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
num_graphs = data["ptr"].numel() - 1
compute_displacement = False
if compute_virials:
compute_displacement = True
data["head"] = self.head
out = self.model(
data,
training=False,
compute_force=False,
compute_virials=False,
compute_stress=False,
compute_displacement=compute_displacement,
)
node_energy = out["node_energy"]
if node_energy is None:
return {
"total_energy_local": None,
"node_energy": None,
"forces": None,
"virials": None,
}
positions = data["positions"]
displacement = out["displacement"]
forces: Optional[torch.Tensor] = torch.zeros_like(positions)
virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"])
# accumulate energies of local atoms
node_energy_local = node_energy * local_or_ghost
total_energy_local = scatter_sum(
src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs
)
# compute partial forces and (possibly) partial virials
grad_outputs: List[Optional[torch.Tensor]] = [
torch.ones_like(total_energy_local)
]
if compute_virials and displacement is not None:
forces, virials = torch.autograd.grad(
outputs=[total_energy_local],
inputs=[positions, displacement],
grad_outputs=grad_outputs,
retain_graph=False,
create_graph=False,
allow_unused=True,
)
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
if virials is not None:
virials = -1 * virials
else:
virials = torch.zeros_like(displacement)
else:
forces = torch.autograd.grad(
outputs=[total_energy_local],
inputs=[positions],
grad_outputs=grad_outputs,
retain_graph=False,
create_graph=False,
allow_unused=True,
)[0]
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
return {
"total_energy_local": total_energy_local,
"node_energy": node_energy,
"forces": forces,
"virials": virials,
}
from typing import Dict, List, Optional
import torch
from e3nn.util.jit import compile_mode
from mace.tools.scatter import scatter_sum
@compile_mode("script")
class LAMMPS_MACE(torch.nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "heads"):
model.heads = [None]
self.register_buffer(
"head",
torch.tensor(
self.model.heads.index(kwargs.get("head", self.model.heads[-1])),
dtype=torch.long,
).unsqueeze(0),
)
for param in self.model.parameters():
param.requires_grad = False
def forward(
self,
data: Dict[str, torch.Tensor],
local_or_ghost: torch.Tensor,
compute_virials: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
num_graphs = data["ptr"].numel() - 1
compute_displacement = False
if compute_virials:
compute_displacement = True
data["head"] = self.head
out = self.model(
data,
training=False,
compute_force=False,
compute_virials=False,
compute_stress=False,
compute_displacement=compute_displacement,
)
node_energy = out["node_energy"]
if node_energy is None:
return {
"total_energy_local": None,
"node_energy": None,
"forces": None,
"virials": None,
}
positions = data["positions"]
displacement = out["displacement"]
forces: Optional[torch.Tensor] = torch.zeros_like(positions)
virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"])
# accumulate energies of local atoms
node_energy_local = node_energy * local_or_ghost
total_energy_local = scatter_sum(
src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs
)
# compute partial forces and (possibly) partial virials
grad_outputs: List[Optional[torch.Tensor]] = [
torch.ones_like(total_energy_local)
]
if compute_virials and displacement is not None:
forces, virials = torch.autograd.grad(
outputs=[total_energy_local],
inputs=[positions, displacement],
grad_outputs=grad_outputs,
retain_graph=False,
create_graph=False,
allow_unused=True,
)
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
if virials is not None:
virials = -1 * virials
else:
virials = torch.zeros_like(displacement)
else:
forces = torch.autograd.grad(
outputs=[total_energy_local],
inputs=[positions],
grad_outputs=grad_outputs,
retain_graph=False,
create_graph=False,
allow_unused=True,
)[0]
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
return {
"total_energy_local": total_energy_local,
"node_energy": node_energy,
"forces": forces,
"virials": virials,
}
import logging
import os
import sys
import time
from contextlib import contextmanager
from typing import Dict, Tuple
import torch
from ase.data import chemical_symbols
from e3nn.util.jit import compile_mode
try:
from lammps.mliap.mliap_unified_abc import MLIAPUnified
except ImportError:
class MLIAPUnified:
def __init__(self):
pass
class MACELammpsConfig:
"""Configuration settings for MACE-LAMMPS integration."""
def __init__(self):
self.debug_time = self._get_env_bool("MACE_TIME", False)
self.debug_profile = self._get_env_bool("MACE_PROFILE", False)
self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5"))
self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10"))
self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False)
self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False)
@staticmethod
def _get_env_bool(var_name: str, default: bool) -> bool:
return os.environ.get(var_name, str(default)).lower() in (
"true",
"1",
"t",
"yes",
)
@contextmanager
def timer(name: str, enabled: bool = True):
"""Context manager for timing code blocks."""
if not enabled:
yield
return
start = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start
logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms")
@compile_mode("script")
class MACEEdgeForcesWrapper(torch.nn.Module):
"""Wrapper that adds per-pair force computation to a MACE model."""
def __init__(self, model: torch.nn.Module, **kwargs):
super().__init__()
self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "heads"):
model.heads = ["Default"]
head_name = kwargs.get("head", model.heads[-1])
head_idx = model.heads.index(head_name)
self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long))
for p in self.model.parameters():
p.requires_grad = False
def forward(
self, data: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute energies and per-pair forces."""
data["head"] = self.head
out = self.model(
data,
training=False,
compute_force=False,
compute_virials=False,
compute_stress=False,
compute_displacement=False,
compute_edge_forces=True,
lammps_mliap=True,
)
node_energy = out["node_energy"]
pair_forces = out["edge_forces"]
total_energy = out["energy"][0]
if pair_forces is None:
pair_forces = torch.zeros_like(data["vectors"])
return total_energy, node_energy, pair_forces
class LAMMPS_MLIAP_MACE(MLIAPUnified):
"""MACE integration for LAMMPS using the MLIAP interface."""
def __init__(self, model, **kwargs):
super().__init__()
self.config = MACELammpsConfig()
self.model = MACEEdgeForcesWrapper(model, **kwargs)
self.element_types = [chemical_symbols[s] for s in model.atomic_numbers]
self.num_species = len(self.element_types)
self.rcutfac = 0.5 * float(model.r_max)
self.ndescriptors = 1
self.nparams = 1
self.dtype = model.r_max.dtype
self.device = "cpu"
self.initialized = False
self.step = 0
def _initialize_device(self, data):
using_kokkos = "kokkos" in data.__class__.__module__.lower()
if using_kokkos and not self.config.force_cpu:
device = torch.as_tensor(data.elems).device
if device.type == "cpu" and not self.config.allow_cpu:
raise ValueError(
"GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation."
)
else:
device = torch.device("cpu")
self.device = device
self.model = self.model.to(device)
logging.info(f"MACE model initialized on device: {device}")
self.initialized = True
def compute_forces(self, data):
natoms = data.nlocal
ntotal = data.ntotal
nghosts = ntotal - natoms
npairs = data.npairs
species = torch.as_tensor(data.elems, dtype=torch.int64)
if not self.initialized:
self._initialize_device(data)
self.step += 1
self._manage_profiling()
if natoms == 0 or npairs <= 1:
return
with timer("total_step", enabled=self.config.debug_time):
with timer("prepare_batch", enabled=self.config.debug_time):
batch = self._prepare_batch(data, natoms, nghosts, species)
with timer("model_forward", enabled=self.config.debug_time):
_, atom_energies, pair_forces = self.model(batch)
if self.device.type != "cpu":
torch.cuda.synchronize()
with timer("update_lammps", enabled=self.config.debug_time):
self._update_lammps_data(data, atom_energies, pair_forces, natoms)
def _prepare_batch(self, data, natoms, nghosts, species):
"""Prepare the input batch for the MACE model."""
return {
"vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device),
"node_attrs": torch.nn.functional.one_hot(
species.to(self.device), num_classes=self.num_species
).to(self.dtype),
"edge_index": torch.stack(
[
torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device),
torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device),
],
dim=0,
),
"batch": torch.zeros(natoms, dtype=torch.int64, device=self.device),
"lammps_class": data,
"natoms": (natoms, nghosts),
}
def _update_lammps_data(self, data, atom_energies, pair_forces, natoms):
"""Update LAMMPS data structures with computed energies and forces."""
if self.dtype == torch.float32:
pair_forces = pair_forces.double()
eatoms = torch.as_tensor(data.eatoms)
eatoms.copy_(atom_energies[:natoms])
data.energy = torch.sum(atom_energies[:natoms])
data.update_pair_forces_gpu(pair_forces)
def _manage_profiling(self):
if not self.config.debug_profile:
return
if self.step == self.config.profile_start_step:
logging.info(f"Starting CUDA profiler at step {self.step}")
torch.cuda.profiler.start()
if self.step == self.config.profile_end_step:
logging.info(f"Stopping CUDA profiler at step {self.step}")
torch.cuda.profiler.stop()
logging.info("Profiling complete. Exiting.")
sys.exit()
def compute_descriptors(self, data):
pass
def compute_gradients(self, data):
pass
import logging
import os
import sys
import time
from contextlib import contextmanager
from typing import Dict, Tuple
import torch
from ase.data import chemical_symbols
from e3nn.util.jit import compile_mode
try:
from lammps.mliap.mliap_unified_abc import MLIAPUnified
except ImportError:
class MLIAPUnified:
def __init__(self):
pass
class MACELammpsConfig:
"""Configuration settings for MACE-LAMMPS integration."""
def __init__(self):
self.debug_time = self._get_env_bool("MACE_TIME", False)
self.debug_profile = self._get_env_bool("MACE_PROFILE", False)
self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5"))
self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10"))
self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False)
self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False)
@staticmethod
def _get_env_bool(var_name: str, default: bool) -> bool:
return os.environ.get(var_name, str(default)).lower() in (
"true",
"1",
"t",
"yes",
)
@contextmanager
def timer(name: str, enabled: bool = True):
"""Context manager for timing code blocks."""
if not enabled:
yield
return
start = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start
logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms")
@compile_mode("script")
class MACEEdgeForcesWrapper(torch.nn.Module):
"""Wrapper that adds per-pair force computation to a MACE model."""
def __init__(self, model: torch.nn.Module, **kwargs):
super().__init__()
self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "heads"):
model.heads = ["Default"]
head_name = kwargs.get("head", model.heads[-1])
head_idx = model.heads.index(head_name)
self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long))
for p in self.model.parameters():
p.requires_grad = False
def forward(
self, data: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute energies and per-pair forces."""
data["head"] = self.head
out = self.model(
data,
training=False,
compute_force=False,
compute_virials=False,
compute_stress=False,
compute_displacement=False,
compute_edge_forces=True,
lammps_mliap=True,
)
node_energy = out["node_energy"]
pair_forces = out["edge_forces"]
total_energy = out["energy"][0]
if pair_forces is None:
pair_forces = torch.zeros_like(data["vectors"])
return total_energy, node_energy, pair_forces
class LAMMPS_MLIAP_MACE(MLIAPUnified):
"""MACE integration for LAMMPS using the MLIAP interface."""
def __init__(self, model, **kwargs):
super().__init__()
self.config = MACELammpsConfig()
self.model = MACEEdgeForcesWrapper(model, **kwargs)
self.element_types = [chemical_symbols[s] for s in model.atomic_numbers]
self.num_species = len(self.element_types)
self.rcutfac = 0.5 * float(model.r_max)
self.ndescriptors = 1
self.nparams = 1
self.dtype = model.r_max.dtype
self.device = "cpu"
self.initialized = False
self.step = 0
def _initialize_device(self, data):
using_kokkos = "kokkos" in data.__class__.__module__.lower()
if using_kokkos and not self.config.force_cpu:
device = torch.as_tensor(data.elems).device
if device.type == "cpu" and not self.config.allow_cpu:
raise ValueError(
"GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation."
)
else:
device = torch.device("cpu")
self.device = device
self.model = self.model.to(device)
logging.info(f"MACE model initialized on device: {device}")
self.initialized = True
def compute_forces(self, data):
natoms = data.nlocal
ntotal = data.ntotal
nghosts = ntotal - natoms
npairs = data.npairs
species = torch.as_tensor(data.elems, dtype=torch.int64)
if not self.initialized:
self._initialize_device(data)
self.step += 1
self._manage_profiling()
if natoms == 0 or npairs <= 1:
return
with timer("total_step", enabled=self.config.debug_time):
with timer("prepare_batch", enabled=self.config.debug_time):
batch = self._prepare_batch(data, natoms, nghosts, species)
with timer("model_forward", enabled=self.config.debug_time):
_, atom_energies, pair_forces = self.model(batch)
if self.device.type != "cpu":
torch.cuda.synchronize()
with timer("update_lammps", enabled=self.config.debug_time):
self._update_lammps_data(data, atom_energies, pair_forces, natoms)
def _prepare_batch(self, data, natoms, nghosts, species):
"""Prepare the input batch for the MACE model."""
return {
"vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device),
"node_attrs": torch.nn.functional.one_hot(
species.to(self.device), num_classes=self.num_species
).to(self.dtype),
"edge_index": torch.stack(
[
torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device),
torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device),
],
dim=0,
),
"batch": torch.zeros(natoms, dtype=torch.int64, device=self.device),
"lammps_class": data,
"natoms": (natoms, nghosts),
}
def _update_lammps_data(self, data, atom_energies, pair_forces, natoms):
"""Update LAMMPS data structures with computed energies and forces."""
if self.dtype == torch.float32:
pair_forces = pair_forces.double()
eatoms = torch.as_tensor(data.eatoms)
eatoms.copy_(atom_energies[:natoms])
data.energy = torch.sum(atom_energies[:natoms])
data.update_pair_forces_gpu(pair_forces)
def _manage_profiling(self):
if not self.config.debug_profile:
return
if self.step == self.config.profile_start_step:
logging.info(f"Starting CUDA profiler at step {self.step}")
torch.cuda.profiler.start()
if self.step == self.config.profile_end_step:
logging.info(f"Stopping CUDA profiler at step {self.step}")
torch.cuda.profiler.stop()
logging.info("Profiling complete. Exiting.")
sys.exit()
def compute_descriptors(self, data):
pass
def compute_gradients(self, data):
pass
###########################################################################################
# The ASE Calculator for MACE
# Authors: Ilyes Batatia, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
# pylint: disable=wrong-import-position
import os
from glob import glob
from pathlib import Path
from typing import List, Union
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import numpy as np
import torch
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3
from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
from mace.tools.scripts_utils import extract_model
import random
from mace.tools.torch_geometric.batch import Batch
from mace.tools import (
atomic_numbers_to_indices,
to_one_hot,
)
import time
def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
"""Get the dtype of the model"""
mode_dtype = next(model.parameters()).dtype
if mode_dtype == torch.float64:
return "float64"
if mode_dtype == torch.float32:
return "float32"
raise ValueError(f"Unknown dtype {mode_dtype}")
class MACECalculator(Calculator):
"""MACE ASE Calculator
args:
model_paths: str, path to model or models if a committee is produced
to make a committee use a wild card notation like mace_*.model
device: str, device to run on (cuda or cpu)
energy_units_to_eV: float, conversion factor from model energy units to eV
length_units_to_A: float, conversion factor from model length units to Angstroms
default_dtype: str, default dtype of model
charges_key: str, Array field of atoms object where atomic charges are stored
model_type: str, type of model to load
Options: [MACE, DipoleMACE, EnergyDipoleMACE]
Dipoles are returned in units of Debye
"""
def __init__(
self,
model_paths: Union[list, str, None] = None,
models: Union[List[torch.nn.Module], torch.nn.Module, None] = None,
device: str = "cpu",
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
default_dtype="",
charges_key="Qs",
model_type="MACE",
compile_mode=None,
fullgraph=True,
enable_cueq=False,
**kwargs,
):
Calculator.__init__(self, **kwargs)
self.device = device
self.dtype=None
if enable_cueq:
assert model_type == "MACE", "CuEq only supports MACE models"
compile_mode = None
if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
)
if model_paths is None:
logging.warning(f"{deprecation_message} in the future.")
model_paths = kwargs["model_path"]
else:
raise ValueError(
f"both 'model_path' and 'model_paths' given, {deprecation_message} only."
)
if (model_paths is None) == (models is None):
raise ValueError(
"Exactly one of 'model_paths' or 'models' must be provided"
)
self.results = {}
self.model_type = model_type
self.compute_atomic_stresses = False
if model_type == "MACE":
self.implemented_properties = [
"energy",
"free_energy",
"node_energy",
"forces",
"stress",
]
if kwargs.get("compute_atomic_stresses", False):
self.implemented_properties.extend(["stresses", "virials"])
self.compute_atomic_stresses = True
elif model_type == "DipoleMACE":
self.implemented_properties = ["dipole"]
elif model_type == "EnergyDipoleMACE":
self.implemented_properties = [
"energy",
"free_energy",
"node_energy",
"forces",
"stress",
"dipole",
]
else:
raise ValueError(
f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported"
)
if model_paths is not None:
if isinstance(model_paths, str):
# Find all models that satisfy the wildcard (e.g. mace_model_*.pt)
model_paths_glob = glob(model_paths)
if len(model_paths_glob) == 0:
raise ValueError(f"Couldn't find MACE model files: {model_paths}")
model_paths = model_paths_glob
elif isinstance(model_paths, Path):
model_paths = [model_paths]
if len(model_paths) == 0:
raise ValueError("No mace file names supplied")
self.num_models = len(model_paths)
# Load models from files
self.models = [
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
elif models is not None:
if not isinstance(models, list):
models = [models]
if len(models) == 0:
raise ValueError("No models supplied")
self.models = models
self.num_models = len(models)
if self.num_models > 1:
print(f"Running committee mace with {self.num_models} models")
if model_type in ["MACE", "EnergyDipoleMACE"]:
self.implemented_properties.extend(
["energies", "energy_var", "forces_comm", "stress_var"]
)
elif model_type == "DipoleMACE":
self.implemented_properties.extend(["dipole_var"])
if compile_mode is not None:
print(f"Torch compile is enabled with mode: {compile_mode}")
self.models = [
torch.compile(
prepare(extract_model)(model=model, map_location=device),
mode=compile_mode,
fullgraph=fullgraph,
)
for model in self.models
]
self.use_compile = True
else:
self.use_compile = False
# Ensure all models are on the same device
for model in self.models:
model.to(device)
r_maxs = [model.r_max.cpu() for model in self.models]
r_maxs = np.array(r_maxs)
if not np.all(r_maxs == r_maxs[0]):
raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}")
self.r_max = float(r_maxs[0])
self.device = torch_tools.init_device(device)
self.energy_units_to_eV = energy_units_to_eV
self.length_units_to_A = length_units_to_A
self.z_table = utils.AtomicNumberTable(
[int(z) for z in self.models[0].atomic_numbers]
)
self.charges_key = charges_key
try:
self.available_heads: List[str] = self.models[0].heads # type: ignore
except AttributeError:
self.available_heads = ["Default"]
kwarg_head = kwargs.get("head", None)
if kwarg_head is not None:
self.head = kwarg_head
else:
self.head = [head for head in self.available_heads if head.lower() == "default"]
if len(self.head) == 0:
raise ValueError(
"Head keyword was not provided, and no head in the model is 'default'. "
"Please provide a head keyword to specify the head you want to use. "
f"Available heads are: {self.available_heads}"
)
self.head = self.head[0]
print("Using head", self.head, "out of", self.available_heads)
model_dtype = get_model_dtype(self.models[0])
if default_dtype == "":
print(
f"No dtype selected, switching to {model_dtype} to match model dtype."
)
default_dtype = model_dtype
if model_dtype != default_dtype:
print(
f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}."
)
if default_dtype == "float64":
self.models = [model.double() for model in self.models]
elif default_dtype == "float32":
self.models = [model.float() for model in self.models]
torch_tools.set_default_dtype(default_dtype)
if enable_cueq:
print("Converting models to CuEq for acceleration")
self.models = [
run_e3nn_to_cueq(model, device=device).to(device)
for model in self.models
]
for model in self.models:
for param in model.parameters():
param.requires_grad = False
self.dtype = torch.float64 if default_dtype == "float64" else torch.float32
self.model_time = 0.0
self.calc_time = 0.0
def _create_result_tensors(
self, model_type: str, num_models: int, num_atoms: int
) -> dict:
"""
Create tensors to store the results of the committee
:param model_type: str, type of model to load
Options: [MACE, DipoleMACE, EnergyDipoleMACE]
:param num_models: int, number of models in the committee
:return: tuple of torch tensors
"""
dict_of_tensors = {}
if model_type in ["MACE", "EnergyDipoleMACE"]:
energies = torch.zeros(num_models, device=self.device)
node_energy = torch.zeros(num_models, num_atoms, device=self.device)
forces = torch.zeros(num_models, num_atoms, 3, device=self.device)
stress = torch.zeros(num_models, 3, 3, device=self.device)
dict_of_tensors.update(
{
"energies": energies,
"node_energy": node_energy,
"forces": forces,
"stress": stress,
}
)
if model_type in ["EnergyDipoleMACE", "DipoleMACE"]:
dipole = torch.zeros(num_models, 3, device=self.device)
dict_of_tensors.update({"dipole": dipole})
return dict_of_tensors
def _atoms_to_batch(self, atoms):
keyspec = data.KeySpecification(
info_keys={}, arrays_keys={"charges": self.charges_key}
)
config = data.config_from_atoms(
atoms, key_specification=keyspec, head_name=self.head
)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.available_heads,
)
],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader)).to(self.device)
return batch
def _clone_batch(self, batch):
batch_clone = batch.clone()
if self.use_compile:
batch_clone["node_attrs"].requires_grad_(True)
batch_clone["positions"].requires_grad_(True)
return batch_clone
# pylint: disable=dangerous-default-value
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
"""
Calculate properties.
:param atoms: ase.Atoms object
:param properties: [str], properties to be computed, used by ASE internally
:param system_changes: [str], system changes since last calculation, used by ASE internally
:return:
"""
# call to base-class to set atoms attribute
calc_start_t = time.perf_counter()
Calculator.calculate(self, atoms)
batch_base = self._atoms_to_batch(atoms)
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
batch = self._clone_batch(batch_base)
node_heads = batch["head"][batch["batch"]]
num_atoms_arange = torch.arange(batch["positions"].shape[0])
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
num_atoms_arange, node_heads
]
compute_stress = not self.use_compile
else:
compute_stress = False
ret_tensors = self._create_result_tensors(
self.model_type, self.num_models, len(atoms)
)
for i, model in enumerate(self.models):
batch = self._clone_batch(batch_base)
# print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}')
# set_seed(0)
model_start_t = time.perf_counter()
out = model(
batch.to_dict(),
compute_stress=compute_stress,
training=self.use_compile,
compute_edge_forces=self.compute_atomic_stresses,
compute_atomic_stresses=self.compute_atomic_stresses,
)
model_end_t = time.perf_counter()
self.model_time += (model_end_t - model_start_t)
# print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.stress: {batch["stress"]}')
# print(f'compute_stress: {compute_stress}')
# for k,v in batch.to_dict().items():
# print(f'&&& batch.to_dict(): {k} {v}')
# print("=======")
# print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}')
# print(f'@@@File: {__file__}, out: {out}')
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
ret_tensors["energies"][i] = out["energy"].detach()
ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach()
ret_tensors["forces"][i] = out["forces"].detach()
if out["stress"] is not None:
ret_tensors["stress"][i] = out["stress"].detach()
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
ret_tensors["dipole"][i] = out["dipole"].detach()
if self.model_type in ["MACE"]:
if out["atomic_stresses"] is not None:
ret_tensors.setdefault("atomic_stresses", []).append(
out["atomic_stresses"].detach()
)
if out["atomic_virials"] is not None:
ret_tensors.setdefault("atomic_virials", []).append(
out["atomic_virials"].detach()
)
self.results = {}
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
self.results["energy"] = (
torch.mean(ret_tensors["energies"], dim=0).cpu().item()
* self.energy_units_to_eV
)
self.results["free_energy"] = self.results["energy"]
self.results["node_energy"] = (
torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy()
)
self.results["forces"] = (
torch.mean(ret_tensors["forces"], dim=0).cpu().numpy()
* self.energy_units_to_eV
/ self.length_units_to_A
)
if self.num_models > 1:
self.results["energies"] = (
ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV
)
self.results["energy_var"] = (
torch.var(ret_tensors["energies"], dim=0, unbiased=False)
.cpu()
.item()
* self.energy_units_to_eV
)
self.results["forces_comm"] = (
ret_tensors["forces"].cpu().numpy()
* self.energy_units_to_eV
/ self.length_units_to_A
)
if out["stress"] is not None:
self.results["stress"] = full_3x3_to_voigt_6_stress(
torch.mean(ret_tensors["stress"], dim=0).cpu().numpy()
* self.energy_units_to_eV
/ self.length_units_to_A**3
)
if self.num_models > 1:
self.results["stress_var"] = full_3x3_to_voigt_6_stress(
torch.var(ret_tensors["stress"], dim=0, unbiased=False)
.cpu()
.numpy()
* self.energy_units_to_eV
/ self.length_units_to_A**3
)
if "atomic_stresses" in ret_tensors:
self.results["stresses"] = (
torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0)
.cpu()
.numpy()
* self.energy_units_to_eV
/ self.length_units_to_A**3
)
if "atomic_virials" in ret_tensors:
self.results["virials"] = (
torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0)
.cpu()
.numpy()
* self.energy_units_to_eV
)
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
self.results["dipole"] = (
torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy()
)
if self.num_models > 1:
self.results["dipole_var"] = (
torch.var(ret_tensors["dipole"], dim=0, unbiased=False)
.cpu()
.numpy()
)
calc_end_t = time.perf_counter()
self.calc_time += (calc_end_t - calc_start_t)
def get_hessian(self, atoms=None):
if atoms is None and self.atoms is None:
raise ValueError("atoms not set")
if atoms is None:
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
batch = self._atoms_to_batch(atoms)
hessians = [
model(
self._clone_batch(batch).to_dict(),
compute_hessian=True,
compute_stress=False,
training=self.use_compile,
)["hessian"]
for model in self.models
]
hessians = [hessian.detach().cpu().numpy() for hessian in hessians]
if self.num_models == 1:
return hessians[0]
return hessians
def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
"""Extracts the descriptors from MACE model.
:param atoms: ase.Atoms object
:param invariants_only: bool, if True only the invariant descriptors are returned
:param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used
:return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise
"""
if atoms is None and self.atoms is None:
raise ValueError("atoms not set")
if atoms is None:
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
num_interactions = int(self.models[0].num_interactions)
if num_layers == -1:
num_layers = num_interactions
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]
irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out))
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer
)
if invariants_only:
descriptors = [
extract_invariant(
descriptor,
num_layers=num_layers,
num_features=num_invariant_features,
l_max=l_max,
)
for descriptor in descriptors
]
to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
]
if self.num_models == 1:
return descriptors[0]
return descriptors
def predict(self, atoms_list, compute_stress=False):
predictions = {'energy': [], 'forces': []}
configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
)
for config in configs
],
batch_size=len(atoms_list),
shuffle=False,
drop_last=False,
)
# get the first batch of data_loader
batch_base = next(iter(data_loader)).to(self.device)
# calculate node_e0
# batch = self._clone_batch(batch_base)
# node_heads = batch["head"][batch["batch"]]
# num_atoms_arange = torch.arange(batch["positions"].shape[0])
# node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
# num_atoms_arange, node_heads
# ]
# set_seed(0)
out = self.models[0](
batch_base.to_dict(),
compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS?
training=self.use_compile,
)
# print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.stress: {batch["stress"]}')
# print(f'&&& batch.to_dict(): {k} {v}')
# print("=======")
# print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}')
predictions["energy"] = out["energy"].unsqueeze(-1).detach()
predictions["forces"] = out["forces"].detach()
if compute_stress:
predictions["stress"] = out["stress"].detach()
# print(f'&&& predictions["forces"] in predict: {predictions["forces"]}')
return predictions
def fast_predict(self, gbatch, compute_stress=False):
gbatch.pos = gbatch.pos.to(self.dtype)
gbatch.cell = gbatch.cell.to(self.dtype)
predictions = {'energy': [], 'forces': []}
batch_base = self.convert_batch(gbatch)
out = self.models[0](
batch_base.to_dict(),
compute_stress=compute_stress,
training=self.use_compile,
)
predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64)
predictions["forces"] = out["forces"].detach().to(torch.float64)
if compute_stress:
predictions["stress"] = out["stress"].detach().to(torch.float64)
return predictions
def convert_batch(self, gbatch):
from batchopt import radius_graph_pbc_cuda
# edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi(
# from batchopt.pbc_graph_legacy import radius_graph_pbc
# edge_indices, cell_offsets, num_neighbors = radius_graph_pbc(
edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda(
gbatch,
radius=4.5,
max_num_neighbors_threshold=float('inf'),
pbc=[True, True, True],
dtype=self.dtype
)
tmp = edge_indices[0].clone()
edge_indices[0] = edge_indices[1]
edge_indices[1] = tmp
# Create a one-hot matrix with number of columns equal to max atomic number + 1
indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table)
one_hot = to_one_hot(
torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
num_classes=len(self.z_table),
).to(self.device)
cbatch = Batch(
positions = gbatch["pos"].clone(),
cell = gbatch["cell"].view(-1, 3),
batch = gbatch["batch"],
ptr = gbatch["ptr"],
edge_index = edge_indices,
unit_shifts = cell_offsets,
node_attrs = one_hot,
)
return cbatch
def predict_debug(self, atoms_list, gbatch, compute_stress=False):
predictions = {'energy': [], 'forces': []}
configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
)
for config in configs
],
batch_size=len(atoms_list),
shuffle=False,
drop_last=False,
)
# get the first batch of data_loader
# batch_base = next(iter(data_loader)).to(self.device)
batch_base_tmp = next(iter(data_loader)).to(self.device)
batch2 = self.convert_batch(gbatch)
batch_base = Batch(
# positions = batch_base_tmp["positions"],
positions = batch2["positions"],
# node_attrs = batch_base_tmp["node_attrs"],
node_attrs = batch2["node_attrs"],
# cell = batch_base_tmp["cell"],
cell = batch2["cell"],
edge_index = batch2["edge_index"],
unit_shifts = batch2["unit_shifts"],
# batch = batch_base_tmp["batch"],
batch = batch2["batch"],
# ptr = batch_base_tmp["ptr"],
ptr = batch2["ptr"],
)
torch.set_printoptions(threshold=float('inf'))
# print(f'batch2["edge_index"]: {batch2["edge_index"]}')
# print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}')
# print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}')
# print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}')
# calculate node_e0
# batch = self._clone_batch(batch_base)
# node_heads = batch["head"][batch["batch"]]
# num_atoms_arange = torch.arange(batch["positions"].shape[0])
# node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
# num_atoms_arange, node_heads
# ]
# set_seed(0)
out = self.models[0](
batch_base.to_dict(),
compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS?
training=self.use_compile,
)
# print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.cell: {batch["cell"]}')
# print(f'&&& batch.stress: {batch["stress"]}')
# for k,v in batch.to_dict().items():
# print(f'&&& batch.to_dict(): {k} {v}')
# print("=======")
# print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}')
predictions["energy"] = out["energy"].unsqueeze(-1).detach()
predictions["forces"] = out["forces"].detach()
if compute_stress:
predictions["stress"] = out["stress"].detach()
# print(f'&&& predictions["forces"] in predict: {predictions["forces"]}')
###########################################################################################
# The ASE Calculator for MACE
# Authors: Ilyes Batatia, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
# pylint: disable=wrong-import-position
import os
from glob import glob
from pathlib import Path
from typing import List, Union
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import numpy as np
import torch
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3
from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
from mace.tools.scripts_utils import extract_model
import random
from mace.tools.torch_geometric.batch import Batch
from mace.tools import (
atomic_numbers_to_indices,
to_one_hot,
)
import time
def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
"""Get the dtype of the model"""
mode_dtype = next(model.parameters()).dtype
if mode_dtype == torch.float64:
return "float64"
if mode_dtype == torch.float32:
return "float32"
raise ValueError(f"Unknown dtype {mode_dtype}")
class MACECalculator(Calculator):
"""MACE ASE Calculator
args:
model_paths: str, path to model or models if a committee is produced
to make a committee use a wild card notation like mace_*.model
device: str, device to run on (cuda or cpu)
energy_units_to_eV: float, conversion factor from model energy units to eV
length_units_to_A: float, conversion factor from model length units to Angstroms
default_dtype: str, default dtype of model
charges_key: str, Array field of atoms object where atomic charges are stored
model_type: str, type of model to load
Options: [MACE, DipoleMACE, EnergyDipoleMACE]
Dipoles are returned in units of Debye
"""
def __init__(
self,
model_paths: Union[list, str, None] = None,
models: Union[List[torch.nn.Module], torch.nn.Module, None] = None,
device: str = "cpu",
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
default_dtype="",
charges_key="Qs",
model_type="MACE",
compile_mode=None,
fullgraph=True,
enable_cueq=False,
**kwargs,
):
Calculator.__init__(self, **kwargs)
self.device = device
self.dtype=None
if enable_cueq:
assert model_type == "MACE", "CuEq only supports MACE models"
compile_mode = None
if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
)
if model_paths is None:
logging.warning(f"{deprecation_message} in the future.")
model_paths = kwargs["model_path"]
else:
raise ValueError(
f"both 'model_path' and 'model_paths' given, {deprecation_message} only."
)
if (model_paths is None) == (models is None):
raise ValueError(
"Exactly one of 'model_paths' or 'models' must be provided"
)
self.results = {}
self.model_type = model_type
self.compute_atomic_stresses = False
if model_type == "MACE":
self.implemented_properties = [
"energy",
"free_energy",
"node_energy",
"forces",
"stress",
]
if kwargs.get("compute_atomic_stresses", False):
self.implemented_properties.extend(["stresses", "virials"])
self.compute_atomic_stresses = True
elif model_type == "DipoleMACE":
self.implemented_properties = ["dipole"]
elif model_type == "EnergyDipoleMACE":
self.implemented_properties = [
"energy",
"free_energy",
"node_energy",
"forces",
"stress",
"dipole",
]
else:
raise ValueError(
f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported"
)
if model_paths is not None:
if isinstance(model_paths, str):
# Find all models that satisfy the wildcard (e.g. mace_model_*.pt)
model_paths_glob = glob(model_paths)
if len(model_paths_glob) == 0:
raise ValueError(f"Couldn't find MACE model files: {model_paths}")
model_paths = model_paths_glob
elif isinstance(model_paths, Path):
model_paths = [model_paths]
if len(model_paths) == 0:
raise ValueError("No mace file names supplied")
self.num_models = len(model_paths)
# Load models from files
self.models = [
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
elif models is not None:
if not isinstance(models, list):
models = [models]
if len(models) == 0:
raise ValueError("No models supplied")
self.models = models
self.num_models = len(models)
if self.num_models > 1:
print(f"Running committee mace with {self.num_models} models")
if model_type in ["MACE", "EnergyDipoleMACE"]:
self.implemented_properties.extend(
["energies", "energy_var", "forces_comm", "stress_var"]
)
elif model_type == "DipoleMACE":
self.implemented_properties.extend(["dipole_var"])
if compile_mode is not None:
print(f"Torch compile is enabled with mode: {compile_mode}")
self.models = [
torch.compile(
prepare(extract_model)(model=model, map_location=device),
mode=compile_mode,
fullgraph=fullgraph,
)
for model in self.models
]
self.use_compile = True
else:
self.use_compile = False
# Ensure all models are on the same device
for model in self.models:
model.to(device)
r_maxs = [model.r_max.cpu() for model in self.models]
r_maxs = np.array(r_maxs)
if not np.all(r_maxs == r_maxs[0]):
raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}")
self.r_max = float(r_maxs[0])
self.device = torch_tools.init_device(device)
self.energy_units_to_eV = energy_units_to_eV
self.length_units_to_A = length_units_to_A
self.z_table = utils.AtomicNumberTable(
[int(z) for z in self.models[0].atomic_numbers]
)
self.charges_key = charges_key
try:
self.available_heads: List[str] = self.models[0].heads # type: ignore
except AttributeError:
self.available_heads = ["Default"]
kwarg_head = kwargs.get("head", None)
if kwarg_head is not None:
self.head = kwarg_head
else:
self.head = [head for head in self.available_heads if head.lower() == "default"]
if len(self.head) == 0:
raise ValueError(
"Head keyword was not provided, and no head in the model is 'default'. "
"Please provide a head keyword to specify the head you want to use. "
f"Available heads are: {self.available_heads}"
)
self.head = self.head[0]
print("Using head", self.head, "out of", self.available_heads)
model_dtype = get_model_dtype(self.models[0])
if default_dtype == "":
print(
f"No dtype selected, switching to {model_dtype} to match model dtype."
)
default_dtype = model_dtype
if model_dtype != default_dtype:
print(
f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}."
)
if default_dtype == "float64":
self.models = [model.double() for model in self.models]
elif default_dtype == "float32":
self.models = [model.float() for model in self.models]
torch_tools.set_default_dtype(default_dtype)
if enable_cueq:
print("Converting models to CuEq for acceleration")
self.models = [
run_e3nn_to_cueq(model, device=device).to(device)
for model in self.models
]
for model in self.models:
for param in model.parameters():
param.requires_grad = False
self.dtype = torch.float64 if default_dtype == "float64" else torch.float32
self.model_time = 0.0
self.calc_time = 0.0
def _create_result_tensors(
self, model_type: str, num_models: int, num_atoms: int
) -> dict:
"""
Create tensors to store the results of the committee
:param model_type: str, type of model to load
Options: [MACE, DipoleMACE, EnergyDipoleMACE]
:param num_models: int, number of models in the committee
:return: tuple of torch tensors
"""
dict_of_tensors = {}
if model_type in ["MACE", "EnergyDipoleMACE"]:
energies = torch.zeros(num_models, device=self.device)
node_energy = torch.zeros(num_models, num_atoms, device=self.device)
forces = torch.zeros(num_models, num_atoms, 3, device=self.device)
stress = torch.zeros(num_models, 3, 3, device=self.device)
dict_of_tensors.update(
{
"energies": energies,
"node_energy": node_energy,
"forces": forces,
"stress": stress,
}
)
if model_type in ["EnergyDipoleMACE", "DipoleMACE"]:
dipole = torch.zeros(num_models, 3, device=self.device)
dict_of_tensors.update({"dipole": dipole})
return dict_of_tensors
def _atoms_to_batch(self, atoms):
keyspec = data.KeySpecification(
info_keys={}, arrays_keys={"charges": self.charges_key}
)
config = data.config_from_atoms(
atoms, key_specification=keyspec, head_name=self.head
)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.available_heads,
)
],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader)).to(self.device)
return batch
def _clone_batch(self, batch):
batch_clone = batch.clone()
if self.use_compile:
batch_clone["node_attrs"].requires_grad_(True)
batch_clone["positions"].requires_grad_(True)
return batch_clone
# pylint: disable=dangerous-default-value
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
"""
Calculate properties.
:param atoms: ase.Atoms object
:param properties: [str], properties to be computed, used by ASE internally
:param system_changes: [str], system changes since last calculation, used by ASE internally
:return:
"""
# call to base-class to set atoms attribute
calc_start_t = time.perf_counter()
Calculator.calculate(self, atoms)
batch_base = self._atoms_to_batch(atoms)
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
batch = self._clone_batch(batch_base)
node_heads = batch["head"][batch["batch"]]
num_atoms_arange = torch.arange(batch["positions"].shape[0])
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
num_atoms_arange, node_heads
]
compute_stress = not self.use_compile
else:
compute_stress = False
ret_tensors = self._create_result_tensors(
self.model_type, self.num_models, len(atoms)
)
for i, model in enumerate(self.models):
batch = self._clone_batch(batch_base)
# print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}')
# set_seed(0)
model_start_t = time.perf_counter()
out = model(
batch.to_dict(),
compute_stress=compute_stress,
training=self.use_compile,
compute_edge_forces=self.compute_atomic_stresses,
compute_atomic_stresses=self.compute_atomic_stresses,
)
model_end_t = time.perf_counter()
self.model_time += (model_end_t - model_start_t)
# print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.stress: {batch["stress"]}')
# print(f'compute_stress: {compute_stress}')
# for k,v in batch.to_dict().items():
# print(f'&&& batch.to_dict(): {k} {v}')
# print("=======")
# print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}')
# print(f'@@@File: {__file__}, out: {out}')
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
ret_tensors["energies"][i] = out["energy"].detach()
ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach()
ret_tensors["forces"][i] = out["forces"].detach()
if out["stress"] is not None:
ret_tensors["stress"][i] = out["stress"].detach()
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
ret_tensors["dipole"][i] = out["dipole"].detach()
if self.model_type in ["MACE"]:
if out["atomic_stresses"] is not None:
ret_tensors.setdefault("atomic_stresses", []).append(
out["atomic_stresses"].detach()
)
if out["atomic_virials"] is not None:
ret_tensors.setdefault("atomic_virials", []).append(
out["atomic_virials"].detach()
)
self.results = {}
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
self.results["energy"] = (
torch.mean(ret_tensors["energies"], dim=0).cpu().item()
* self.energy_units_to_eV
)
self.results["free_energy"] = self.results["energy"]
self.results["node_energy"] = (
torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy()
)
self.results["forces"] = (
torch.mean(ret_tensors["forces"], dim=0).cpu().numpy()
* self.energy_units_to_eV
/ self.length_units_to_A
)
if self.num_models > 1:
self.results["energies"] = (
ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV
)
self.results["energy_var"] = (
torch.var(ret_tensors["energies"], dim=0, unbiased=False)
.cpu()
.item()
* self.energy_units_to_eV
)
self.results["forces_comm"] = (
ret_tensors["forces"].cpu().numpy()
* self.energy_units_to_eV
/ self.length_units_to_A
)
if out["stress"] is not None:
self.results["stress"] = full_3x3_to_voigt_6_stress(
torch.mean(ret_tensors["stress"], dim=0).cpu().numpy()
* self.energy_units_to_eV
/ self.length_units_to_A**3
)
if self.num_models > 1:
self.results["stress_var"] = full_3x3_to_voigt_6_stress(
torch.var(ret_tensors["stress"], dim=0, unbiased=False)
.cpu()
.numpy()
* self.energy_units_to_eV
/ self.length_units_to_A**3
)
if "atomic_stresses" in ret_tensors:
self.results["stresses"] = (
torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0)
.cpu()
.numpy()
* self.energy_units_to_eV
/ self.length_units_to_A**3
)
if "atomic_virials" in ret_tensors:
self.results["virials"] = (
torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0)
.cpu()
.numpy()
* self.energy_units_to_eV
)
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
self.results["dipole"] = (
torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy()
)
if self.num_models > 1:
self.results["dipole_var"] = (
torch.var(ret_tensors["dipole"], dim=0, unbiased=False)
.cpu()
.numpy()
)
calc_end_t = time.perf_counter()
self.calc_time += (calc_end_t - calc_start_t)
def get_hessian(self, atoms=None):
if atoms is None and self.atoms is None:
raise ValueError("atoms not set")
if atoms is None:
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
batch = self._atoms_to_batch(atoms)
hessians = [
model(
self._clone_batch(batch).to_dict(),
compute_hessian=True,
compute_stress=False,
training=self.use_compile,
)["hessian"]
for model in self.models
]
hessians = [hessian.detach().cpu().numpy() for hessian in hessians]
if self.num_models == 1:
return hessians[0]
return hessians
def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
"""Extracts the descriptors from MACE model.
:param atoms: ase.Atoms object
:param invariants_only: bool, if True only the invariant descriptors are returned
:param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used
:return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise
"""
if atoms is None and self.atoms is None:
raise ValueError("atoms not set")
if atoms is None:
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
num_interactions = int(self.models[0].num_interactions)
if num_layers == -1:
num_layers = num_interactions
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]
irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out))
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer
)
if invariants_only:
descriptors = [
extract_invariant(
descriptor,
num_layers=num_layers,
num_features=num_invariant_features,
l_max=l_max,
)
for descriptor in descriptors
]
to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
]
if self.num_models == 1:
return descriptors[0]
return descriptors
def predict(self, atoms_list, compute_stress=False):
predictions = {'energy': [], 'forces': []}
configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
)
for config in configs
],
batch_size=len(atoms_list),
shuffle=False,
drop_last=False,
)
# get the first batch of data_loader
batch_base = next(iter(data_loader)).to(self.device)
# calculate node_e0
# batch = self._clone_batch(batch_base)
# node_heads = batch["head"][batch["batch"]]
# num_atoms_arange = torch.arange(batch["positions"].shape[0])
# node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
# num_atoms_arange, node_heads
# ]
# set_seed(0)
out = self.models[0](
batch_base.to_dict(),
compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS?
training=self.use_compile,
)
# print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.stress: {batch["stress"]}')
# print(f'&&& batch.to_dict(): {k} {v}')
# print("=======")
# print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}')
predictions["energy"] = out["energy"].unsqueeze(-1).detach()
predictions["forces"] = out["forces"].detach()
if compute_stress:
predictions["stress"] = out["stress"].detach()
# print(f'&&& predictions["forces"] in predict: {predictions["forces"]}')
return predictions
def fast_predict(self, gbatch, compute_stress=False):
gbatch.pos = gbatch.pos.to(self.dtype)
gbatch.cell = gbatch.cell.to(self.dtype)
predictions = {'energy': [], 'forces': []}
batch_base = self.convert_batch(gbatch)
out = self.models[0](
batch_base.to_dict(),
compute_stress=compute_stress,
training=self.use_compile,
)
predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64)
predictions["forces"] = out["forces"].detach().to(torch.float64)
if compute_stress:
predictions["stress"] = out["stress"].detach().to(torch.float64)
return predictions
def convert_batch(self, gbatch):
from batchopt import radius_graph_pbc_cuda
# edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi(
# from batchopt.pbc_graph_legacy import radius_graph_pbc
# edge_indices, cell_offsets, num_neighbors = radius_graph_pbc(
edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda(
gbatch,
radius=4.5,
max_num_neighbors_threshold=float('inf'),
pbc=[True, True, True],
dtype=self.dtype
)
tmp = edge_indices[0].clone()
edge_indices[0] = edge_indices[1]
edge_indices[1] = tmp
# Create a one-hot matrix with number of columns equal to max atomic number + 1
indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table)
one_hot = to_one_hot(
torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
num_classes=len(self.z_table),
).to(self.device)
cbatch = Batch(
positions = gbatch["pos"].clone(),
cell = gbatch["cell"].view(-1, 3),
batch = gbatch["batch"],
ptr = gbatch["ptr"],
edge_index = edge_indices,
unit_shifts = cell_offsets,
node_attrs = one_hot,
)
return cbatch
def predict_debug(self, atoms_list, gbatch, compute_stress=False):
predictions = {'energy': [], 'forces': []}
configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
)
for config in configs
],
batch_size=len(atoms_list),
shuffle=False,
drop_last=False,
)
# get the first batch of data_loader
# batch_base = next(iter(data_loader)).to(self.device)
batch_base_tmp = next(iter(data_loader)).to(self.device)
batch2 = self.convert_batch(gbatch)
batch_base = Batch(
# positions = batch_base_tmp["positions"],
positions = batch2["positions"],
# node_attrs = batch_base_tmp["node_attrs"],
node_attrs = batch2["node_attrs"],
# cell = batch_base_tmp["cell"],
cell = batch2["cell"],
edge_index = batch2["edge_index"],
unit_shifts = batch2["unit_shifts"],
# batch = batch_base_tmp["batch"],
batch = batch2["batch"],
# ptr = batch_base_tmp["ptr"],
ptr = batch2["ptr"],
)
torch.set_printoptions(threshold=float('inf'))
# print(f'batch2["edge_index"]: {batch2["edge_index"]}')
# print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}')
# print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}')
# print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}')
# calculate node_e0
# batch = self._clone_batch(batch_base)
# node_heads = batch["head"][batch["batch"]]
# num_atoms_arange = torch.arange(batch["positions"].shape[0])
# node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
# num_atoms_arange, node_heads
# ]
# set_seed(0)
out = self.models[0](
batch_base.to_dict(),
compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS?
training=self.use_compile,
)
# print(f'&&& batch.positions: {batch["positions"]}')
# print(f'&&& batch.cell: {batch["cell"]}')
# print(f'&&& batch.stress: {batch["stress"]}')
# for k,v in batch.to_dict().items():
# print(f'&&& batch.to_dict(): {k} {v}')
# print("=======")
# print(f'&&& out["forces"]: {out["forces"]}')
# print(f'&&& training: {self.use_compile}')
predictions["energy"] = out["energy"].unsqueeze(-1).detach()
predictions["forces"] = out["forces"].detach()
if compute_stress:
predictions["stress"] = out["stress"].detach()
# print(f'&&& predictions["forces"] in predict: {predictions["forces"]}')
return predictions
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment