Commit 2409a22f authored by fanding2000's avatar fanding2000
Browse files

Format fix. More options in readme

parent ce29afea
########################################################################################### ###########################################################################################
# Atomic Data Class for handling molecules as graphs # Atomic Data Class for handling molecules as graphs
# Authors: Ilyes Batatia, Gregor Simm # Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md) # This program is distributed under the MIT License (see MIT.md)
########################################################################################### ###########################################################################################
from copy import deepcopy from copy import deepcopy
from typing import Optional, Sequence from typing import Optional, Sequence
import torch.utils.data import torch.utils.data
from mace.tools import ( from mace.tools import (
AtomicNumberTable, AtomicNumberTable,
atomic_numbers_to_indices, atomic_numbers_to_indices,
to_one_hot, to_one_hot,
torch_geometric, torch_geometric,
voigt_to_matrix, voigt_to_matrix,
) )
from .neighborhood import get_neighborhood from .neighborhood import get_neighborhood
from .utils import Configuration from .utils import Configuration
class AtomicData(torch_geometric.data.Data): class AtomicData(torch_geometric.data.Data):
num_graphs: torch.Tensor num_graphs: torch.Tensor
batch: torch.Tensor batch: torch.Tensor
edge_index: torch.Tensor edge_index: torch.Tensor
node_attrs: torch.Tensor node_attrs: torch.Tensor
edge_vectors: torch.Tensor edge_vectors: torch.Tensor
edge_lengths: torch.Tensor edge_lengths: torch.Tensor
positions: torch.Tensor positions: torch.Tensor
shifts: torch.Tensor shifts: torch.Tensor
unit_shifts: torch.Tensor unit_shifts: torch.Tensor
cell: torch.Tensor cell: torch.Tensor
forces: torch.Tensor forces: torch.Tensor
energy: torch.Tensor energy: torch.Tensor
stress: torch.Tensor stress: torch.Tensor
virials: torch.Tensor virials: torch.Tensor
dipole: torch.Tensor dipole: torch.Tensor
charges: torch.Tensor charges: torch.Tensor
weight: torch.Tensor weight: torch.Tensor
energy_weight: torch.Tensor energy_weight: torch.Tensor
forces_weight: torch.Tensor forces_weight: torch.Tensor
stress_weight: torch.Tensor stress_weight: torch.Tensor
virials_weight: torch.Tensor virials_weight: torch.Tensor
dipole_weight: torch.Tensor dipole_weight: torch.Tensor
charges_weight: torch.Tensor charges_weight: torch.Tensor
def __init__( def __init__(
self, self,
edge_index: torch.Tensor, # [2, n_edges] edge_index: torch.Tensor, # [2, n_edges]
node_attrs: torch.Tensor, # [n_nodes, n_node_feats] node_attrs: torch.Tensor, # [n_nodes, n_node_feats]
positions: torch.Tensor, # [n_nodes, 3] positions: torch.Tensor, # [n_nodes, 3]
shifts: torch.Tensor, # [n_edges, 3], shifts: torch.Tensor, # [n_edges, 3],
unit_shifts: torch.Tensor, # [n_edges, 3] unit_shifts: torch.Tensor, # [n_edges, 3]
cell: Optional[torch.Tensor], # [3,3] cell: Optional[torch.Tensor], # [3,3]
weight: Optional[torch.Tensor], # [,] weight: Optional[torch.Tensor], # [,]
head: Optional[torch.Tensor], # [,] head: Optional[torch.Tensor], # [,]
energy_weight: Optional[torch.Tensor], # [,] energy_weight: Optional[torch.Tensor], # [,]
forces_weight: Optional[torch.Tensor], # [,] forces_weight: Optional[torch.Tensor], # [,]
stress_weight: Optional[torch.Tensor], # [,] stress_weight: Optional[torch.Tensor], # [,]
virials_weight: Optional[torch.Tensor], # [,] virials_weight: Optional[torch.Tensor], # [,]
dipole_weight: Optional[torch.Tensor], # [,] dipole_weight: Optional[torch.Tensor], # [,]
charges_weight: Optional[torch.Tensor], # [,] charges_weight: Optional[torch.Tensor], # [,]
forces: Optional[torch.Tensor], # [n_nodes, 3] forces: Optional[torch.Tensor], # [n_nodes, 3]
energy: Optional[torch.Tensor], # [, ] energy: Optional[torch.Tensor], # [, ]
stress: Optional[torch.Tensor], # [1,3,3] stress: Optional[torch.Tensor], # [1,3,3]
virials: Optional[torch.Tensor], # [1,3,3] virials: Optional[torch.Tensor], # [1,3,3]
dipole: Optional[torch.Tensor], # [, 3] dipole: Optional[torch.Tensor], # [, 3]
charges: Optional[torch.Tensor], # [n_nodes, ] charges: Optional[torch.Tensor], # [n_nodes, ]
): ):
# Check shapes # Check shapes
num_nodes = node_attrs.shape[0] num_nodes = node_attrs.shape[0]
assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
assert positions.shape == (num_nodes, 3) assert positions.shape == (num_nodes, 3)
assert shifts.shape[1] == 3 assert shifts.shape[1] == 3
assert unit_shifts.shape[1] == 3 assert unit_shifts.shape[1] == 3
assert len(node_attrs.shape) == 2 assert len(node_attrs.shape) == 2
assert weight is None or len(weight.shape) == 0 assert weight is None or len(weight.shape) == 0
assert head is None or len(head.shape) == 0 assert head is None or len(head.shape) == 0
assert energy_weight is None or len(energy_weight.shape) == 0 assert energy_weight is None or len(energy_weight.shape) == 0
assert forces_weight is None or len(forces_weight.shape) == 0 assert forces_weight is None or len(forces_weight.shape) == 0
assert stress_weight is None or len(stress_weight.shape) == 0 assert stress_weight is None or len(stress_weight.shape) == 0
assert virials_weight is None or len(virials_weight.shape) == 0 assert virials_weight is None or len(virials_weight.shape) == 0
assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight
assert charges_weight is None or len(charges_weight.shape) == 0 assert charges_weight is None or len(charges_weight.shape) == 0
assert cell is None or cell.shape == (3, 3) assert cell is None or cell.shape == (3, 3)
assert forces is None or forces.shape == (num_nodes, 3) assert forces is None or forces.shape == (num_nodes, 3)
assert energy is None or len(energy.shape) == 0 assert energy is None or len(energy.shape) == 0
assert stress is None or stress.shape == (1, 3, 3) assert stress is None or stress.shape == (1, 3, 3)
assert virials is None or virials.shape == (1, 3, 3) assert virials is None or virials.shape == (1, 3, 3)
assert dipole is None or dipole.shape[-1] == 3 assert dipole is None or dipole.shape[-1] == 3
assert charges is None or charges.shape == (num_nodes,) assert charges is None or charges.shape == (num_nodes,)
# Aggregate data # Aggregate data
data = { data = {
"num_nodes": num_nodes, "num_nodes": num_nodes,
"edge_index": edge_index, "edge_index": edge_index,
"positions": positions, "positions": positions,
"shifts": shifts, "shifts": shifts,
"unit_shifts": unit_shifts, "unit_shifts": unit_shifts,
"cell": cell, "cell": cell,
"node_attrs": node_attrs, "node_attrs": node_attrs,
"weight": weight, "weight": weight,
"head": head, "head": head,
"energy_weight": energy_weight, "energy_weight": energy_weight,
"forces_weight": forces_weight, "forces_weight": forces_weight,
"stress_weight": stress_weight, "stress_weight": stress_weight,
"virials_weight": virials_weight, "virials_weight": virials_weight,
"dipole_weight": dipole_weight, "dipole_weight": dipole_weight,
"charges_weight": charges_weight, "charges_weight": charges_weight,
"forces": forces, "forces": forces,
"energy": energy, "energy": energy,
"stress": stress, "stress": stress,
"virials": virials, "virials": virials,
"dipole": dipole, "dipole": dipole,
"charges": charges, "charges": charges,
} }
super().__init__(**data) super().__init__(**data)
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: Configuration, config: Configuration,
z_table: AtomicNumberTable, z_table: AtomicNumberTable,
cutoff: float, cutoff: float,
heads: Optional[list] = None, heads: Optional[list] = None,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
) -> "AtomicData": ) -> "AtomicData":
if heads is None: if heads is None:
heads = ["Default"] heads = ["Default"]
edge_index, shifts, unit_shifts, cell = get_neighborhood( edge_index, shifts, unit_shifts, cell = get_neighborhood(
positions=config.positions, positions=config.positions,
cutoff=cutoff, cutoff=cutoff,
pbc=deepcopy(config.pbc), pbc=deepcopy(config.pbc),
cell=deepcopy(config.cell), cell=deepcopy(config.cell),
) )
indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table)
one_hot = to_one_hot( one_hot = to_one_hot(
torch.tensor(indices, dtype=torch.long).unsqueeze(-1), torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
num_classes=len(z_table), num_classes=len(z_table),
) )
try: try:
head = torch.tensor(heads.index(config.head), dtype=torch.long) head = torch.tensor(heads.index(config.head), dtype=torch.long)
except ValueError: except ValueError:
head = torch.tensor(len(heads) - 1, dtype=torch.long) head = torch.tensor(len(heads) - 1, dtype=torch.long)
cell = ( cell = (
torch.tensor(cell, dtype=torch.get_default_dtype()) torch.tensor(cell, dtype=torch.get_default_dtype())
if cell is not None if cell is not None
else torch.tensor( else torch.tensor(
3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype()
).view(3, 3) ).view(3, 3)
) )
num_atoms = len(config.atomic_numbers) num_atoms = len(config.atomic_numbers)
weight = ( weight = (
torch.tensor(config.weight, dtype=torch.get_default_dtype()) torch.tensor(config.weight, dtype=torch.get_default_dtype())
if config.weight is not None if config.weight is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype()) else torch.tensor(1.0, dtype=torch.get_default_dtype())
) )
energy_weight = ( energy_weight = (
torch.tensor( torch.tensor(
config.property_weights.get("energy"), dtype=torch.get_default_dtype() config.property_weights.get("energy"), dtype=torch.get_default_dtype()
) )
if config.property_weights.get("energy") is not None if config.property_weights.get("energy") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype()) else torch.tensor(1.0, dtype=torch.get_default_dtype())
) )
forces_weight = ( forces_weight = (
torch.tensor( torch.tensor(
config.property_weights.get("forces"), dtype=torch.get_default_dtype() config.property_weights.get("forces"), dtype=torch.get_default_dtype()
) )
if config.property_weights.get("forces") is not None if config.property_weights.get("forces") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype()) else torch.tensor(1.0, dtype=torch.get_default_dtype())
) )
stress_weight = ( stress_weight = (
torch.tensor( torch.tensor(
config.property_weights.get("stress"), dtype=torch.get_default_dtype() config.property_weights.get("stress"), dtype=torch.get_default_dtype()
) )
if config.property_weights.get("stress") is not None if config.property_weights.get("stress") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype()) else torch.tensor(1.0, dtype=torch.get_default_dtype())
) )
virials_weight = ( virials_weight = (
torch.tensor( torch.tensor(
config.property_weights.get("virials"), dtype=torch.get_default_dtype() config.property_weights.get("virials"), dtype=torch.get_default_dtype()
) )
if config.property_weights.get("virials") is not None if config.property_weights.get("virials") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype()) else torch.tensor(1.0, dtype=torch.get_default_dtype())
) )
dipole_weight = ( dipole_weight = (
torch.tensor( torch.tensor(
config.property_weights.get("dipole"), dtype=torch.get_default_dtype() config.property_weights.get("dipole"), dtype=torch.get_default_dtype()
) )
if config.property_weights.get("dipole") is not None if config.property_weights.get("dipole") is not None
else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype())
) )
if len(dipole_weight.shape) == 0: if len(dipole_weight.shape) == 0:
dipole_weight = dipole_weight * torch.tensor( dipole_weight = dipole_weight * torch.tensor(
[[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype() [[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()
) )
elif len(dipole_weight.shape) == 1: elif len(dipole_weight.shape) == 1:
dipole_weight = dipole_weight.unsqueeze(0) dipole_weight = dipole_weight.unsqueeze(0)
charges_weight = ( charges_weight = (
torch.tensor( torch.tensor(
config.property_weights.get("charges"), dtype=torch.get_default_dtype() config.property_weights.get("charges"), dtype=torch.get_default_dtype()
) )
if config.property_weights.get("charges") is not None if config.property_weights.get("charges") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype()) else torch.tensor(1.0, dtype=torch.get_default_dtype())
) )
forces = ( forces = (
torch.tensor( torch.tensor(
config.properties.get("forces"), dtype=torch.get_default_dtype() config.properties.get("forces"), dtype=torch.get_default_dtype()
) )
if config.properties.get("forces") is not None if config.properties.get("forces") is not None
else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype())
) )
energy = ( energy = (
torch.tensor( torch.tensor(
config.properties.get("energy"), dtype=torch.get_default_dtype() config.properties.get("energy"), dtype=torch.get_default_dtype()
) )
if config.properties.get("energy") is not None if config.properties.get("energy") is not None
else torch.tensor(0.0, dtype=torch.get_default_dtype()) else torch.tensor(0.0, dtype=torch.get_default_dtype())
) )
stress = ( stress = (
voigt_to_matrix( voigt_to_matrix(
torch.tensor( torch.tensor(
config.properties.get("stress"), dtype=torch.get_default_dtype() config.properties.get("stress"), dtype=torch.get_default_dtype()
) )
).unsqueeze(0) ).unsqueeze(0)
if config.properties.get("stress") is not None if config.properties.get("stress") is not None
else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype())
) )
virials = ( virials = (
voigt_to_matrix( voigt_to_matrix(
torch.tensor( torch.tensor(
config.properties.get("virials"), dtype=torch.get_default_dtype() config.properties.get("virials"), dtype=torch.get_default_dtype()
) )
).unsqueeze(0) ).unsqueeze(0)
if config.properties.get("virials") is not None if config.properties.get("virials") is not None
else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype())
) )
dipole = ( dipole = (
torch.tensor( torch.tensor(
config.properties.get("dipole"), dtype=torch.get_default_dtype() config.properties.get("dipole"), dtype=torch.get_default_dtype()
).unsqueeze(0) ).unsqueeze(0)
if config.properties.get("dipole") is not None if config.properties.get("dipole") is not None
else torch.zeros(1, 3, dtype=torch.get_default_dtype()) else torch.zeros(1, 3, dtype=torch.get_default_dtype())
) )
charges = ( charges = (
torch.tensor( torch.tensor(
config.properties.get("charges"), dtype=torch.get_default_dtype() config.properties.get("charges"), dtype=torch.get_default_dtype()
) )
if config.properties.get("charges") is not None if config.properties.get("charges") is not None
else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) else torch.zeros(num_atoms, dtype=torch.get_default_dtype())
) )
return cls( return cls(
edge_index=torch.tensor(edge_index, dtype=torch.long), edge_index=torch.tensor(edge_index, dtype=torch.long),
positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()),
shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()),
unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()),
cell=cell, cell=cell,
node_attrs=one_hot, node_attrs=one_hot,
weight=weight, weight=weight,
head=head, head=head,
energy_weight=energy_weight, energy_weight=energy_weight,
forces_weight=forces_weight, forces_weight=forces_weight,
stress_weight=stress_weight, stress_weight=stress_weight,
virials_weight=virials_weight, virials_weight=virials_weight,
dipole_weight=dipole_weight, dipole_weight=dipole_weight,
charges_weight=charges_weight, charges_weight=charges_weight,
forces=forces, forces=forces,
energy=energy, energy=energy,
stress=stress, stress=stress,
virials=virials, virials=virials,
dipole=dipole, dipole=dipole,
charges=charges, charges=charges,
) )
def get_data_loader( def get_data_loader(
dataset: Sequence[AtomicData], dataset: Sequence[AtomicData],
batch_size: int, batch_size: int,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
) -> torch.utils.data.DataLoader: ) -> torch.utils.data.DataLoader:
return torch_geometric.dataloader.DataLoader( return torch_geometric.dataloader.DataLoader(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last, drop_last=drop_last,
) )
from glob import glob from glob import glob
from typing import List from typing import List
import h5py import h5py
from torch.utils.data import ConcatDataset, Dataset from torch.utils.data import ConcatDataset, Dataset
from mace.data.atomic_data import AtomicData from mace.data.atomic_data import AtomicData
from mace.data.utils import Configuration from mace.data.utils import Configuration
from mace.tools.utils import AtomicNumberTable from mace.tools.utils import AtomicNumberTable
class HDF5Dataset(Dataset): class HDF5Dataset(Dataset):
def __init__( def __init__(
self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs
): ):
super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments
self.file_path = file_path self.file_path = file_path
self._file = None self._file = None
batch_key = list(self.file.keys())[0] batch_key = list(self.file.keys())[0]
self.batch_size = len(self.file[batch_key].keys()) self.batch_size = len(self.file[batch_key].keys())
self.length = len(self.file.keys()) * self.batch_size self.length = len(self.file.keys()) * self.batch_size
self.r_max = r_max self.r_max = r_max
self.z_table = z_table self.z_table = z_table
self.atomic_dataclass = atomic_dataclass self.atomic_dataclass = atomic_dataclass
try: try:
self.drop_last = bool(self.file.attrs["drop_last"]) self.drop_last = bool(self.file.attrs["drop_last"])
except KeyError: except KeyError:
self.drop_last = False self.drop_last = False
self.kwargs = kwargs self.kwargs = kwargs
@property @property
def file(self): def file(self):
if self._file is None: if self._file is None:
# If a file has not already been opened, open one here # If a file has not already been opened, open one here
self._file = h5py.File(self.file_path, "r") self._file = h5py.File(self.file_path, "r")
return self._file return self._file
def __getstate__(self): def __getstate__(self):
_d = dict(self.__dict__) _d = dict(self.__dict__)
# An opened h5py.File cannot be pickled, so we must exclude it from the state # An opened h5py.File cannot be pickled, so we must exclude it from the state
_d["_file"] = None _d["_file"] = None
return _d return _d
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, index): def __getitem__(self, index):
# compute the index of the batch # compute the index of the batch
batch_index = index // self.batch_size batch_index = index // self.batch_size
config_index = index % self.batch_size config_index = index % self.batch_size
grp = self.file["config_batch_" + str(batch_index)] grp = self.file["config_batch_" + str(batch_index)]
subgrp = grp["config_" + str(config_index)] subgrp = grp["config_" + str(config_index)]
properties = {} properties = {}
property_weights = {} property_weights = {}
for key in subgrp["properties"]: for key in subgrp["properties"]:
properties[key] = unpack_value(subgrp["properties"][key][()]) properties[key] = unpack_value(subgrp["properties"][key][()])
for key in subgrp["property_weights"]: for key in subgrp["property_weights"]:
property_weights[key] = unpack_value(subgrp["property_weights"][key][()]) property_weights[key] = unpack_value(subgrp["property_weights"][key][()])
config = Configuration( config = Configuration(
atomic_numbers=subgrp["atomic_numbers"][()], atomic_numbers=subgrp["atomic_numbers"][()],
positions=subgrp["positions"][()], positions=subgrp["positions"][()],
properties=properties, properties=properties,
weight=unpack_value(subgrp["weight"][()]), weight=unpack_value(subgrp["weight"][()]),
property_weights=property_weights, property_weights=property_weights,
config_type=unpack_value(subgrp["config_type"][()]), config_type=unpack_value(subgrp["config_type"][()]),
pbc=unpack_value(subgrp["pbc"][()]), pbc=unpack_value(subgrp["pbc"][()]),
cell=unpack_value(subgrp["cell"][()]), cell=unpack_value(subgrp["cell"][()]),
) )
if config.head is None: if config.head is None:
config.head = self.kwargs.get("head") config.head = self.kwargs.get("head")
atomic_data = self.atomic_dataclass.from_config( atomic_data = self.atomic_dataclass.from_config(
config, config,
z_table=self.z_table, z_table=self.z_table,
cutoff=self.r_max, cutoff=self.r_max,
heads=self.kwargs.get("heads", ["Default"]), heads=self.kwargs.get("heads", ["Default"]),
**{k: v for k, v in self.kwargs.items() if k != "heads"}, **{k: v for k, v in self.kwargs.items() if k != "heads"},
) )
return atomic_data return atomic_data
def dataset_from_sharded_hdf5( def dataset_from_sharded_hdf5(
files: List, z_table: AtomicNumberTable, r_max: float, **kwargs files: List, z_table: AtomicNumberTable, r_max: float, **kwargs
): ):
files = glob(files + "/*") files = glob(files + "/*")
datasets = [] datasets = []
for file in files: for file in files:
datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs))
full_dataset = ConcatDataset(datasets) full_dataset = ConcatDataset(datasets)
return full_dataset return full_dataset
def unpack_value(value): def unpack_value(value):
value = value.decode("utf-8") if isinstance(value, bytes) else value value = value.decode("utf-8") if isinstance(value, bytes) else value
return None if str(value) == "None" else value return None if str(value) == "None" else value
import os import os
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mace.data.atomic_data import AtomicData from mace.data.atomic_data import AtomicData
from mace.data.utils import KeySpecification, config_from_atoms from mace.data.utils import KeySpecification, config_from_atoms
from mace.tools.default_keys import DefaultKeys from mace.tools.default_keys import DefaultKeys
from mace.tools.fairchem_dataset import AseDBDataset from mace.tools.fairchem_dataset import AseDBDataset
class LMDBDataset(Dataset): class LMDBDataset(Dataset):
def __init__(self, file_path, r_max, z_table, **kwargs): def __init__(self, file_path, r_max, z_table, **kwargs):
dataset_paths = file_path.split(":") # using : split multiple paths dataset_paths = file_path.split(":") # using : split multiple paths
# make sure each of the path exist # make sure each of the path exist
for path in dataset_paths: for path in dataset_paths:
assert os.path.exists(path) assert os.path.exists(path)
config_kwargs = {} config_kwargs = {}
super(LMDBDataset, self).__init__() # pylint: disable=super-with-arguments super(LMDBDataset, self).__init__() # pylint: disable=super-with-arguments
self.AseDB = AseDBDataset(config=dict(src=dataset_paths, **config_kwargs)) self.AseDB = AseDBDataset(config=dict(src=dataset_paths, **config_kwargs))
self.r_max = r_max self.r_max = r_max
self.z_table = z_table self.z_table = z_table
self.kwargs = kwargs self.kwargs = kwargs
self.transform = kwargs["transform"] if "transform" in kwargs else None self.transform = kwargs["transform"] if "transform" in kwargs else None
def __len__(self): def __len__(self):
return len(self.AseDB) return len(self.AseDB)
def __getitem__(self, index): def __getitem__(self, index):
try: try:
atoms = self.AseDB.get_atoms(self.AseDB.ids[index]) atoms = self.AseDB.get_atoms(self.AseDB.ids[index])
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
print(f"Error in index {index}") print(f"Error in index {index}")
print(e) print(e)
return None return None
assert np.sum(atoms.get_cell() == atoms.cell) == 9 assert np.sum(atoms.get_cell() == atoms.cell) == 9
if hasattr(atoms, "calc") and hasattr(atoms.calc, "results"): if hasattr(atoms, "calc") and hasattr(atoms.calc, "results"):
if "energy" in atoms.calc.results: if "energy" in atoms.calc.results:
atoms.info[DefaultKeys.ENERGY.value] = atoms.calc.results["energy"] atoms.info[DefaultKeys.ENERGY.value] = atoms.calc.results["energy"]
if "forces" in atoms.calc.results: if "forces" in atoms.calc.results:
atoms.arrays[DefaultKeys.FORCES.value] = atoms.calc.results["forces"] atoms.arrays[DefaultKeys.FORCES.value] = atoms.calc.results["forces"]
if "stress" in atoms.calc.results: if "stress" in atoms.calc.results:
atoms.info[DefaultKeys.STRESS.value] = atoms.calc.results["stress"] atoms.info[DefaultKeys.STRESS.value] = atoms.calc.results["stress"]
config = config_from_atoms( config = config_from_atoms(
atoms, atoms,
key_specification=KeySpecification.from_defaults(), key_specification=KeySpecification.from_defaults(),
) )
# Set head if not already set # Set head if not already set
if config.head == "Default": if config.head == "Default":
config.head = self.kwargs.get("head", "Default") config.head = self.kwargs.get("head", "Default")
try: try:
atomic_data = AtomicData.from_config( atomic_data = AtomicData.from_config(
config, config,
z_table=self.z_table, z_table=self.z_table,
cutoff=self.r_max, cutoff=self.r_max,
heads=self.kwargs.get("heads", ["Default"]), heads=self.kwargs.get("heads", ["Default"]),
) )
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
print(f"Error in index {index}") print(f"Error in index {index}")
print(e) print(e)
if self.transform: if self.transform:
atomic_data = self.transform(atomic_data) atomic_data = self.transform(atomic_data)
return atomic_data return atomic_data
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
from matscipy.neighbours import neighbour_list from matscipy.neighbours import neighbour_list
def get_neighborhood( def get_neighborhood(
positions: np.ndarray, # [num_positions, 3] positions: np.ndarray, # [num_positions, 3]
cutoff: float, cutoff: float,
pbc: Optional[Tuple[bool, bool, bool]] = None, pbc: Optional[Tuple[bool, bool, bool]] = None,
cell: Optional[np.ndarray] = None, # [3, 3] cell: Optional[np.ndarray] = None, # [3, 3]
true_self_interaction=False, true_self_interaction=False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if pbc is None: if pbc is None:
pbc = (False, False, False) pbc = (False, False, False)
if cell is None or cell.any() == np.zeros((3, 3)).any(): if cell is None or cell.any() == np.zeros((3, 3)).any():
cell = np.identity(3, dtype=float) cell = np.identity(3, dtype=float)
assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc)
assert cell.shape == (3, 3) assert cell.shape == (3, 3)
pbc_x = pbc[0] pbc_x = pbc[0]
pbc_y = pbc[1] pbc_y = pbc[1]
pbc_z = pbc[2] pbc_z = pbc[2]
identity = np.identity(3, dtype=float) identity = np.identity(3, dtype=float)
max_positions = np.max(np.absolute(positions)) + 1 max_positions = np.max(np.absolute(positions)) + 1
# Extend cell in non-periodic directions # Extend cell in non-periodic directions
# For models with more than 5 layers, the multiplicative constant needs to be increased. # For models with more than 5 layers, the multiplicative constant needs to be increased.
# temp_cell = np.copy(cell) # temp_cell = np.copy(cell)
if not pbc_x: if not pbc_x:
cell[0, :] = max_positions * 5 * cutoff * identity[0, :] cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
if not pbc_y: if not pbc_y:
cell[1, :] = max_positions * 5 * cutoff * identity[1, :] cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
if not pbc_z: if not pbc_z:
cell[2, :] = max_positions * 5 * cutoff * identity[2, :] cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
sender, receiver, unit_shifts = neighbour_list( sender, receiver, unit_shifts = neighbour_list(
quantities="ijS", quantities="ijS",
pbc=pbc, pbc=pbc,
cell=cell, cell=cell,
positions=positions, positions=positions,
cutoff=cutoff, cutoff=cutoff,
# self_interaction=True, # we want edges from atom to itself in different periodic images # self_interaction=True, # we want edges from atom to itself in different periodic images
# use_scaled_positions=False, # positions are not scaled positions # use_scaled_positions=False, # positions are not scaled positions
) )
if not true_self_interaction: if not true_self_interaction:
# Eliminate self-edges that don't cross periodic boundaries # Eliminate self-edges that don't cross periodic boundaries
true_self_edge = sender == receiver true_self_edge = sender == receiver
true_self_edge &= np.all(unit_shifts == 0, axis=1) true_self_edge &= np.all(unit_shifts == 0, axis=1)
keep_edge = ~true_self_edge keep_edge = ~true_self_edge
# Note: after eliminating self-edges, it can be that no edges remain in this system # Note: after eliminating self-edges, it can be that no edges remain in this system
sender = sender[keep_edge] sender = sender[keep_edge]
receiver = receiver[keep_edge] receiver = receiver[keep_edge]
unit_shifts = unit_shifts[keep_edge] unit_shifts = unit_shifts[keep_edge]
# Build output # Build output
edge_index = np.stack((sender, receiver)) # [2, n_edges] edge_index = np.stack((sender, receiver)) # [2, n_edges]
# From the docs: With the shift vector S, the distances D between atoms can be computed from # From the docs: With the shift vector S, the distances D between atoms can be computed from
# D = positions[j]-positions[i]+S.dot(cell) # D = positions[j]-positions[i]+S.dot(cell)
shifts = np.dot(unit_shifts, cell) # [n_edges, 3] shifts = np.dot(unit_shifts, cell) # [n_edges, 3]
return edge_index, shifts, unit_shifts, cell return edge_index, shifts, unit_shifts, cell
########################################################################################### ###########################################################################################
# Data parsing utilities # Data parsing utilities
# Authors: Ilyes Batatia, Gregor Simm and David Kovacs # Authors: Ilyes Batatia, Gregor Simm and David Kovacs
# This program is distributed under the MIT License (see MIT.md) # This program is distributed under the MIT License (see MIT.md)
########################################################################################### ###########################################################################################
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import Any, Dict, List, Optional, Sequence, Tuple
import ase.data import ase.data
import ase.io import ase.io
import h5py import h5py
import numpy as np import numpy as np
from mace.tools import AtomicNumberTable, DefaultKeys from mace.tools import AtomicNumberTable, DefaultKeys
Positions = np.ndarray # [..., 3] Positions = np.ndarray # [..., 3]
Cell = np.ndarray # [3,3] Cell = np.ndarray # [3,3]
Pbc = tuple # (3,) Pbc = tuple # (3,)
DEFAULT_CONFIG_TYPE = "Default" DEFAULT_CONFIG_TYPE = "Default"
DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0}
@dataclass @dataclass
class KeySpecification: class KeySpecification:
info_keys: Dict[str, str] = field(default_factory=dict) info_keys: Dict[str, str] = field(default_factory=dict)
arrays_keys: Dict[str, str] = field(default_factory=dict) arrays_keys: Dict[str, str] = field(default_factory=dict)
def update( def update(
self, self,
info_keys: Optional[Dict[str, str]] = None, info_keys: Optional[Dict[str, str]] = None,
arrays_keys: Optional[Dict[str, str]] = None, arrays_keys: Optional[Dict[str, str]] = None,
): ):
if info_keys is not None: if info_keys is not None:
self.info_keys.update(info_keys) self.info_keys.update(info_keys)
if arrays_keys is not None: if arrays_keys is not None:
self.arrays_keys.update(arrays_keys) self.arrays_keys.update(arrays_keys)
return self return self
@classmethod @classmethod
def from_defaults(cls): def from_defaults(cls):
instance = cls() instance = cls()
return update_keyspec_from_kwargs(instance, DefaultKeys.keydict()) return update_keyspec_from_kwargs(instance, DefaultKeys.keydict())
def update_keyspec_from_kwargs( def update_keyspec_from_kwargs(
keyspec: KeySpecification, keydict: Dict[str, str] keyspec: KeySpecification, keydict: Dict[str, str]
) -> KeySpecification: ) -> KeySpecification:
# convert command line style property_key arguments into a keyspec # convert command line style property_key arguments into a keyspec
infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"]
arrays = ["forces_key", "charges_key"] arrays = ["forces_key", "charges_key"]
info_keys = {} info_keys = {}
arrays_keys = {} arrays_keys = {}
for key in infos: for key in infos:
if key in keydict: if key in keydict:
info_keys[key[:-4]] = keydict[key] info_keys[key[:-4]] = keydict[key]
for key in arrays: for key in arrays:
if key in keydict: if key in keydict:
arrays_keys[key[:-4]] = keydict[key] arrays_keys[key[:-4]] = keydict[key]
keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys) keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys)
return keyspec return keyspec
@dataclass @dataclass
class Configuration: class Configuration:
atomic_numbers: np.ndarray atomic_numbers: np.ndarray
positions: Positions # Angstrom positions: Positions # Angstrom
properties: Dict[str, Any] properties: Dict[str, Any]
property_weights: Dict[str, float] property_weights: Dict[str, float]
cell: Optional[Cell] = None cell: Optional[Cell] = None
pbc: Optional[Pbc] = None pbc: Optional[Pbc] = None
weight: float = 1.0 # weight of config in loss weight: float = 1.0 # weight of config in loss
config_type: str = DEFAULT_CONFIG_TYPE # config_type of config config_type: str = DEFAULT_CONFIG_TYPE # config_type of config
head: str = "Default" # head used to compute the config head: str = "Default" # head used to compute the config
Configurations = List[Configuration] Configurations = List[Configuration]
def random_train_valid_split( def random_train_valid_split(
items: Sequence, valid_fraction: float, seed: int, work_dir: str items: Sequence, valid_fraction: float, seed: int, work_dir: str
) -> Tuple[List, List]: ) -> Tuple[List, List]:
assert 0.0 < valid_fraction < 1.0 assert 0.0 < valid_fraction < 1.0
size = len(items) size = len(items)
train_size = size - int(valid_fraction * size) train_size = size - int(valid_fraction * size)
indices = list(range(size)) indices = list(range(size))
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
rng.shuffle(indices) rng.shuffle(indices)
if len(indices[train_size:]) < 10: if len(indices[train_size:]) < 10:
logging.info( logging.info(
f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}"
) )
else: else:
# Save indices to file # Save indices to file
with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f:
for index in indices[train_size:]: for index in indices[train_size:]:
f.write(f"{index}\n") f.write(f"{index}\n")
logging.info( logging.info(
f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt"
) )
return ( return (
[items[i] for i in indices[:train_size]], [items[i] for i in indices[:train_size]],
[items[i] for i in indices[train_size:]], [items[i] for i in indices[train_size:]],
) )
def config_from_atoms_list( def config_from_atoms_list(
atoms_list: List[ase.Atoms], atoms_list: List[ase.Atoms],
key_specification: KeySpecification, key_specification: KeySpecification,
config_type_weights: Optional[Dict[str, float]] = None, config_type_weights: Optional[Dict[str, float]] = None,
head_name: str = "Default", head_name: str = "Default",
) -> Configurations: ) -> Configurations:
"""Convert list of ase.Atoms into Configurations""" """Convert list of ase.Atoms into Configurations"""
if config_type_weights is None: if config_type_weights is None:
config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS
all_configs = [] all_configs = []
for atoms in atoms_list: for atoms in atoms_list:
all_configs.append( all_configs.append(
config_from_atoms( config_from_atoms(
atoms, atoms,
key_specification=key_specification, key_specification=key_specification,
config_type_weights=config_type_weights, config_type_weights=config_type_weights,
head_name=head_name, head_name=head_name,
) )
) )
return all_configs return all_configs
def config_from_atoms( def config_from_atoms(
atoms: ase.Atoms, atoms: ase.Atoms,
key_specification: KeySpecification = KeySpecification(), key_specification: KeySpecification = KeySpecification(),
config_type_weights: Optional[Dict[str, float]] = None, config_type_weights: Optional[Dict[str, float]] = None,
head_name: str = "Default", head_name: str = "Default",
) -> Configuration: ) -> Configuration:
"""Convert ase.Atoms to Configuration""" """Convert ase.Atoms to Configuration"""
if config_type_weights is None: if config_type_weights is None:
config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS
atomic_numbers = np.array( atomic_numbers = np.array(
[ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols]
) )
pbc = tuple(atoms.get_pbc()) pbc = tuple(atoms.get_pbc())
cell = np.array(atoms.get_cell()) cell = np.array(atoms.get_cell())
config_type = atoms.info.get("config_type", "Default") config_type = atoms.info.get("config_type", "Default")
weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get(
config_type, 1.0 config_type, 1.0
) )
properties = {} properties = {}
property_weights = {} property_weights = {}
for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): for name in list(key_specification.arrays_keys) + list(key_specification.info_keys):
property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0) property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0)
for name, atoms_key in key_specification.info_keys.items(): for name, atoms_key in key_specification.info_keys.items():
properties[name] = atoms.info.get(atoms_key, None) properties[name] = atoms.info.get(atoms_key, None)
if not atoms_key in atoms.info: if not atoms_key in atoms.info:
property_weights[name] = 0.0 property_weights[name] = 0.0
for name, atoms_key in key_specification.arrays_keys.items(): for name, atoms_key in key_specification.arrays_keys.items():
properties[name] = atoms.arrays.get(atoms_key, None) properties[name] = atoms.arrays.get(atoms_key, None)
if not atoms_key in atoms.arrays: if not atoms_key in atoms.arrays:
property_weights[name] = 0.0 property_weights[name] = 0.0
return Configuration( return Configuration(
atomic_numbers=atomic_numbers, atomic_numbers=atomic_numbers,
positions=atoms.get_positions(), positions=atoms.get_positions(),
properties=properties, properties=properties,
weight=weight, weight=weight,
property_weights=property_weights, property_weights=property_weights,
head=head_name, head=head_name,
config_type=config_type, config_type=config_type,
pbc=pbc, pbc=pbc,
cell=cell, cell=cell,
) )
def test_config_types( def test_config_types(
test_configs: Configurations, test_configs: Configurations,
) -> List[Tuple[str, List[Configuration]]]: ) -> List[Tuple[str, List[Configuration]]]:
"""Split test set based on config_type-s""" """Split test set based on config_type-s"""
test_by_ct = [] test_by_ct = []
all_cts = [] all_cts = []
for conf in test_configs: for conf in test_configs:
config_type_name = conf.config_type + "_" + conf.head config_type_name = conf.config_type + "_" + conf.head
if config_type_name not in all_cts: if config_type_name not in all_cts:
all_cts.append(config_type_name) all_cts.append(config_type_name)
test_by_ct.append((config_type_name, [conf])) test_by_ct.append((config_type_name, [conf]))
else: else:
ind = all_cts.index(config_type_name) ind = all_cts.index(config_type_name)
test_by_ct[ind][1].append(conf) test_by_ct[ind][1].append(conf)
return test_by_ct return test_by_ct
def load_from_xyz( def load_from_xyz(
file_path: str, file_path: str,
key_specification: KeySpecification, key_specification: KeySpecification,
head_name: str = "Default", head_name: str = "Default",
config_type_weights: Optional[Dict] = None, config_type_weights: Optional[Dict] = None,
extract_atomic_energies: bool = False, extract_atomic_energies: bool = False,
keep_isolated_atoms: bool = False, keep_isolated_atoms: bool = False,
) -> Tuple[Dict[int, float], Configurations]: ) -> Tuple[Dict[int, float], Configurations]:
atoms_list = ase.io.read(file_path, index=":") atoms_list = ase.io.read(file_path, index=":")
energy_key = key_specification.info_keys["energy"] energy_key = key_specification.info_keys["energy"]
forces_key = key_specification.arrays_keys["forces"] forces_key = key_specification.arrays_keys["forces"]
stress_key = key_specification.info_keys["stress"] stress_key = key_specification.info_keys["stress"]
head_key = key_specification.info_keys["head"] head_key = key_specification.info_keys["head"]
if energy_key == "energy": if energy_key == "energy":
logging.warning( logging.warning(
"Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name."
) )
key_specification.info_keys["energy"] = "REF_energy" key_specification.info_keys["energy"] = "REF_energy"
for atoms in atoms_list: for atoms in atoms_list:
try: try:
atoms.info["REF_energy"] = atoms.get_potential_energy() atoms.info["REF_energy"] = atoms.get_potential_energy()
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
logging.error(f"Failed to extract energy: {e}") logging.error(f"Failed to extract energy: {e}")
atoms.info["REF_energy"] = None atoms.info["REF_energy"] = None
if forces_key == "forces": if forces_key == "forces":
logging.warning( logging.warning(
"Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name."
) )
key_specification.arrays_keys["forces"] = "REF_forces" key_specification.arrays_keys["forces"] = "REF_forces"
for atoms in atoms_list: for atoms in atoms_list:
try: try:
atoms.arrays["REF_forces"] = atoms.get_forces() atoms.arrays["REF_forces"] = atoms.get_forces()
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
logging.error(f"Failed to extract forces: {e}") logging.error(f"Failed to extract forces: {e}")
atoms.arrays["REF_forces"] = None atoms.arrays["REF_forces"] = None
if stress_key == "stress": if stress_key == "stress":
logging.warning( logging.warning(
"Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name."
) )
key_specification.info_keys["stress"] = "REF_stress" key_specification.info_keys["stress"] = "REF_stress"
for atoms in atoms_list: for atoms in atoms_list:
try: try:
atoms.info["REF_stress"] = atoms.get_stress() atoms.info["REF_stress"] = atoms.get_stress()
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
atoms.info["REF_stress"] = None atoms.info["REF_stress"] = None
if not isinstance(atoms_list, list): if not isinstance(atoms_list, list):
atoms_list = [atoms_list] atoms_list = [atoms_list]
atomic_energies_dict = {} atomic_energies_dict = {}
if extract_atomic_energies: if extract_atomic_energies:
atoms_without_iso_atoms = [] atoms_without_iso_atoms = []
for idx, atoms in enumerate(atoms_list): for idx, atoms in enumerate(atoms_list):
atoms.info[head_key] = head_name atoms.info[head_key] = head_name
isolated_atom_config = ( isolated_atom_config = (
len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom"
) )
if isolated_atom_config: if isolated_atom_config:
atomic_number = int(atoms.get_atomic_numbers()[0]) atomic_number = int(atoms.get_atomic_numbers()[0])
if energy_key in atoms.info.keys(): if energy_key in atoms.info.keys():
atomic_energies_dict[atomic_number] = float(atoms.info[energy_key]) atomic_energies_dict[atomic_number] = float(atoms.info[energy_key])
else: else:
logging.warning( logging.warning(
f"Configuration '{idx}' is marked as 'IsolatedAtom' " f"Configuration '{idx}' is marked as 'IsolatedAtom' "
"but does not contain an energy. Zero energy will be used." "but does not contain an energy. Zero energy will be used."
) )
atomic_energies_dict[atomic_number] = 0.0 atomic_energies_dict[atomic_number] = 0.0
else: else:
atoms_without_iso_atoms.append(atoms) atoms_without_iso_atoms.append(atoms)
if len(atomic_energies_dict) > 0: if len(atomic_energies_dict) > 0:
logging.info("Using isolated atom energies from training file") logging.info("Using isolated atom energies from training file")
if not keep_isolated_atoms: if not keep_isolated_atoms:
atoms_list = atoms_without_iso_atoms atoms_list = atoms_without_iso_atoms
for atoms in atoms_list: for atoms in atoms_list:
atoms.info[head_key] = head_name atoms.info[head_key] = head_name
configs = config_from_atoms_list( configs = config_from_atoms_list(
atoms_list, atoms_list,
config_type_weights=config_type_weights, config_type_weights=config_type_weights,
key_specification=key_specification, key_specification=key_specification,
head_name=head_name, head_name=head_name,
) )
return atomic_energies_dict, configs return atomic_energies_dict, configs
def compute_average_E0s( def compute_average_E0s(
collections_train: Configurations, z_table: AtomicNumberTable collections_train: Configurations, z_table: AtomicNumberTable
) -> Dict[int, float]: ) -> Dict[int, float]:
""" """
Function to compute the average interaction energy of each chemical element Function to compute the average interaction energy of each chemical element
returns dictionary of E0s returns dictionary of E0s
""" """
len_train = len(collections_train) len_train = len(collections_train)
len_zs = len(z_table) len_zs = len(z_table)
A = np.zeros((len_train, len_zs)) A = np.zeros((len_train, len_zs))
B = np.zeros(len_train) B = np.zeros(len_train)
for i in range(len_train): for i in range(len_train):
B[i] = collections_train[i].properties["energy"] B[i] = collections_train[i].properties["energy"]
for j, z in enumerate(z_table.zs): for j, z in enumerate(z_table.zs):
A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z)
try: try:
E0s = np.linalg.lstsq(A, B, rcond=None)[0] E0s = np.linalg.lstsq(A, B, rcond=None)[0]
atomic_energies_dict = {} atomic_energies_dict = {}
for i, z in enumerate(z_table.zs): for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = E0s[i] atomic_energies_dict[z] = E0s[i]
except np.linalg.LinAlgError: except np.linalg.LinAlgError:
logging.error( logging.error(
"Failed to compute E0s using least squares regression, using the same for all atoms" "Failed to compute E0s using least squares regression, using the same for all atoms"
) )
atomic_energies_dict = {} atomic_energies_dict = {}
for i, z in enumerate(z_table.zs): for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = 0.0 atomic_energies_dict[z] = 0.0
return atomic_energies_dict return atomic_energies_dict
def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: def save_dataset_as_HDF5(dataset: List, out_name: str) -> None:
with h5py.File(out_name, "w") as f: with h5py.File(out_name, "w") as f:
for i, data in enumerate(dataset): for i, data in enumerate(dataset):
save_AtomicData_to_HDF5(data, i, f) save_AtomicData_to_HDF5(data, i, f)
def save_AtomicData_to_HDF5(data, i, h5_file) -> None: def save_AtomicData_to_HDF5(data, i, h5_file) -> None:
grp = h5_file.create_group(f"config_{i}") grp = h5_file.create_group(f"config_{i}")
grp["num_nodes"] = data.num_nodes grp["num_nodes"] = data.num_nodes
grp["edge_index"] = data.edge_index grp["edge_index"] = data.edge_index
grp["positions"] = data.positions grp["positions"] = data.positions
grp["shifts"] = data.shifts grp["shifts"] = data.shifts
grp["unit_shifts"] = data.unit_shifts grp["unit_shifts"] = data.unit_shifts
grp["cell"] = data.cell grp["cell"] = data.cell
grp["node_attrs"] = data.node_attrs grp["node_attrs"] = data.node_attrs
grp["weight"] = data.weight grp["weight"] = data.weight
grp["energy_weight"] = data.energy_weight grp["energy_weight"] = data.energy_weight
grp["forces_weight"] = data.forces_weight grp["forces_weight"] = data.forces_weight
grp["stress_weight"] = data.stress_weight grp["stress_weight"] = data.stress_weight
grp["virials_weight"] = data.virials_weight grp["virials_weight"] = data.virials_weight
grp["forces"] = data.forces grp["forces"] = data.forces
grp["energy"] = data.energy grp["energy"] = data.energy
grp["stress"] = data.stress grp["stress"] = data.stress
grp["virials"] = data.virials grp["virials"] = data.virials
grp["dipole"] = data.dipole grp["dipole"] = data.dipole
grp["charges"] = data.charges grp["charges"] = data.charges
grp["head"] = data.head grp["head"] = data.head
def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None:
grp = h5_file.create_group("config_batch_0") grp = h5_file.create_group("config_batch_0")
for j, config in enumerate(configurations): for j, config in enumerate(configurations):
subgroup_name = f"config_{j}" subgroup_name = f"config_{j}"
subgroup = grp.create_group(subgroup_name) subgroup = grp.create_group(subgroup_name)
subgroup["atomic_numbers"] = write_value(config.atomic_numbers) subgroup["atomic_numbers"] = write_value(config.atomic_numbers)
subgroup["positions"] = write_value(config.positions) subgroup["positions"] = write_value(config.positions)
properties_subgrp = subgroup.create_group("properties") properties_subgrp = subgroup.create_group("properties")
for key, value in config.properties.items(): for key, value in config.properties.items():
properties_subgrp[key] = write_value(value) properties_subgrp[key] = write_value(value)
subgroup["cell"] = write_value(config.cell) subgroup["cell"] = write_value(config.cell)
subgroup["pbc"] = write_value(config.pbc) subgroup["pbc"] = write_value(config.pbc)
subgroup["weight"] = write_value(config.weight) subgroup["weight"] = write_value(config.weight)
weights_subgrp = subgroup.create_group("property_weights") weights_subgrp = subgroup.create_group("property_weights")
for key, value in config.property_weights.items(): for key, value in config.property_weights.items():
weights_subgrp[key] = write_value(value) weights_subgrp[key] = write_value(value)
subgroup["config_type"] = write_value(config.config_type) subgroup["config_type"] = write_value(config.config_type)
def write_value(value): def write_value(value):
return value if value is not None else "None" return value if value is not None else "None"
from typing import Callable, Dict, Optional, Type from typing import Callable, Dict, Optional, Type
import torch import torch
from .blocks import ( from .blocks import (
AtomicEnergiesBlock, AtomicEnergiesBlock,
EquivariantProductBasisBlock, EquivariantProductBasisBlock,
InteractionBlock, InteractionBlock,
LinearDipoleReadoutBlock, LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock, LinearNodeEmbeddingBlock,
LinearReadoutBlock, LinearReadoutBlock,
NonLinearDipoleReadoutBlock, NonLinearDipoleReadoutBlock,
NonLinearReadoutBlock, NonLinearReadoutBlock,
RadialEmbeddingBlock, RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock, RealAgnosticAttResidualInteractionBlock,
RealAgnosticDensityInteractionBlock, RealAgnosticDensityInteractionBlock,
RealAgnosticDensityResidualInteractionBlock, RealAgnosticDensityResidualInteractionBlock,
RealAgnosticInteractionBlock, RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock, RealAgnosticResidualInteractionBlock,
ScaleShiftBlock, ScaleShiftBlock,
) )
from .loss import ( from .loss import (
DipoleSingleLoss, DipoleSingleLoss,
UniversalLoss, UniversalLoss,
WeightedEnergyForcesDipoleLoss, WeightedEnergyForcesDipoleLoss,
WeightedEnergyForcesL1L2Loss, WeightedEnergyForcesL1L2Loss,
WeightedEnergyForcesLoss, WeightedEnergyForcesLoss,
WeightedEnergyForcesStressLoss, WeightedEnergyForcesStressLoss,
WeightedEnergyForcesVirialsLoss, WeightedEnergyForcesVirialsLoss,
WeightedForcesLoss, WeightedForcesLoss,
WeightedHuberEnergyForcesStressLoss, WeightedHuberEnergyForcesStressLoss,
) )
from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE
from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis
from .symmetric_contraction import SymmetricContraction from .symmetric_contraction import SymmetricContraction
from .utils import ( from .utils import (
compute_avg_num_neighbors, compute_avg_num_neighbors,
compute_fixed_charge_dipole, compute_fixed_charge_dipole,
compute_mean_rms_energy_forces, compute_mean_rms_energy_forces,
compute_mean_std_atomic_inter_energy, compute_mean_std_atomic_inter_energy,
compute_rms_dipoles, compute_rms_dipoles,
compute_statistics, compute_statistics,
) )
interaction_classes: Dict[str, Type[InteractionBlock]] = { interaction_classes: Dict[str, Type[InteractionBlock]] = {
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock,
"RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock,
} }
scaling_classes: Dict[str, Callable] = { scaling_classes: Dict[str, Callable] = {
"std_scaling": compute_mean_std_atomic_inter_energy, "std_scaling": compute_mean_std_atomic_inter_energy,
"rms_forces_scaling": compute_mean_rms_energy_forces, "rms_forces_scaling": compute_mean_rms_energy_forces,
"rms_dipoles_scaling": compute_rms_dipoles, "rms_dipoles_scaling": compute_rms_dipoles,
} }
gate_dict: Dict[str, Optional[Callable]] = { gate_dict: Dict[str, Optional[Callable]] = {
"abs": torch.abs, "abs": torch.abs,
"tanh": torch.tanh, "tanh": torch.tanh,
"silu": torch.nn.functional.silu, "silu": torch.nn.functional.silu,
"None": None, "None": None,
} }
__all__ = [ __all__ = [
"AtomicEnergiesBlock", "AtomicEnergiesBlock",
"RadialEmbeddingBlock", "RadialEmbeddingBlock",
"ZBLBasis", "ZBLBasis",
"LinearNodeEmbeddingBlock", "LinearNodeEmbeddingBlock",
"LinearReadoutBlock", "LinearReadoutBlock",
"EquivariantProductBasisBlock", "EquivariantProductBasisBlock",
"ScaleShiftBlock", "ScaleShiftBlock",
"LinearDipoleReadoutBlock", "LinearDipoleReadoutBlock",
"NonLinearDipoleReadoutBlock", "NonLinearDipoleReadoutBlock",
"InteractionBlock", "InteractionBlock",
"NonLinearReadoutBlock", "NonLinearReadoutBlock",
"PolynomialCutoff", "PolynomialCutoff",
"BesselBasis", "BesselBasis",
"GaussianBasis", "GaussianBasis",
"MACE", "MACE",
"ScaleShiftMACE", "ScaleShiftMACE",
"AtomicDipolesMACE", "AtomicDipolesMACE",
"EnergyDipolesMACE", "EnergyDipolesMACE",
"WeightedEnergyForcesLoss", "WeightedEnergyForcesLoss",
"WeightedForcesLoss", "WeightedForcesLoss",
"WeightedEnergyForcesVirialsLoss", "WeightedEnergyForcesVirialsLoss",
"WeightedEnergyForcesStressLoss", "WeightedEnergyForcesStressLoss",
"DipoleSingleLoss", "DipoleSingleLoss",
"WeightedEnergyForcesDipoleLoss", "WeightedEnergyForcesDipoleLoss",
"WeightedHuberEnergyForcesStressLoss", "WeightedHuberEnergyForcesStressLoss",
"UniversalLoss", "UniversalLoss",
"WeightedEnergyForcesL1L2Loss", "WeightedEnergyForcesL1L2Loss",
"SymmetricContraction", "SymmetricContraction",
"interaction_classes", "interaction_classes",
"compute_mean_std_atomic_inter_energy", "compute_mean_std_atomic_inter_energy",
"compute_avg_num_neighbors", "compute_avg_num_neighbors",
"compute_statistics", "compute_statistics",
"compute_fixed_charge_dipole", "compute_fixed_charge_dipole",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment