Unverified Commit d47e8ba0 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Delete data directory

parent 92733cc7
from .atomic_data import AtomicData
from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5
from .lmdb_dataset import LMDBDataset
from .neighborhood import get_neighborhood
from .utils import (
Configuration,
Configurations,
KeySpecification,
compute_average_E0s,
config_from_atoms,
config_from_atoms_list,
load_from_xyz,
random_train_valid_split,
save_AtomicData_to_HDF5,
save_configurations_as_HDF5,
save_dataset_as_HDF5,
test_config_types,
update_keyspec_from_kwargs,
)
__all__ = [
"get_neighborhood",
"Configuration",
"Configurations",
"random_train_valid_split",
"load_from_xyz",
"test_config_types",
"config_from_atoms",
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
"save_dataset_as_HDF5",
"HDF5Dataset",
"dataset_from_sharded_hdf5",
"save_AtomicData_to_HDF5",
"save_configurations_as_HDF5",
"KeySpecification",
"update_keyspec_from_kwargs",
"LMDBDataset",
]
###########################################################################################
# Atomic Data Class for handling molecules as graphs
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from copy import deepcopy
from typing import Optional, Sequence
import torch.utils.data
from mace.tools import (
AtomicNumberTable,
atomic_numbers_to_indices,
to_one_hot,
torch_geometric,
voigt_to_matrix,
)
from .neighborhood import get_neighborhood
from .utils import Configuration
class AtomicData(torch_geometric.data.Data):
num_graphs: torch.Tensor
batch: torch.Tensor
edge_index: torch.Tensor
node_attrs: torch.Tensor
edge_vectors: torch.Tensor
edge_lengths: torch.Tensor
positions: torch.Tensor
shifts: torch.Tensor
unit_shifts: torch.Tensor
cell: torch.Tensor
forces: torch.Tensor
energy: torch.Tensor
stress: torch.Tensor
virials: torch.Tensor
dipole: torch.Tensor
charges: torch.Tensor
weight: torch.Tensor
energy_weight: torch.Tensor
forces_weight: torch.Tensor
stress_weight: torch.Tensor
virials_weight: torch.Tensor
dipole_weight: torch.Tensor
charges_weight: torch.Tensor
def __init__(
self,
edge_index: torch.Tensor, # [2, n_edges]
node_attrs: torch.Tensor, # [n_nodes, n_node_feats]
positions: torch.Tensor, # [n_nodes, 3]
shifts: torch.Tensor, # [n_edges, 3],
unit_shifts: torch.Tensor, # [n_edges, 3]
cell: Optional[torch.Tensor], # [3,3]
weight: Optional[torch.Tensor], # [,]
head: Optional[torch.Tensor], # [,]
energy_weight: Optional[torch.Tensor], # [,]
forces_weight: Optional[torch.Tensor], # [,]
stress_weight: Optional[torch.Tensor], # [,]
virials_weight: Optional[torch.Tensor], # [,]
dipole_weight: Optional[torch.Tensor], # [,]
charges_weight: Optional[torch.Tensor], # [,]
forces: Optional[torch.Tensor], # [n_nodes, 3]
energy: Optional[torch.Tensor], # [, ]
stress: Optional[torch.Tensor], # [1,3,3]
virials: Optional[torch.Tensor], # [1,3,3]
dipole: Optional[torch.Tensor], # [, 3]
charges: Optional[torch.Tensor], # [n_nodes, ]
):
# Check shapes
num_nodes = node_attrs.shape[0]
assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
assert positions.shape == (num_nodes, 3)
assert shifts.shape[1] == 3
assert unit_shifts.shape[1] == 3
assert len(node_attrs.shape) == 2
assert weight is None or len(weight.shape) == 0
assert head is None or len(head.shape) == 0
assert energy_weight is None or len(energy_weight.shape) == 0
assert forces_weight is None or len(forces_weight.shape) == 0
assert stress_weight is None or len(stress_weight.shape) == 0
assert virials_weight is None or len(virials_weight.shape) == 0
assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight
assert charges_weight is None or len(charges_weight.shape) == 0
assert cell is None or cell.shape == (3, 3)
assert forces is None or forces.shape == (num_nodes, 3)
assert energy is None or len(energy.shape) == 0
assert stress is None or stress.shape == (1, 3, 3)
assert virials is None or virials.shape == (1, 3, 3)
assert dipole is None or dipole.shape[-1] == 3
assert charges is None or charges.shape == (num_nodes,)
# Aggregate data
data = {
"num_nodes": num_nodes,
"edge_index": edge_index,
"positions": positions,
"shifts": shifts,
"unit_shifts": unit_shifts,
"cell": cell,
"node_attrs": node_attrs,
"weight": weight,
"head": head,
"energy_weight": energy_weight,
"forces_weight": forces_weight,
"stress_weight": stress_weight,
"virials_weight": virials_weight,
"dipole_weight": dipole_weight,
"charges_weight": charges_weight,
"forces": forces,
"energy": energy,
"stress": stress,
"virials": virials,
"dipole": dipole,
"charges": charges,
}
super().__init__(**data)
@classmethod
def from_config(
cls,
config: Configuration,
z_table: AtomicNumberTable,
cutoff: float,
heads: Optional[list] = None,
**kwargs, # pylint: disable=unused-argument
) -> "AtomicData":
if heads is None:
heads = ["Default"]
edge_index, shifts, unit_shifts, cell = get_neighborhood(
positions=config.positions,
cutoff=cutoff,
pbc=deepcopy(config.pbc),
cell=deepcopy(config.cell),
)
indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table)
one_hot = to_one_hot(
torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
num_classes=len(z_table),
)
try:
head = torch.tensor(heads.index(config.head), dtype=torch.long)
except ValueError:
head = torch.tensor(len(heads) - 1, dtype=torch.long)
cell = (
torch.tensor(cell, dtype=torch.get_default_dtype())
if cell is not None
else torch.tensor(
3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype()
).view(3, 3)
)
num_atoms = len(config.atomic_numbers)
weight = (
torch.tensor(config.weight, dtype=torch.get_default_dtype())
if config.weight is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)
energy_weight = (
torch.tensor(
config.property_weights.get("energy"), dtype=torch.get_default_dtype()
)
if config.property_weights.get("energy") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)
forces_weight = (
torch.tensor(
config.property_weights.get("forces"), dtype=torch.get_default_dtype()
)
if config.property_weights.get("forces") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)
stress_weight = (
torch.tensor(
config.property_weights.get("stress"), dtype=torch.get_default_dtype()
)
if config.property_weights.get("stress") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)
virials_weight = (
torch.tensor(
config.property_weights.get("virials"), dtype=torch.get_default_dtype()
)
if config.property_weights.get("virials") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)
dipole_weight = (
torch.tensor(
config.property_weights.get("dipole"), dtype=torch.get_default_dtype()
)
if config.property_weights.get("dipole") is not None
else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype())
)
if len(dipole_weight.shape) == 0:
dipole_weight = dipole_weight * torch.tensor(
[[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()
)
elif len(dipole_weight.shape) == 1:
dipole_weight = dipole_weight.unsqueeze(0)
charges_weight = (
torch.tensor(
config.property_weights.get("charges"), dtype=torch.get_default_dtype()
)
if config.property_weights.get("charges") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)
forces = (
torch.tensor(
config.properties.get("forces"), dtype=torch.get_default_dtype()
)
if config.properties.get("forces") is not None
else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype())
)
energy = (
torch.tensor(
config.properties.get("energy"), dtype=torch.get_default_dtype()
)
if config.properties.get("energy") is not None
else torch.tensor(0.0, dtype=torch.get_default_dtype())
)
stress = (
voigt_to_matrix(
torch.tensor(
config.properties.get("stress"), dtype=torch.get_default_dtype()
)
).unsqueeze(0)
if config.properties.get("stress") is not None
else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype())
)
virials = (
voigt_to_matrix(
torch.tensor(
config.properties.get("virials"), dtype=torch.get_default_dtype()
)
).unsqueeze(0)
if config.properties.get("virials") is not None
else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype())
)
dipole = (
torch.tensor(
config.properties.get("dipole"), dtype=torch.get_default_dtype()
).unsqueeze(0)
if config.properties.get("dipole") is not None
else torch.zeros(1, 3, dtype=torch.get_default_dtype())
)
charges = (
torch.tensor(
config.properties.get("charges"), dtype=torch.get_default_dtype()
)
if config.properties.get("charges") is not None
else torch.zeros(num_atoms, dtype=torch.get_default_dtype())
)
return cls(
edge_index=torch.tensor(edge_index, dtype=torch.long),
positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()),
shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()),
unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()),
cell=cell,
node_attrs=one_hot,
weight=weight,
head=head,
energy_weight=energy_weight,
forces_weight=forces_weight,
stress_weight=stress_weight,
virials_weight=virials_weight,
dipole_weight=dipole_weight,
charges_weight=charges_weight,
forces=forces,
energy=energy,
stress=stress,
virials=virials,
dipole=dipole,
charges=charges,
)
def get_data_loader(
dataset: Sequence[AtomicData],
batch_size: int,
shuffle=True,
drop_last=False,
) -> torch.utils.data.DataLoader:
return torch_geometric.dataloader.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
from glob import glob
from typing import List
import h5py
from torch.utils.data import ConcatDataset, Dataset
from mace.data.atomic_data import AtomicData
from mace.data.utils import Configuration
from mace.tools.utils import AtomicNumberTable
class HDF5Dataset(Dataset):
def __init__(
self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs
):
super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments
self.file_path = file_path
self._file = None
batch_key = list(self.file.keys())[0]
self.batch_size = len(self.file[batch_key].keys())
self.length = len(self.file.keys()) * self.batch_size
self.r_max = r_max
self.z_table = z_table
self.atomic_dataclass = atomic_dataclass
try:
self.drop_last = bool(self.file.attrs["drop_last"])
except KeyError:
self.drop_last = False
self.kwargs = kwargs
@property
def file(self):
if self._file is None:
# If a file has not already been opened, open one here
self._file = h5py.File(self.file_path, "r")
return self._file
def __getstate__(self):
_d = dict(self.__dict__)
# An opened h5py.File cannot be pickled, so we must exclude it from the state
_d["_file"] = None
return _d
def __len__(self):
return self.length
def __getitem__(self, index):
# compute the index of the batch
batch_index = index // self.batch_size
config_index = index % self.batch_size
grp = self.file["config_batch_" + str(batch_index)]
subgrp = grp["config_" + str(config_index)]
properties = {}
property_weights = {}
for key in subgrp["properties"]:
properties[key] = unpack_value(subgrp["properties"][key][()])
for key in subgrp["property_weights"]:
property_weights[key] = unpack_value(subgrp["property_weights"][key][()])
config = Configuration(
atomic_numbers=subgrp["atomic_numbers"][()],
positions=subgrp["positions"][()],
properties=properties,
weight=unpack_value(subgrp["weight"][()]),
property_weights=property_weights,
config_type=unpack_value(subgrp["config_type"][()]),
pbc=unpack_value(subgrp["pbc"][()]),
cell=unpack_value(subgrp["cell"][()]),
)
if config.head is None:
config.head = self.kwargs.get("head")
atomic_data = self.atomic_dataclass.from_config(
config,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.kwargs.get("heads", ["Default"]),
**{k: v for k, v in self.kwargs.items() if k != "heads"},
)
return atomic_data
def dataset_from_sharded_hdf5(
files: List, z_table: AtomicNumberTable, r_max: float, **kwargs
):
files = glob(files + "/*")
datasets = []
for file in files:
datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs))
full_dataset = ConcatDataset(datasets)
return full_dataset
def unpack_value(value):
value = value.decode("utf-8") if isinstance(value, bytes) else value
return None if str(value) == "None" else value
import os
import numpy as np
from torch.utils.data import Dataset
from mace.data.atomic_data import AtomicData
from mace.data.utils import KeySpecification, config_from_atoms
from mace.tools.default_keys import DefaultKeys
from mace.tools.fairchem_dataset import AseDBDataset
class LMDBDataset(Dataset):
def __init__(self, file_path, r_max, z_table, **kwargs):
dataset_paths = file_path.split(":") # using : split multiple paths
# make sure each of the path exist
for path in dataset_paths:
assert os.path.exists(path)
config_kwargs = {}
super(LMDBDataset, self).__init__() # pylint: disable=super-with-arguments
self.AseDB = AseDBDataset(config=dict(src=dataset_paths, **config_kwargs))
self.r_max = r_max
self.z_table = z_table
self.kwargs = kwargs
self.transform = kwargs["transform"] if "transform" in kwargs else None
def __len__(self):
return len(self.AseDB)
def __getitem__(self, index):
try:
atoms = self.AseDB.get_atoms(self.AseDB.ids[index])
except Exception as e: # pylint: disable=broad-except
print(f"Error in index {index}")
print(e)
return None
assert np.sum(atoms.get_cell() == atoms.cell) == 9
if hasattr(atoms, "calc") and hasattr(atoms.calc, "results"):
if "energy" in atoms.calc.results:
atoms.info[DefaultKeys.ENERGY.value] = atoms.calc.results["energy"]
if "forces" in atoms.calc.results:
atoms.arrays[DefaultKeys.FORCES.value] = atoms.calc.results["forces"]
if "stress" in atoms.calc.results:
atoms.info[DefaultKeys.STRESS.value] = atoms.calc.results["stress"]
config = config_from_atoms(
atoms,
key_specification=KeySpecification.from_defaults(),
)
# Set head if not already set
if config.head == "Default":
config.head = self.kwargs.get("head", "Default")
try:
atomic_data = AtomicData.from_config(
config,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.kwargs.get("heads", ["Default"]),
)
except Exception as e: # pylint: disable=broad-except
print(f"Error in index {index}")
print(e)
if self.transform:
atomic_data = self.transform(atomic_data)
return atomic_data
from typing import Optional, Tuple
import numpy as np
from matscipy.neighbours import neighbour_list
def get_neighborhood(
positions: np.ndarray, # [num_positions, 3]
cutoff: float,
pbc: Optional[Tuple[bool, bool, bool]] = None,
cell: Optional[np.ndarray] = None, # [3, 3]
true_self_interaction=False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if pbc is None:
pbc = (False, False, False)
if cell is None or cell.any() == np.zeros((3, 3)).any():
cell = np.identity(3, dtype=float)
assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc)
assert cell.shape == (3, 3)
pbc_x = pbc[0]
pbc_y = pbc[1]
pbc_z = pbc[2]
identity = np.identity(3, dtype=float)
max_positions = np.max(np.absolute(positions)) + 1
# Extend cell in non-periodic directions
# For models with more than 5 layers, the multiplicative constant needs to be increased.
# temp_cell = np.copy(cell)
if not pbc_x:
cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
if not pbc_y:
cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
if not pbc_z:
cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
sender, receiver, unit_shifts = neighbour_list(
quantities="ijS",
pbc=pbc,
cell=cell,
positions=positions,
cutoff=cutoff,
# self_interaction=True, # we want edges from atom to itself in different periodic images
# use_scaled_positions=False, # positions are not scaled positions
)
if not true_self_interaction:
# Eliminate self-edges that don't cross periodic boundaries
true_self_edge = sender == receiver
true_self_edge &= np.all(unit_shifts == 0, axis=1)
keep_edge = ~true_self_edge
# Note: after eliminating self-edges, it can be that no edges remain in this system
sender = sender[keep_edge]
receiver = receiver[keep_edge]
unit_shifts = unit_shifts[keep_edge]
# Build output
edge_index = np.stack((sender, receiver)) # [2, n_edges]
# From the docs: With the shift vector S, the distances D between atoms can be computed from
# D = positions[j]-positions[i]+S.dot(cell)
shifts = np.dot(unit_shifts, cell) # [n_edges, 3]
return edge_index, shifts, unit_shifts, cell
###########################################################################################
# Data parsing utilities
# Authors: Ilyes Batatia, Gregor Simm and David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple
import ase.data
import ase.io
import h5py
import numpy as np
from mace.tools import AtomicNumberTable, DefaultKeys
Positions = np.ndarray # [..., 3]
Cell = np.ndarray # [3,3]
Pbc = tuple # (3,)
DEFAULT_CONFIG_TYPE = "Default"
DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0}
@dataclass
class KeySpecification:
info_keys: Dict[str, str] = field(default_factory=dict)
arrays_keys: Dict[str, str] = field(default_factory=dict)
def update(
self,
info_keys: Optional[Dict[str, str]] = None,
arrays_keys: Optional[Dict[str, str]] = None,
):
if info_keys is not None:
self.info_keys.update(info_keys)
if arrays_keys is not None:
self.arrays_keys.update(arrays_keys)
return self
@classmethod
def from_defaults(cls):
instance = cls()
return update_keyspec_from_kwargs(instance, DefaultKeys.keydict())
def update_keyspec_from_kwargs(
keyspec: KeySpecification, keydict: Dict[str, str]
) -> KeySpecification:
# convert command line style property_key arguments into a keyspec
infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"]
arrays = ["forces_key", "charges_key"]
info_keys = {}
arrays_keys = {}
for key in infos:
if key in keydict:
info_keys[key[:-4]] = keydict[key]
for key in arrays:
if key in keydict:
arrays_keys[key[:-4]] = keydict[key]
keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys)
return keyspec
@dataclass
class Configuration:
atomic_numbers: np.ndarray
positions: Positions # Angstrom
properties: Dict[str, Any]
property_weights: Dict[str, float]
cell: Optional[Cell] = None
pbc: Optional[Pbc] = None
weight: float = 1.0 # weight of config in loss
config_type: str = DEFAULT_CONFIG_TYPE # config_type of config
head: str = "Default" # head used to compute the config
Configurations = List[Configuration]
def random_train_valid_split(
items: Sequence, valid_fraction: float, seed: int, work_dir: str
) -> Tuple[List, List]:
assert 0.0 < valid_fraction < 1.0
size = len(items)
train_size = size - int(valid_fraction * size)
indices = list(range(size))
rng = np.random.default_rng(seed)
rng.shuffle(indices)
if len(indices[train_size:]) < 10:
logging.info(
f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}"
)
else:
# Save indices to file
with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f:
for index in indices[train_size:]:
f.write(f"{index}\n")
logging.info(
f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt"
)
return (
[items[i] for i in indices[:train_size]],
[items[i] for i in indices[train_size:]],
)
def config_from_atoms_list(
atoms_list: List[ase.Atoms],
key_specification: KeySpecification,
config_type_weights: Optional[Dict[str, float]] = None,
head_name: str = "Default",
) -> Configurations:
"""Convert list of ase.Atoms into Configurations"""
if config_type_weights is None:
config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS
all_configs = []
for atoms in atoms_list:
all_configs.append(
config_from_atoms(
atoms,
key_specification=key_specification,
config_type_weights=config_type_weights,
head_name=head_name,
)
)
return all_configs
def config_from_atoms(
atoms: ase.Atoms,
key_specification: KeySpecification = KeySpecification(),
config_type_weights: Optional[Dict[str, float]] = None,
head_name: str = "Default",
) -> Configuration:
"""Convert ase.Atoms to Configuration"""
if config_type_weights is None:
config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS
atomic_numbers = np.array(
[ase.data.atomic_numbers[symbol] for symbol in atoms.symbols]
)
pbc = tuple(atoms.get_pbc())
cell = np.array(atoms.get_cell())
config_type = atoms.info.get("config_type", "Default")
weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get(
config_type, 1.0
)
properties = {}
property_weights = {}
for name in list(key_specification.arrays_keys) + list(key_specification.info_keys):
property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0)
for name, atoms_key in key_specification.info_keys.items():
properties[name] = atoms.info.get(atoms_key, None)
if not atoms_key in atoms.info:
property_weights[name] = 0.0
for name, atoms_key in key_specification.arrays_keys.items():
properties[name] = atoms.arrays.get(atoms_key, None)
if not atoms_key in atoms.arrays:
property_weights[name] = 0.0
return Configuration(
atomic_numbers=atomic_numbers,
positions=atoms.get_positions(),
properties=properties,
weight=weight,
property_weights=property_weights,
head=head_name,
config_type=config_type,
pbc=pbc,
cell=cell,
)
def test_config_types(
test_configs: Configurations,
) -> List[Tuple[str, List[Configuration]]]:
"""Split test set based on config_type-s"""
test_by_ct = []
all_cts = []
for conf in test_configs:
config_type_name = conf.config_type + "_" + conf.head
if config_type_name not in all_cts:
all_cts.append(config_type_name)
test_by_ct.append((config_type_name, [conf]))
else:
ind = all_cts.index(config_type_name)
test_by_ct[ind][1].append(conf)
return test_by_ct
def load_from_xyz(
file_path: str,
key_specification: KeySpecification,
head_name: str = "Default",
config_type_weights: Optional[Dict] = None,
extract_atomic_energies: bool = False,
keep_isolated_atoms: bool = False,
) -> Tuple[Dict[int, float], Configurations]:
atoms_list = ase.io.read(file_path, index=":")
energy_key = key_specification.info_keys["energy"]
forces_key = key_specification.arrays_keys["forces"]
stress_key = key_specification.info_keys["stress"]
head_key = key_specification.info_keys["head"]
if energy_key == "energy":
logging.warning(
"Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name."
)
key_specification.info_keys["energy"] = "REF_energy"
for atoms in atoms_list:
try:
atoms.info["REF_energy"] = atoms.get_potential_energy()
except Exception as e: # pylint: disable=W0703
logging.error(f"Failed to extract energy: {e}")
atoms.info["REF_energy"] = None
if forces_key == "forces":
logging.warning(
"Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name."
)
key_specification.arrays_keys["forces"] = "REF_forces"
for atoms in atoms_list:
try:
atoms.arrays["REF_forces"] = atoms.get_forces()
except Exception as e: # pylint: disable=W0703
logging.error(f"Failed to extract forces: {e}")
atoms.arrays["REF_forces"] = None
if stress_key == "stress":
logging.warning(
"Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name."
)
key_specification.info_keys["stress"] = "REF_stress"
for atoms in atoms_list:
try:
atoms.info["REF_stress"] = atoms.get_stress()
except Exception as e: # pylint: disable=W0703
atoms.info["REF_stress"] = None
if not isinstance(atoms_list, list):
atoms_list = [atoms_list]
atomic_energies_dict = {}
if extract_atomic_energies:
atoms_without_iso_atoms = []
for idx, atoms in enumerate(atoms_list):
atoms.info[head_key] = head_name
isolated_atom_config = (
len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom"
)
if isolated_atom_config:
atomic_number = int(atoms.get_atomic_numbers()[0])
if energy_key in atoms.info.keys():
atomic_energies_dict[atomic_number] = float(atoms.info[energy_key])
else:
logging.warning(
f"Configuration '{idx}' is marked as 'IsolatedAtom' "
"but does not contain an energy. Zero energy will be used."
)
atomic_energies_dict[atomic_number] = 0.0
else:
atoms_without_iso_atoms.append(atoms)
if len(atomic_energies_dict) > 0:
logging.info("Using isolated atom energies from training file")
if not keep_isolated_atoms:
atoms_list = atoms_without_iso_atoms
for atoms in atoms_list:
atoms.info[head_key] = head_name
configs = config_from_atoms_list(
atoms_list,
config_type_weights=config_type_weights,
key_specification=key_specification,
head_name=head_name,
)
return atomic_energies_dict, configs
def compute_average_E0s(
collections_train: Configurations, z_table: AtomicNumberTable
) -> Dict[int, float]:
"""
Function to compute the average interaction energy of each chemical element
returns dictionary of E0s
"""
len_train = len(collections_train)
len_zs = len(z_table)
A = np.zeros((len_train, len_zs))
B = np.zeros(len_train)
for i in range(len_train):
B[i] = collections_train[i].properties["energy"]
for j, z in enumerate(z_table.zs):
A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z)
try:
E0s = np.linalg.lstsq(A, B, rcond=None)[0]
atomic_energies_dict = {}
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = E0s[i]
except np.linalg.LinAlgError:
logging.error(
"Failed to compute E0s using least squares regression, using the same for all atoms"
)
atomic_energies_dict = {}
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = 0.0
return atomic_energies_dict
def save_dataset_as_HDF5(dataset: List, out_name: str) -> None:
with h5py.File(out_name, "w") as f:
for i, data in enumerate(dataset):
save_AtomicData_to_HDF5(data, i, f)
def save_AtomicData_to_HDF5(data, i, h5_file) -> None:
grp = h5_file.create_group(f"config_{i}")
grp["num_nodes"] = data.num_nodes
grp["edge_index"] = data.edge_index
grp["positions"] = data.positions
grp["shifts"] = data.shifts
grp["unit_shifts"] = data.unit_shifts
grp["cell"] = data.cell
grp["node_attrs"] = data.node_attrs
grp["weight"] = data.weight
grp["energy_weight"] = data.energy_weight
grp["forces_weight"] = data.forces_weight
grp["stress_weight"] = data.stress_weight
grp["virials_weight"] = data.virials_weight
grp["forces"] = data.forces
grp["energy"] = data.energy
grp["stress"] = data.stress
grp["virials"] = data.virials
grp["dipole"] = data.dipole
grp["charges"] = data.charges
grp["head"] = data.head
def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None:
grp = h5_file.create_group("config_batch_0")
for j, config in enumerate(configurations):
subgroup_name = f"config_{j}"
subgroup = grp.create_group(subgroup_name)
subgroup["atomic_numbers"] = write_value(config.atomic_numbers)
subgroup["positions"] = write_value(config.positions)
properties_subgrp = subgroup.create_group("properties")
for key, value in config.properties.items():
properties_subgrp[key] = write_value(value)
subgroup["cell"] = write_value(config.cell)
subgroup["pbc"] = write_value(config.pbc)
subgroup["weight"] = write_value(config.weight)
weights_subgrp = subgroup.create_group("property_weights")
for key, value in config.property_weights.items():
weights_subgrp[key] = write_value(value)
subgroup["config_type"] = write_value(config.config_type)
def write_value(value):
return value if value is not None else "None"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment