Unverified Commit 73ff4f3a authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent fb246ae0
from collections.abc import Sequence
from typing import List
import numpy as np
import torch
from torch import Tensor
from .data import Data
from .dataset import IndexType
class Batch(Data):
r"""A plain old python object modeling a batch of graphs as one big
(disconnected) graph. With :class:`torch_geometric.data.Data` being the
base class, all its methods can also be used here.
In addition, single graphs can be reconstructed via the assignment vector
:obj:`batch`, which maps each node to its respective graph identifier.
"""
def __init__(self, batch=None, ptr=None, **kwargs):
super(Batch, self).__init__(**kwargs)
for key, item in kwargs.items():
if key == "num_nodes":
self.__num_nodes__ = item
else:
self[key] = item
self.batch = batch
self.ptr = ptr
self.__data_class__ = Data
self.__slices__ = None
self.__cumsum__ = None
self.__cat_dims__ = None
self.__num_nodes_list__ = None
self.__num_graphs__ = None
@classmethod
def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]):
r"""Constructs a batch object from a python list holding
:class:`torch_geometric.data.Data` objects.
The assignment vector :obj:`batch` is created on the fly.
Additionally, creates assignment batch vectors for each key in
:obj:`follow_batch`.
Will exclude any keys given in :obj:`exclude_keys`."""
keys = list(set(data_list[0].keys) - set(exclude_keys))
assert "batch" not in keys and "ptr" not in keys
batch = cls()
for key in data_list[0].__dict__.keys():
if key[:2] != "__" and key[-2:] != "__":
batch[key] = None
batch.__num_graphs__ = len(data_list)
batch.__data_class__ = data_list[0].__class__
for key in keys + ["batch"]:
batch[key] = []
batch["ptr"] = [0]
device = None
slices = {key: [0] for key in keys}
cumsum = {key: [0] for key in keys}
cat_dims = {}
num_nodes_list = []
for i, data in enumerate(data_list):
for key in keys:
item = data[key]
# Increase values by `cumsum` value.
cum = cumsum[key][-1]
if isinstance(item, Tensor) and item.dtype != torch.bool:
if not isinstance(cum, int) or cum != 0:
item = item + cum
elif isinstance(item, (int, float)):
item = item + cum
# Gather the size of the `cat` dimension.
size = 1
cat_dim = data.__cat_dim__(key, data[key])
# 0-dimensional tensors have no dimension along which to
# concatenate, so we set `cat_dim` to `None`.
if isinstance(item, Tensor) and item.dim() == 0:
cat_dim = None
cat_dims[key] = cat_dim
# Add a batch dimension to items whose `cat_dim` is `None`:
if isinstance(item, Tensor) and cat_dim is None:
cat_dim = 0 # Concatenate along this new batch dimension.
item = item.unsqueeze(0)
device = item.device
elif isinstance(item, Tensor):
size = item.size(cat_dim)
device = item.device
batch[key].append(item) # Append item to the attribute list.
slices[key].append(size + slices[key][-1])
inc = data.__inc__(key, item)
if isinstance(inc, (tuple, list)):
inc = torch.tensor(inc)
cumsum[key].append(inc + cumsum[key][-1])
if key in follow_batch:
if isinstance(size, Tensor):
for j, size in enumerate(size.tolist()):
tmp = f"{key}_{j}_batch"
batch[tmp] = [] if i == 0 else batch[tmp]
batch[tmp].append(
torch.full((size,), i, dtype=torch.long, device=device)
)
else:
tmp = f"{key}_batch"
batch[tmp] = [] if i == 0 else batch[tmp]
batch[tmp].append(
torch.full((size,), i, dtype=torch.long, device=device)
)
if hasattr(data, "__num_nodes__"):
num_nodes_list.append(data.__num_nodes__)
else:
num_nodes_list.append(None)
num_nodes = data.num_nodes
if num_nodes is not None:
item = torch.full((num_nodes,), i, dtype=torch.long, device=device)
batch.batch.append(item)
batch.ptr.append(batch.ptr[-1] + num_nodes)
batch.batch = None if len(batch.batch) == 0 else batch.batch
batch.ptr = None if len(batch.ptr) == 1 else batch.ptr
batch.__slices__ = slices
batch.__cumsum__ = cumsum
batch.__cat_dims__ = cat_dims
batch.__num_nodes_list__ = num_nodes_list
ref_data = data_list[0]
for key in batch.keys:
items = batch[key]
item = items[0]
cat_dim = ref_data.__cat_dim__(key, item)
cat_dim = 0 if cat_dim is None else cat_dim
if isinstance(item, Tensor):
batch[key] = torch.cat(items, cat_dim)
elif isinstance(item, (int, float)):
batch[key] = torch.tensor(items)
# if torch_geometric.is_debug_enabled():
# batch.debug()
return batch.contiguous()
def get_example(self, idx: int) -> Data:
r"""Reconstructs the :class:`torch_geometric.data.Data` object at index
:obj:`idx` from the batch object.
The batch object must have been created via :meth:`from_data_list` in
order to be able to reconstruct the initial objects."""
if self.__slices__ is None:
raise RuntimeError(
(
"Cannot reconstruct data list from batch because the batch "
"object was not created using `Batch.from_data_list()`."
)
)
data = self.__data_class__()
idx = self.num_graphs + idx if idx < 0 else idx
for key in self.__slices__.keys():
item = self[key]
if self.__cat_dims__[key] is None:
# The item was concatenated along a new batch dimension,
# so just index in that dimension:
item = item[idx]
else:
# Narrow the item based on the values in `__slices__`.
if isinstance(item, Tensor):
dim = self.__cat_dims__[key]
start = self.__slices__[key][idx]
end = self.__slices__[key][idx + 1]
item = item.narrow(dim, start, end - start)
else:
start = self.__slices__[key][idx]
end = self.__slices__[key][idx + 1]
item = item[start:end]
item = item[0] if len(item) == 1 else item
# Decrease its value by `cumsum` value:
cum = self.__cumsum__[key][idx]
if isinstance(item, Tensor):
if not isinstance(cum, int) or cum != 0:
item = item - cum
elif isinstance(item, (int, float)):
item = item - cum
data[key] = item
if self.__num_nodes_list__[idx] is not None:
data.num_nodes = self.__num_nodes_list__[idx]
return data
def index_select(self, idx: IndexType) -> List[Data]:
if isinstance(idx, slice):
idx = list(range(self.num_graphs)[idx])
elif isinstance(idx, Tensor) and idx.dtype == torch.long:
idx = idx.flatten().tolist()
elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist()
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
idx = idx.flatten().tolist()
elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
idx = idx.flatten().nonzero()[0].flatten().tolist()
elif isinstance(idx, Sequence) and not isinstance(idx, str):
pass
else:
raise IndexError(
f"Only integers, slices (':'), list, tuples, torch.tensor and "
f"np.ndarray of dtype long or bool are valid indices (got "
f"'{type(idx).__name__}')"
)
return [self.get_example(i) for i in idx]
def __getitem__(self, idx):
if isinstance(idx, str):
return super(Batch, self).__getitem__(idx)
elif isinstance(idx, (int, np.integer)):
return self.get_example(idx)
else:
return self.index_select(idx)
def to_data_list(self) -> List[Data]:
r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
from the batch object.
The batch object must have been created via :meth:`from_data_list` in
order to be able to reconstruct the initial objects."""
return [self.get_example(i) for i in range(self.num_graphs)]
@property
def num_graphs(self) -> int:
"""Returns the number of graphs in the batch."""
if self.__num_graphs__ is not None:
return self.__num_graphs__
elif self.ptr is not None:
return self.ptr.numel() - 1
elif self.batch is not None:
return int(self.batch.max()) + 1
else:
raise ValueError
import collections
import copy
import re
import torch
# from ..utils.num_nodes import maybe_num_nodes
__num_nodes_warn_msg__ = (
"The number of nodes in your data object can only be inferred by its {} "
"indices, and hence may result in unexpected batch-wise behavior, e.g., "
"in case there exists isolated nodes. Please consider explicitly setting "
"the number of nodes for this data object by assigning it to "
"data.num_nodes."
)
def size_repr(key, item, indent=0):
indent_str = " " * indent
if torch.is_tensor(item) and item.dim() == 0:
out = item.item()
elif torch.is_tensor(item):
out = str(list(item.size()))
elif isinstance(item, list) or isinstance(item, tuple):
out = str([len(item)])
elif isinstance(item, dict):
lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()]
out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}"
elif isinstance(item, str):
out = f'"{item}"'
else:
out = str(item)
return f"{indent_str}{key}={out}"
class Data(object):
r"""A plain old python object modeling a single graph with various
(optional) attributes:
Args:
x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
num_node_features]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in COO format
with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y (Tensor, optional): Graph or node targets with arbitrary shape.
(default: :obj:`None`)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
normal (Tensor, optional): Normal vector matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
face (LongTensor, optional): Face adjacency matrix with shape
:obj:`[3, num_faces]`. (default: :obj:`None`)
The data object is not restricted to these attributes and can be extended
by any other additional data.
Example::
data = Data(x=x, edge_index=edge_index)
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)
"""
def __init__(
self,
x=None,
edge_index=None,
edge_attr=None,
y=None,
pos=None,
normal=None,
face=None,
**kwargs,
):
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
self.pos = pos
self.normal = normal
self.face = face
for key, item in kwargs.items():
if key == "num_nodes":
self.__num_nodes__ = item
else:
self[key] = item
if edge_index is not None and edge_index.dtype != torch.long:
raise ValueError(
(
f"Argument `edge_index` needs to be of type `torch.long` but "
f"found type `{edge_index.dtype}`."
)
)
if face is not None and face.dtype != torch.long:
raise ValueError(
(
f"Argument `face` needs to be of type `torch.long` but found "
f"type `{face.dtype}`."
)
)
@classmethod
def from_dict(cls, dictionary):
r"""Creates a data object from a python dictionary."""
data = cls()
for key, item in dictionary.items():
data[key] = item
return data
def to_dict(self):
return {key: item for key, item in self}
def to_namedtuple(self):
keys = self.keys
DataTuple = collections.namedtuple("DataTuple", keys)
return DataTuple(*[self[key] for key in keys])
def __getitem__(self, key):
r"""Gets the data of the attribute :obj:`key`."""
return getattr(self, key, None)
def __setitem__(self, key, value):
"""Sets the attribute :obj:`key` to :obj:`value`."""
setattr(self, key, value)
def __delitem__(self, key):
r"""Delete the data of the attribute :obj:`key`."""
return delattr(self, key)
@property
def keys(self):
r"""Returns all names of graph attributes."""
keys = [key for key in self.__dict__.keys() if self[key] is not None]
keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
return keys
def __len__(self):
r"""Returns the number of all present attributes."""
return len(self.keys)
def __contains__(self, key):
r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the
data."""
return key in self.keys
def __iter__(self):
r"""Iterates over all present attributes in the data, yielding their
attribute names and content."""
for key in sorted(self.keys):
yield key, self[key]
def __call__(self, *keys):
r"""Iterates over all attributes :obj:`*keys` in the data, yielding
their attribute names and content.
If :obj:`*keys` is not given this method will iterative over all
present attributes."""
for key in sorted(self.keys) if not keys else keys:
if key in self:
yield key, self[key]
def __cat_dim__(self, key, value):
r"""Returns the dimension for which :obj:`value` of attribute
:obj:`key` will get concatenated when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
if bool(re.search("(index|face)", key)):
return -1
return 0
def __inc__(self, key, value):
r"""Returns the incremental count to cumulatively increase the value
of the next attribute of :obj:`key` when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
# Only `*index*` and `*face*` attributes should be cumulatively summed
# up when creating batches.
return self.num_nodes if bool(re.search("(index|face)", key)) else 0
@property
def num_nodes(self):
r"""Returns or sets the number of nodes in the graph.
.. note::
The number of nodes in your data object is typically automatically
inferred, *e.g.*, when node features :obj:`x` are present.
In some cases however, a graph may only be given by its edge
indices :obj:`edge_index`.
PyTorch Geometric then *guesses* the number of nodes
according to :obj:`edge_index.max().item() + 1`, but in case there
exists isolated nodes, this number has not to be correct and can
therefore result in unexpected batch-wise behavior.
Thus, we recommend to set the number of nodes in your data object
explicitly via :obj:`data.num_nodes = ...`.
You will be given a warning that requests you to do so.
"""
if hasattr(self, "__num_nodes__"):
return self.__num_nodes__
for key, item in self("x", "pos", "normal", "batch"):
return item.size(self.__cat_dim__(key, item))
if hasattr(self, "adj"):
return self.adj.size(0)
if hasattr(self, "adj_t"):
return self.adj_t.size(1)
# if self.face is not None:
# logging.warning(__num_nodes_warn_msg__.format("face"))
# return maybe_num_nodes(self.face)
# if self.edge_index is not None:
# logging.warning(__num_nodes_warn_msg__.format("edge"))
# return maybe_num_nodes(self.edge_index)
return None
@num_nodes.setter
def num_nodes(self, num_nodes):
self.__num_nodes__ = num_nodes
@property
def num_edges(self):
"""
Returns the number of edges in the graph.
For undirected graphs, this will return the number of bi-directional
edges, which is double the amount of unique edges.
"""
for key, item in self("edge_index", "edge_attr"):
return item.size(self.__cat_dim__(key, item))
for key, item in self("adj", "adj_t"):
return item.nnz()
return None
@property
def num_faces(self):
r"""Returns the number of faces in the mesh."""
if self.face is not None:
return self.face.size(self.__cat_dim__("face", self.face))
return None
@property
def num_node_features(self):
r"""Returns the number of features per node in the graph."""
if self.x is None:
return 0
return 1 if self.x.dim() == 1 else self.x.size(1)
@property
def num_features(self):
r"""Alias for :py:attr:`~num_node_features`."""
return self.num_node_features
@property
def num_edge_features(self):
r"""Returns the number of features per edge in the graph."""
if self.edge_attr is None:
return 0
return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1)
def __apply__(self, item, func):
if torch.is_tensor(item):
return func(item)
elif isinstance(item, (tuple, list)):
return [self.__apply__(v, func) for v in item]
elif isinstance(item, dict):
return {k: self.__apply__(v, func) for k, v in item.items()}
else:
return item
def apply(self, func, *keys):
r"""Applies the function :obj:`func` to all tensor attributes
:obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to
all present attributes.
"""
for key, item in self(*keys):
self[key] = self.__apply__(item, func)
return self
def contiguous(self, *keys):
r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`.
If :obj:`*keys` is not given, all present attributes are ensured to
have a contiguous memory layout."""
return self.apply(lambda x: x.contiguous(), *keys)
def to(self, device, *keys, **kwargs):
r"""Performs tensor dtype and/or device conversion to all attributes
:obj:`*keys`.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return self.apply(lambda x: x.to(device, **kwargs), *keys)
def cpu(self, *keys):
r"""Copies all attributes :obj:`*keys` to CPU memory.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return self.apply(lambda x: x.cpu(), *keys)
def cuda(self, device=None, non_blocking=False, *keys):
r"""Copies all attributes :obj:`*keys` to CUDA memory.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return self.apply(
lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys
)
def clone(self):
r"""Performs a deep-copy of the data object."""
return self.__class__.from_dict(
{
k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v)
for k, v in self.__dict__.items()
}
)
def pin_memory(self, *keys):
r"""Copies all attributes :obj:`*keys` to pinned memory.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return self.apply(lambda x: x.pin_memory(), *keys)
def debug(self):
if self.edge_index is not None:
if self.edge_index.dtype != torch.long:
raise RuntimeError(
(
"Expected edge indices of dtype {}, but found dtype " " {}"
).format(torch.long, self.edge_index.dtype)
)
if self.face is not None:
if self.face.dtype != torch.long:
raise RuntimeError(
(
"Expected face indices of dtype {}, but found dtype " " {}"
).format(torch.long, self.face.dtype)
)
if self.edge_index is not None:
if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2:
raise RuntimeError(
(
"Edge indices should have shape [2, num_edges] but found"
" shape {}"
).format(self.edge_index.size())
)
if self.edge_index is not None and self.num_nodes is not None:
if self.edge_index.numel() > 0:
min_index = self.edge_index.min()
max_index = self.edge_index.max()
else:
min_index = max_index = 0
if min_index < 0 or max_index > self.num_nodes - 1:
raise RuntimeError(
(
"Edge indices must lay in the interval [0, {}]"
" but found them in the interval [{}, {}]"
).format(self.num_nodes - 1, min_index, max_index)
)
if self.face is not None:
if self.face.dim() != 2 or self.face.size(0) != 3:
raise RuntimeError(
(
"Face indices should have shape [3, num_faces] but found"
" shape {}"
).format(self.face.size())
)
if self.face is not None and self.num_nodes is not None:
if self.face.numel() > 0:
min_index = self.face.min()
max_index = self.face.max()
else:
min_index = max_index = 0
if min_index < 0 or max_index > self.num_nodes - 1:
raise RuntimeError(
(
"Face indices must lay in the interval [0, {}]"
" but found them in the interval [{}, {}]"
).format(self.num_nodes - 1, min_index, max_index)
)
if self.edge_index is not None and self.edge_attr is not None:
if self.edge_index.size(1) != self.edge_attr.size(0):
raise RuntimeError(
(
"Edge indices and edge attributes hold a differing "
"number of edges, found {} and {}"
).format(self.edge_index.size(), self.edge_attr.size())
)
if self.x is not None and self.num_nodes is not None:
if self.x.size(0) != self.num_nodes:
raise RuntimeError(
(
"Node features should hold {} elements in the first "
"dimension but found {}"
).format(self.num_nodes, self.x.size(0))
)
if self.pos is not None and self.num_nodes is not None:
if self.pos.size(0) != self.num_nodes:
raise RuntimeError(
(
"Node positions should hold {} elements in the first "
"dimension but found {}"
).format(self.num_nodes, self.pos.size(0))
)
if self.normal is not None and self.num_nodes is not None:
if self.normal.size(0) != self.num_nodes:
raise RuntimeError(
(
"Node normals should hold {} elements in the first "
"dimension but found {}"
).format(self.num_nodes, self.normal.size(0))
)
def __repr__(self):
cls = str(self.__class__.__name__)
has_dict = any([isinstance(item, dict) for _, item in self])
if not has_dict:
info = [size_repr(key, item) for key, item in self]
return "{}({})".format(cls, ", ".join(info))
else:
info = [size_repr(key, item, indent=2) for key, item in self]
return "{}(\n{}\n)".format(cls, ",\n".join(info))
from collections.abc import Mapping, Sequence
from typing import List, Optional, Union
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from .batch import Batch
from .data import Data
from .dataset import Dataset
class Collater:
def __init__(self, follow_batch, exclude_keys):
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
def __call__(self, batch):
elem = batch[0]
if isinstance(elem, Data):
return Batch.from_data_list(
batch,
follow_batch=self.follow_batch,
exclude_keys=self.exclude_keys,
)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, Mapping):
return {key: self([data[key] for data in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
return type(elem)(*(self(s) for s in zip(*batch)))
elif isinstance(elem, Sequence) and not isinstance(elem, str):
return [self(s) for s in zip(*batch)]
raise TypeError(f"DataLoader found invalid type: {type(elem)}")
def collate(self, batch): # Deprecated...
return self(batch)
class DataLoader(torch.utils.data.DataLoader):
r"""A data loader which merges data objects from a
:class:`torch_geometric.data.Dataset` to a mini-batch.
Data objects can be either of type :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData`.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
follow_batch (List[str], optional): Creates assignment batch
vectors for each key in the list. (default: :obj:`None`)
exclude_keys (List[str], optional): Will exclude each key in the
list. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 1,
shuffle: bool = False,
follow_batch: Optional[List[str]] = [None],
exclude_keys: Optional[List[str]] = [None],
**kwargs,
):
if "collate_fn" in kwargs:
del kwargs["collate_fn"]
# Save for PyTorch Lightning < 1.6:
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
super().__init__(
dataset,
batch_size,
shuffle,
collate_fn=Collater(follow_batch, exclude_keys),
**kwargs,
)
import copy
import os.path as osp
import re
import warnings
from collections.abc import Sequence
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch.utils.data
from torch import Tensor
from .data import Data
from .utils import makedirs
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating graph datasets.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_dataset.html>`__ for the accompanying tutorial.
Args:
root (string, optional): Root directory where the dataset should be
saved. (optional: :obj:`None`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
@property
def raw_file_names(self) -> Union[str, List[str], Tuple]:
r"""The name of the files to find in the :obj:`self.raw_dir` folder in
order to skip the download."""
raise NotImplementedError
@property
def processed_file_names(self) -> Union[str, List[str], Tuple]:
r"""The name of the files to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
raise NotImplementedError
def download(self):
r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
raise NotImplementedError
def process(self):
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError
def len(self) -> int:
raise NotImplementedError
def get(self, idx: int) -> Data:
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
def __init__(
self,
root: Optional[str] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):
super().__init__()
if isinstance(root, str):
root = osp.expanduser(osp.normpath(root))
self.root = root
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self._indices: Optional[Sequence] = None
if "download" in self.__class__.__dict__.keys():
self._download()
if "process" in self.__class__.__dict__.keys():
self._process()
def indices(self) -> Sequence:
return range(self.len()) if self._indices is None else self._indices
@property
def raw_dir(self) -> str:
return osp.join(self.root, "raw")
@property
def processed_dir(self) -> str:
return osp.join(self.root, "processed")
@property
def num_node_features(self) -> int:
r"""Returns the number of features per node in the dataset."""
data = self[0]
if hasattr(data, "num_node_features"):
return data.num_node_features
raise AttributeError(
f"'{data.__class__.__name__}' object has no "
f"attribute 'num_node_features'"
)
@property
def num_features(self) -> int:
r"""Alias for :py:attr:`~num_node_features`."""
return self.num_node_features
@property
def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the dataset."""
data = self[0]
if hasattr(data, "num_edge_features"):
return data.num_edge_features
raise AttributeError(
f"'{data.__class__.__name__}' object has no "
f"attribute 'num_edge_features'"
)
@property
def raw_paths(self) -> List[str]:
r"""The filepaths to find in order to skip the download."""
files = to_list(self.raw_file_names)
return [osp.join(self.raw_dir, f) for f in files]
@property
def processed_paths(self) -> List[str]:
r"""The filepaths to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
files = to_list(self.processed_file_names)
return [osp.join(self.processed_dir, f) for f in files]
def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
def _process(self):
f = osp.join(self.processed_dir, "pre_transform.pt")
if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
warnings.warn(
f"The `pre_transform` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-processing technique, make sure to "
f"sure to delete '{self.processed_dir}' first"
)
f = osp.join(self.processed_dir, "pre_filter.pt")
if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
warnings.warn(
"The `pre_filter` argument differs from the one used in the "
"pre-processed version of this dataset. If you want to make "
"use of another pre-fitering technique, make sure to delete "
"'{self.processed_dir}' first"
)
if files_exist(self.processed_paths): # pragma: no cover
return
print("Processing...")
makedirs(self.processed_dir)
self.process()
path = osp.join(self.processed_dir, "pre_transform.pt")
torch.save(_repr(self.pre_transform), path)
path = osp.join(self.processed_dir, "pre_filter.pt")
torch.save(_repr(self.pre_filter), path)
print("Done!")
def __len__(self) -> int:
r"""The number of examples in the dataset."""
return len(self.indices())
def __getitem__(
self,
idx: Union[int, np.integer, IndexType],
) -> Union["Dataset", Data]:
r"""In case :obj:`idx` is of type integer, will return the data object
at index :obj:`idx` (and transforms it in case :obj:`transform` is
present).
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy
:obj:`np.array`, will return a subset of the dataset at the specified
indices."""
if (
isinstance(idx, (int, np.integer))
or (isinstance(idx, Tensor) and idx.dim() == 0)
or (isinstance(idx, np.ndarray) and np.isscalar(idx))
):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
def index_select(self, idx: IndexType) -> "Dataset":
indices = self.indices()
if isinstance(idx, slice):
indices = indices[idx]
elif isinstance(idx, Tensor) and idx.dtype == torch.long:
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
idx = idx.flatten().nonzero(as_tuple=False)
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
idx = idx.flatten().nonzero()[0]
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, Sequence) and not isinstance(idx, str):
indices = [indices[i] for i in idx]
else:
raise IndexError(
f"Only integers, slices (':'), list, tuples, torch.tensor and "
f"np.ndarray of dtype long or bool are valid indices (got "
f"'{type(idx).__name__}')"
)
dataset = copy.copy(self)
dataset._indices = indices
return dataset
def shuffle(
self,
return_perm: bool = False,
) -> Union["Dataset", Tuple["Dataset", Tensor]]:
r"""Randomly shuffles the examples in the dataset.
Args:
return_perm (bool, optional): If set to :obj:`True`, will return
the random permutation used to shuffle the dataset in addition.
(default: :obj:`False`)
"""
perm = torch.randperm(len(self))
dataset = self.index_select(perm)
return (dataset, perm) if return_perm is True else dataset
def __repr__(self) -> str:
arg_repr = str(len(self)) if len(self) > 1 else ""
return f"{self.__class__.__name__}({arg_repr})"
def to_list(value: Any) -> Sequence:
if isinstance(value, Sequence) and not isinstance(value, str):
return value
else:
return [value]
def files_exist(files: List[str]) -> bool:
# NOTE: We return `False` in case `files` is empty, leading to a
# re-processing of files on every instantiation.
return len(files) != 0 and all([osp.exists(f) for f in files])
def _repr(obj: Any) -> str:
if obj is None:
return "None"
return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__())
import random
import numpy as np
import torch
def seed_everything(seed: int):
r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`,
:obj:`numpy` and Python.
Args:
seed (int): The desired seed.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
import os
import os.path as osp
import ssl
import urllib
import zipfile
def makedirs(dir):
os.makedirs(dir, exist_ok=True)
def download_url(url, folder, log=True):
r"""Downloads the content of an URL to a specific folder.
Args:
url (string): The url.
folder (string): The folder.
log (bool, optional): If :obj:`False`, will not print anything to the
console. (default: :obj:`True`)
"""
filename = url.rpartition("/")[2].split("?")[0]
path = osp.join(folder, filename)
if osp.exists(path): # pragma: no cover
if log:
print("Using exist file", filename)
return path
if log:
print("Downloading", url)
makedirs(folder)
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
with open(path, "wb") as f:
f.write(data.read())
return path
def extract_zip(path, folder, log=True):
r"""Extracts a zip archive to a specific folder.
Args:
path (string): The path to the tar archive.
folder (string): The folder.
log (bool, optional): If :obj:`False`, will not print anything to the
console. (default: :obj:`True`)
"""
with zipfile.ZipFile(path, "r") as f:
f.extractall(folder)
###########################################################################################
# Tools for torch
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import logging
from contextlib import contextmanager
from typing import Dict, Union
import numpy as np
import torch
from e3nn.io import CartesianTensor
TensorDict = Dict[str, torch.Tensor]
def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor:
"""
Generates one-hot encoding with <num_classes> classes from <indices>
:param indices: (N x 1) tensor
:param num_classes: number of classes
:param device: torch device
:return: (N x num_classes) tensor
"""
shape = indices.shape[:-1] + (num_classes,)
oh = torch.zeros(shape, device=indices.device).view(shape)
# scatter_ is the in-place version of scatter
oh.scatter_(dim=-1, index=indices, value=1)
return oh.view(*shape)
def count_parameters(module: torch.nn.Module) -> int:
return int(sum(np.prod(p.shape) for p in module.parameters()))
def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict:
return {k: v.to(device) if v is not None else None for k, v in td.items()}
def set_seeds(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
def to_numpy(t: torch.Tensor) -> np.ndarray:
return t.cpu().detach().numpy()
def init_device(device_str: str) -> torch.device:
if "cuda" in device_str:
assert torch.cuda.is_available(), "No CUDA device available!"
if ":" in device_str:
# Check if the desired device is available
assert int(device_str.split(":")[-1]) < torch.cuda.device_count()
logging.info(
f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}"
)
torch.cuda.init()
return torch.device(device_str)
if device_str == "mps":
assert torch.backends.mps.is_available(), "No MPS backend is available!"
logging.info("Using MPS GPU acceleration")
return torch.device("mps")
if device_str == "xpu":
torch.xpu.is_available()
return torch.device("xpu")
logging.info("Using CPU")
return torch.device("cpu")
dtype_dict = {"float32": torch.float32, "float64": torch.float64}
def set_default_dtype(dtype: str) -> None:
torch.set_default_dtype(dtype_dict[dtype])
def spherical_to_cartesian(t: torch.Tensor):
"""
Convert spherical notation to cartesian notation
"""
stress_cart_tensor = CartesianTensor("ij=ji")
stress_rtp = stress_cart_tensor.reduced_tensor_products()
return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp)
def cartesian_to_spherical(t: torch.Tensor):
"""
Convert cartesian notation to spherical notation
"""
stress_cart_tensor = CartesianTensor("ij=ji")
stress_rtp = stress_cart_tensor.reduced_tensor_products()
return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp)
def voigt_to_matrix(t: torch.Tensor):
"""
Convert voigt notation to matrix notation
:param t: (6,) tensor or (3, 3) tensor or (9,) tensor
:return: (3, 3) tensor
"""
if t.shape == (3, 3):
return t
if t.shape == (6,):
return torch.tensor(
[
[t[0], t[5], t[4]],
[t[5], t[1], t[3]],
[t[4], t[3], t[2]],
],
dtype=t.dtype,
)
if t.shape == (9,):
return t.view(3, 3)
raise ValueError(
f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}"
)
def init_wandb(project: str, entity: str, name: str, config: dict, directory: str):
import wandb
wandb.init(
project=project,
entity=entity,
name=name,
config=config,
dir=directory,
resume="allow",
)
@contextmanager
def default_dtype(dtype: Union[torch.dtype, str]):
"""Context manager for configuring the default_dtype used by torch
Args:
dtype (torch.dtype|str): the default dtype to use within this context manager
"""
init = torch.get_default_dtype()
if isinstance(dtype, str):
set_default_dtype(dtype)
else:
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(init)
###########################################################################################
# Training script
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import dataclasses
import logging
import time
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel
from torch.optim import LBFGS
from torch.optim.swa_utils import SWALR, AveragedModel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch_ema import ExponentialMovingAverage
from torchmetrics import Metric
from mace.cli.visualise_train import TrainingPlotter
from . import torch_geometric
from .checkpoint import CheckpointHandler, CheckpointState
from .torch_tools import to_numpy
from .utils import (
MetricsLogger,
compute_mae,
compute_q95,
compute_rel_mae,
compute_rel_rmse,
compute_rmse,
)
@dataclasses.dataclass
class SWAContainer:
model: AveragedModel
scheduler: SWALR
start: int
loss_fn: torch.nn.Module
def valid_err_log(
valid_loss,
eval_metrics,
logger,
log_errors,
epoch=None,
valid_loader_name="Default",
):
eval_metrics["mode"] = "eval"
eval_metrics["epoch"] = epoch
eval_metrics["head"] = valid_loader_name
logger.log(eval_metrics)
if epoch is None:
inintial_phrase = "Initial"
else:
inintial_phrase = f"Epoch {epoch}"
if log_errors == "PerAtomRMSE":
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
and eval_metrics["rmse_stress"] is not None
):
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_stress = eval_metrics["rmse_stress"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3",
)
elif (
log_errors == "PerAtomRMSEstressvirials"
and eval_metrics["rmse_virials_per_atom"] is not None
):
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV",
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_stress_per_atom"] is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_virials_per_atom"] is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_virials = eval_metrics["mae_virials"] * 1e3
logging.info(
f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV"
)
elif log_errors == "TotalRMSE":
error_e = eval_metrics["rmse_e"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A",
)
elif log_errors == "PerAtomMAE":
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
)
elif log_errors == "TotalMAE":
error_e = eval_metrics["mae_e"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
)
elif log_errors == "DipoleRMSE":
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye",
)
elif log_errors == "EnergyDipoleRMSE":
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye",
)
def train(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
train_loader: DataLoader,
valid_loaders: Dict[str, DataLoader],
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.ExponentialLR,
start_epoch: int,
max_num_epochs: int,
patience: int,
checkpoint_handler: CheckpointHandler,
logger: MetricsLogger,
eval_interval: int,
output_args: Dict[str, bool],
device: torch.device,
log_errors: str,
swa: Optional[SWAContainer] = None,
ema: Optional[ExponentialMovingAverage] = None,
max_grad_norm: Optional[float] = 10.0,
log_wandb: bool = False,
distributed: bool = False,
save_all_checkpoints: bool = False,
plotter: TrainingPlotter = None,
distributed_model: Optional[DistributedDataParallel] = None,
train_sampler: Optional[DistributedSampler] = None,
rank: Optional[int] = 0,
):
lowest_loss = np.inf
valid_loss = np.inf
patience_counter = 0
swa_start = True
keep_last = False
if log_wandb:
import wandb
if max_grad_norm is not None:
logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}")
logging.info("")
logging.info("===========TRAINING===========")
logging.info("Started training, reporting errors on validation set")
logging.info("Loss metrics on validation set")
epoch = start_epoch
# log validation loss before _any_ training
for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
model=model,
loss_fn=loss_fn,
data_loader=valid_loader,
output_args=output_args,
device=device,
)
valid_err_log(
valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name
)
valid_loss = valid_loss_head # consider only the last head for the checkpoint
while epoch < max_num_epochs:
# LR scheduler and SWA update
if swa is None or epoch < swa.start:
if epoch > start_epoch:
lr_scheduler.step(
metrics=valid_loss
) # Can break if exponential LR, TODO fix that!
else:
if swa_start:
logging.info("Changing loss based on Stage Two Weights")
lowest_loss = np.inf
swa_start = False
keep_last = True
loss_fn = swa.loss_fn
swa.model.update_parameters(model)
if epoch > start_epoch:
swa.scheduler.step()
# Train
if distributed:
train_sampler.set_epoch(epoch)
if "ScheduleFree" in type(optimizer).__name__:
optimizer.train()
train_one_epoch(
model=model,
loss_fn=loss_fn,
data_loader=train_loader,
optimizer=optimizer,
epoch=epoch,
output_args=output_args,
max_grad_norm=max_grad_norm,
ema=ema,
logger=logger,
device=device,
distributed=distributed,
distributed_model=distributed_model,
rank=rank,
)
if distributed:
torch.distributed.barrier()
# Validate
if epoch % eval_interval == 0:
model_to_evaluate = (
model if distributed_model is None else distributed_model
)
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
if "ScheduleFree" in type(optimizer).__name__:
optimizer.eval()
with param_context:
wandb_log_dict = {}
for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
model=model_to_evaluate,
loss_fn=loss_fn,
data_loader=valid_loader,
output_args=output_args,
device=device,
)
if rank == 0:
valid_err_log(
valid_loss_head,
eval_metrics,
logger,
log_errors,
epoch,
valid_loader_name,
)
if log_wandb:
wandb_log_dict[valid_loader_name] = {
"epoch": epoch,
"valid_loss": valid_loss_head,
"valid_rmse_e_per_atom": eval_metrics[
"rmse_e_per_atom"
],
"valid_rmse_f": eval_metrics["rmse_f"],
}
if plotter and epoch % plotter.plot_frequency == 0:
try:
plotter.plot(epoch, model_to_evaluate, rank)
except Exception as e: # pylint: disable=broad-except
logging.debug(f"Plotting failed: {e}")
valid_loss = (
valid_loss_head # consider only the last head for the checkpoint
)
if log_wandb:
wandb.log(wandb_log_dict)
if rank == 0:
if valid_loss >= lowest_loss:
patience_counter += 1
if patience_counter >= patience:
if swa is not None and epoch < swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two"
)
epoch = swa.start
else:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement"
)
break
if save_all_checkpoints:
param_context = (
ema.average_parameters()
if ema is not None
else nullcontext()
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=True,
)
else:
lowest_loss = valid_loss
patience_counter = 0
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
)
keep_last = False or save_all_checkpoints
if distributed:
torch.distributed.barrier()
epoch += 1
logging.info("Training complete")
def train_one_epoch(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
data_loader: DataLoader,
optimizer: torch.optim.Optimizer,
epoch: int,
output_args: Dict[str, bool],
max_grad_norm: Optional[float],
ema: Optional[ExponentialMovingAverage],
logger: MetricsLogger,
device: torch.device,
distributed: bool,
distributed_model: Optional[DistributedDataParallel] = None,
rank: Optional[int] = 0,
) -> None:
model_to_train = model if distributed_model is None else distributed_model
if isinstance(optimizer, LBFGS):
_, opt_metrics = take_step_lbfgs(
model=model_to_train,
loss_fn=loss_fn,
data_loader=data_loader,
optimizer=optimizer,
ema=ema,
output_args=output_args,
max_grad_norm=max_grad_norm,
device=device,
distributed=distributed,
rank=rank,
)
opt_metrics["mode"] = "opt"
opt_metrics["epoch"] = epoch
if rank == 0:
logger.log(opt_metrics)
else:
for batch in data_loader:
_, opt_metrics = take_step(
model=model_to_train,
loss_fn=loss_fn,
batch=batch,
optimizer=optimizer,
ema=ema,
output_args=output_args,
max_grad_norm=max_grad_norm,
device=device,
)
opt_metrics["mode"] = "opt"
opt_metrics["epoch"] = epoch
if rank == 0:
logger.log(opt_metrics)
def take_step(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
batch: torch_geometric.batch.Batch,
optimizer: torch.optim.Optimizer,
ema: Optional[ExponentialMovingAverage],
output_args: Dict[str, bool],
max_grad_norm: Optional[float],
device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
start_time = time.time()
batch = batch.to(device)
batch_dict = batch.to_dict()
def closure():
optimizer.zero_grad(set_to_none=True)
output = model(
batch_dict,
training=True,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
loss = loss_fn(pred=output, ref=batch)
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
return loss
loss = closure()
optimizer.step()
if ema is not None:
ema.update()
loss_dict = {
"loss": to_numpy(loss),
"time": time.time() - start_time,
}
return loss, loss_dict
def take_step_lbfgs(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
data_loader: DataLoader,
optimizer: torch.optim.Optimizer,
ema: Optional[ExponentialMovingAverage],
output_args: Dict[str, bool],
max_grad_norm: Optional[float],
device: torch.device,
distributed: bool,
rank: int,
) -> Tuple[float, Dict[str, Any]]:
start_time = time.time()
logging.debug(
f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB"
)
total_sample_count = 0
for batch in data_loader:
total_sample_count += batch.num_graphs
if distributed:
global_sample_count = torch.tensor(total_sample_count, device=device)
torch.distributed.all_reduce(
global_sample_count, op=torch.distributed.ReduceOp.SUM
)
total_sample_count = global_sample_count.item()
signal = torch.zeros(1, device=device) if distributed else None
def closure():
if distributed:
if rank == 0:
signal.fill_(1)
torch.distributed.broadcast(signal, src=0)
for param in model.parameters():
torch.distributed.broadcast(param.data, src=0)
optimizer.zero_grad(set_to_none=True)
total_loss = torch.tensor(0.0, device=device)
# Process each batch and then collect the results we pass to the optimizer
for batch in data_loader:
batch = batch.to(device)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=True,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
batch_loss = loss_fn(pred=output, ref=batch)
batch_loss = batch_loss * (batch.num_graphs / total_sample_count)
batch_loss.backward()
total_loss += batch_loss
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
if distributed:
torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM)
return total_loss
if distributed:
if rank == 0:
loss = optimizer.step(closure)
signal.fill_(0)
torch.distributed.broadcast(signal, src=0)
else:
while True:
# Other ranks wait for signals from rank 0
torch.distributed.broadcast(signal, src=0)
if signal.item() == 0:
break
if signal.item() == 1:
loss = closure()
for param in model.parameters():
torch.distributed.broadcast(param.data, src=0)
else:
loss = optimizer.step(closure)
if ema is not None:
ema.update()
loss_dict = {
"loss": to_numpy(loss),
"time": time.time() - start_time,
}
return loss, loss_dict
def evaluate(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
data_loader: DataLoader,
output_args: Dict[str, bool],
device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
for param in model.parameters():
param.requires_grad = False
metrics = MACELoss(loss_fn=loss_fn).to(device)
start_time = time.time()
for batch in data_loader:
batch = batch.to(device)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=False,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
avg_loss, aux = metrics(batch, output)
avg_loss, aux = metrics.compute()
aux["time"] = time.time() - start_time
metrics.reset()
for param in model.parameters():
param.requires_grad = True
return avg_loss, aux
class MACELoss(Metric):
def __init__(self, loss_fn: torch.nn.Module):
super().__init__()
self.loss_fn = loss_fn
self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("delta_es", default=[], dist_reduce_fx="cat")
self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("fs", default=[], dist_reduce_fx="cat")
self.add_state("delta_fs", default=[], dist_reduce_fx="cat")
self.add_state(
"stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
self.add_state("delta_stress", default=[], dist_reduce_fx="cat")
self.add_state(
"virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
self.add_state("delta_virials", default=[], dist_reduce_fx="cat")
self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat")
self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("mus", default=[], dist_reduce_fx="cat")
self.add_state("delta_mus", default=[], dist_reduce_fx="cat")
self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat")
def update(self, batch, output): # pylint: disable=arguments-differ
loss = self.loss_fn(pred=output, ref=batch)
self.total_loss += loss
self.num_data += batch.num_graphs
if output.get("energy") is not None and batch.energy is not None:
self.E_computed += 1.0
self.delta_es.append(batch.energy - output["energy"])
self.delta_es_per_atom.append(
(batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1])
)
if output.get("forces") is not None and batch.forces is not None:
self.Fs_computed += 1.0
self.fs.append(batch.forces)
self.delta_fs.append(batch.forces - output["forces"])
if output.get("stress") is not None and batch.stress is not None:
self.stress_computed += 1.0
self.delta_stress.append(batch.stress - output["stress"])
if output.get("virials") is not None and batch.virials is not None:
self.virials_computed += 1.0
self.delta_virials.append(batch.virials - output["virials"])
self.delta_virials_per_atom.append(
(batch.virials - output["virials"])
/ (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1)
)
if output.get("dipole") is not None and batch.dipole is not None:
self.Mus_computed += 1.0
self.mus.append(batch.dipole)
self.delta_mus.append(batch.dipole - output["dipole"])
self.delta_mus_per_atom.append(
(batch.dipole - output["dipole"])
/ (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1)
)
def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray:
if isinstance(delta, list):
delta = torch.cat(delta)
return to_numpy(delta)
def compute(self):
aux = {}
aux["loss"] = to_numpy(self.total_loss / self.num_data).item()
if self.E_computed:
delta_es = self.convert(self.delta_es)
delta_es_per_atom = self.convert(self.delta_es_per_atom)
aux["mae_e"] = compute_mae(delta_es)
aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom)
aux["rmse_e"] = compute_rmse(delta_es)
aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom)
aux["q95_e"] = compute_q95(delta_es)
if self.Fs_computed:
fs = self.convert(self.fs)
delta_fs = self.convert(self.delta_fs)
aux["mae_f"] = compute_mae(delta_fs)
aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs)
aux["rmse_f"] = compute_rmse(delta_fs)
aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs)
aux["q95_f"] = compute_q95(delta_fs)
if self.stress_computed:
delta_stress = self.convert(self.delta_stress)
aux["mae_stress"] = compute_mae(delta_stress)
aux["rmse_stress"] = compute_rmse(delta_stress)
aux["q95_stress"] = compute_q95(delta_stress)
if self.virials_computed:
delta_virials = self.convert(self.delta_virials)
delta_virials_per_atom = self.convert(self.delta_virials_per_atom)
aux["mae_virials"] = compute_mae(delta_virials)
aux["rmse_virials"] = compute_rmse(delta_virials)
aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom)
aux["q95_virials"] = compute_q95(delta_virials)
if self.Mus_computed:
mus = self.convert(self.mus)
delta_mus = self.convert(self.delta_mus)
delta_mus_per_atom = self.convert(self.delta_mus_per_atom)
aux["mae_mu"] = compute_mae(delta_mus)
aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom)
aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus)
aux["rmse_mu"] = compute_rmse(delta_mus)
aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom)
aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus)
aux["q95_mu"] = compute_q95(delta_mus)
return aux["loss"], aux
###########################################################################################
# Statistics utilities
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import json
import logging
import os
import sys
from typing import Any, Dict, Iterable, Optional, Sequence, Union
import numpy as np
import torch
from .torch_tools import to_numpy
def compute_mae(delta: np.ndarray) -> float:
return np.mean(np.abs(delta)).item()
def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float:
target_norm = np.mean(np.abs(target_val))
return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100
def compute_rmse(delta: np.ndarray) -> float:
return np.sqrt(np.mean(np.square(delta))).item()
def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float:
target_norm = np.sqrt(np.mean(np.square(target_val))).item()
return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100
def compute_q95(delta: np.ndarray) -> float:
return np.percentile(np.abs(delta), q=95)
def compute_c(delta: np.ndarray, eta: float) -> float:
return np.mean(np.abs(delta) < eta).item()
def get_tag(name: str, seed: int) -> str:
return f"{name}_run-{seed}"
def setup_logger(
level: Union[int, str] = logging.INFO,
tag: Optional[str] = None,
directory: Optional[str] = None,
rank: Optional[int] = 0,
):
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels
# Create formatters
formatter = logging.Formatter(
"%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Add filter for rank
logger.addFilter(lambda _: rank == 0)
# Create console handler
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(level)
ch.setFormatter(formatter)
logger.addHandler(ch)
if directory is not None and tag is not None:
os.makedirs(name=directory, exist_ok=True)
# Create file handler for non-debug logs
main_log_path = os.path.join(directory, f"{tag}.log")
fh_main = logging.FileHandler(main_log_path)
fh_main.setLevel(level)
fh_main.setFormatter(formatter)
logger.addHandler(fh_main)
# Create file handler for debug logs
debug_log_path = os.path.join(directory, f"{tag}_debug.log")
fh_debug = logging.FileHandler(debug_log_path)
fh_debug.setLevel(logging.DEBUG)
fh_debug.setFormatter(formatter)
fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG)
logger.addHandler(fh_debug)
class AtomicNumberTable:
def __init__(self, zs: Sequence[int]):
self.zs = zs
def __len__(self) -> int:
return len(self.zs)
def __str__(self):
return f"AtomicNumberTable: {tuple(s for s in self.zs)}"
def index_to_z(self, index: int) -> int:
return self.zs[index]
def z_to_index(self, atomic_number: str) -> int:
return self.zs.index(atomic_number)
def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable:
z_set = set()
for z in zs:
z_set.add(z)
return AtomicNumberTable(sorted(list(z_set)))
def atomic_numbers_to_indices(
atomic_numbers: np.ndarray, z_table: AtomicNumberTable
) -> np.ndarray:
to_index_fn = np.vectorize(z_table.z_to_index)
return to_index_fn(atomic_numbers)
class UniversalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, torch.Tensor):
return to_numpy(o)
return json.JSONEncoder.default(self, o)
class MetricsLogger:
def __init__(self, directory: str, tag: str) -> None:
self.directory = directory
self.filename = tag + ".txt"
self.path = os.path.join(self.directory, self.filename)
def log(self, d: Dict[str, Any]) -> None:
os.makedirs(name=self.directory, exist_ok=True)
with open(self.path, mode="a", encoding="utf-8") as f:
f.write(json.dumps(d, cls=UniversalEncoder))
f.write("\n")
# pylint: disable=abstract-method, arguments-differ
class LAMMPS_MP(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
feats, data = args # unpack
ctx.vec_len = feats.shape[-1]
ctx.data = data
out = torch.empty_like(feats)
data.forward_exchange(feats, out, ctx.vec_len)
return out
@staticmethod
def backward(ctx, *grad_outputs):
(grad,) = grad_outputs # unpack
gout = torch.empty_like(grad)
ctx.data.reverse_exchange(grad, gout, ctx.vec_len)
return gout, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment