Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
nivren
ICT-CSP
Commits
73ff4f3a
Unverified
Commit
73ff4f3a
authored
Aug 24, 2025
by
zcxzcx1
Committed by
GitHub
Aug 24, 2025
Browse files
Add files via upload
parent
fb246ae0
Changes
89
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
2124 additions
and
0 deletions
+2124
-0
mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py
mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py
+257
-0
mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py
mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py
+441
-0
mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py
...ch/3rdparty/mace/mace/tools/torch_geometric/dataloader.py
+87
-0
mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py
...bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py
+280
-0
mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py
mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py
+17
-0
mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py
mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py
+54
-0
mace-bench/3rdparty/mace/mace/tools/torch_tools.py
mace-bench/3rdparty/mace/mace/tools/torch_tools.py
+153
-0
mace-bench/3rdparty/mace/mace/tools/train.py
mace-bench/3rdparty/mace/mace/tools/train.py
+669
-0
mace-bench/3rdparty/mace/mace/tools/utils.py
mace-bench/3rdparty/mace/mace/tools/utils.py
+166
-0
No files found.
mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py
0 → 100644
View file @
73ff4f3a
from
collections.abc
import
Sequence
from
typing
import
List
import
numpy
as
np
import
torch
from
torch
import
Tensor
from
.data
import
Data
from
.dataset
import
IndexType
class
Batch
(
Data
):
r
"""A plain old python object modeling a batch of graphs as one big
(disconnected) graph. With :class:`torch_geometric.data.Data` being the
base class, all its methods can also be used here.
In addition, single graphs can be reconstructed via the assignment vector
:obj:`batch`, which maps each node to its respective graph identifier.
"""
def
__init__
(
self
,
batch
=
None
,
ptr
=
None
,
**
kwargs
):
super
(
Batch
,
self
).
__init__
(
**
kwargs
)
for
key
,
item
in
kwargs
.
items
():
if
key
==
"num_nodes"
:
self
.
__num_nodes__
=
item
else
:
self
[
key
]
=
item
self
.
batch
=
batch
self
.
ptr
=
ptr
self
.
__data_class__
=
Data
self
.
__slices__
=
None
self
.
__cumsum__
=
None
self
.
__cat_dims__
=
None
self
.
__num_nodes_list__
=
None
self
.
__num_graphs__
=
None
@
classmethod
def
from_data_list
(
cls
,
data_list
,
follow_batch
=
[],
exclude_keys
=
[]):
r
"""Constructs a batch object from a python list holding
:class:`torch_geometric.data.Data` objects.
The assignment vector :obj:`batch` is created on the fly.
Additionally, creates assignment batch vectors for each key in
:obj:`follow_batch`.
Will exclude any keys given in :obj:`exclude_keys`."""
keys
=
list
(
set
(
data_list
[
0
].
keys
)
-
set
(
exclude_keys
))
assert
"batch"
not
in
keys
and
"ptr"
not
in
keys
batch
=
cls
()
for
key
in
data_list
[
0
].
__dict__
.
keys
():
if
key
[:
2
]
!=
"__"
and
key
[
-
2
:]
!=
"__"
:
batch
[
key
]
=
None
batch
.
__num_graphs__
=
len
(
data_list
)
batch
.
__data_class__
=
data_list
[
0
].
__class__
for
key
in
keys
+
[
"batch"
]:
batch
[
key
]
=
[]
batch
[
"ptr"
]
=
[
0
]
device
=
None
slices
=
{
key
:
[
0
]
for
key
in
keys
}
cumsum
=
{
key
:
[
0
]
for
key
in
keys
}
cat_dims
=
{}
num_nodes_list
=
[]
for
i
,
data
in
enumerate
(
data_list
):
for
key
in
keys
:
item
=
data
[
key
]
# Increase values by `cumsum` value.
cum
=
cumsum
[
key
][
-
1
]
if
isinstance
(
item
,
Tensor
)
and
item
.
dtype
!=
torch
.
bool
:
if
not
isinstance
(
cum
,
int
)
or
cum
!=
0
:
item
=
item
+
cum
elif
isinstance
(
item
,
(
int
,
float
)):
item
=
item
+
cum
# Gather the size of the `cat` dimension.
size
=
1
cat_dim
=
data
.
__cat_dim__
(
key
,
data
[
key
])
# 0-dimensional tensors have no dimension along which to
# concatenate, so we set `cat_dim` to `None`.
if
isinstance
(
item
,
Tensor
)
and
item
.
dim
()
==
0
:
cat_dim
=
None
cat_dims
[
key
]
=
cat_dim
# Add a batch dimension to items whose `cat_dim` is `None`:
if
isinstance
(
item
,
Tensor
)
and
cat_dim
is
None
:
cat_dim
=
0
# Concatenate along this new batch dimension.
item
=
item
.
unsqueeze
(
0
)
device
=
item
.
device
elif
isinstance
(
item
,
Tensor
):
size
=
item
.
size
(
cat_dim
)
device
=
item
.
device
batch
[
key
].
append
(
item
)
# Append item to the attribute list.
slices
[
key
].
append
(
size
+
slices
[
key
][
-
1
])
inc
=
data
.
__inc__
(
key
,
item
)
if
isinstance
(
inc
,
(
tuple
,
list
)):
inc
=
torch
.
tensor
(
inc
)
cumsum
[
key
].
append
(
inc
+
cumsum
[
key
][
-
1
])
if
key
in
follow_batch
:
if
isinstance
(
size
,
Tensor
):
for
j
,
size
in
enumerate
(
size
.
tolist
()):
tmp
=
f
"
{
key
}
_
{
j
}
_batch"
batch
[
tmp
]
=
[]
if
i
==
0
else
batch
[
tmp
]
batch
[
tmp
].
append
(
torch
.
full
((
size
,),
i
,
dtype
=
torch
.
long
,
device
=
device
)
)
else
:
tmp
=
f
"
{
key
}
_batch"
batch
[
tmp
]
=
[]
if
i
==
0
else
batch
[
tmp
]
batch
[
tmp
].
append
(
torch
.
full
((
size
,),
i
,
dtype
=
torch
.
long
,
device
=
device
)
)
if
hasattr
(
data
,
"__num_nodes__"
):
num_nodes_list
.
append
(
data
.
__num_nodes__
)
else
:
num_nodes_list
.
append
(
None
)
num_nodes
=
data
.
num_nodes
if
num_nodes
is
not
None
:
item
=
torch
.
full
((
num_nodes
,),
i
,
dtype
=
torch
.
long
,
device
=
device
)
batch
.
batch
.
append
(
item
)
batch
.
ptr
.
append
(
batch
.
ptr
[
-
1
]
+
num_nodes
)
batch
.
batch
=
None
if
len
(
batch
.
batch
)
==
0
else
batch
.
batch
batch
.
ptr
=
None
if
len
(
batch
.
ptr
)
==
1
else
batch
.
ptr
batch
.
__slices__
=
slices
batch
.
__cumsum__
=
cumsum
batch
.
__cat_dims__
=
cat_dims
batch
.
__num_nodes_list__
=
num_nodes_list
ref_data
=
data_list
[
0
]
for
key
in
batch
.
keys
:
items
=
batch
[
key
]
item
=
items
[
0
]
cat_dim
=
ref_data
.
__cat_dim__
(
key
,
item
)
cat_dim
=
0
if
cat_dim
is
None
else
cat_dim
if
isinstance
(
item
,
Tensor
):
batch
[
key
]
=
torch
.
cat
(
items
,
cat_dim
)
elif
isinstance
(
item
,
(
int
,
float
)):
batch
[
key
]
=
torch
.
tensor
(
items
)
# if torch_geometric.is_debug_enabled():
# batch.debug()
return
batch
.
contiguous
()
def
get_example
(
self
,
idx
:
int
)
->
Data
:
r
"""Reconstructs the :class:`torch_geometric.data.Data` object at index
:obj:`idx` from the batch object.
The batch object must have been created via :meth:`from_data_list` in
order to be able to reconstruct the initial objects."""
if
self
.
__slices__
is
None
:
raise
RuntimeError
(
(
"Cannot reconstruct data list from batch because the batch "
"object was not created using `Batch.from_data_list()`."
)
)
data
=
self
.
__data_class__
()
idx
=
self
.
num_graphs
+
idx
if
idx
<
0
else
idx
for
key
in
self
.
__slices__
.
keys
():
item
=
self
[
key
]
if
self
.
__cat_dims__
[
key
]
is
None
:
# The item was concatenated along a new batch dimension,
# so just index in that dimension:
item
=
item
[
idx
]
else
:
# Narrow the item based on the values in `__slices__`.
if
isinstance
(
item
,
Tensor
):
dim
=
self
.
__cat_dims__
[
key
]
start
=
self
.
__slices__
[
key
][
idx
]
end
=
self
.
__slices__
[
key
][
idx
+
1
]
item
=
item
.
narrow
(
dim
,
start
,
end
-
start
)
else
:
start
=
self
.
__slices__
[
key
][
idx
]
end
=
self
.
__slices__
[
key
][
idx
+
1
]
item
=
item
[
start
:
end
]
item
=
item
[
0
]
if
len
(
item
)
==
1
else
item
# Decrease its value by `cumsum` value:
cum
=
self
.
__cumsum__
[
key
][
idx
]
if
isinstance
(
item
,
Tensor
):
if
not
isinstance
(
cum
,
int
)
or
cum
!=
0
:
item
=
item
-
cum
elif
isinstance
(
item
,
(
int
,
float
)):
item
=
item
-
cum
data
[
key
]
=
item
if
self
.
__num_nodes_list__
[
idx
]
is
not
None
:
data
.
num_nodes
=
self
.
__num_nodes_list__
[
idx
]
return
data
def
index_select
(
self
,
idx
:
IndexType
)
->
List
[
Data
]:
if
isinstance
(
idx
,
slice
):
idx
=
list
(
range
(
self
.
num_graphs
)[
idx
])
elif
isinstance
(
idx
,
Tensor
)
and
idx
.
dtype
==
torch
.
long
:
idx
=
idx
.
flatten
().
tolist
()
elif
isinstance
(
idx
,
Tensor
)
and
idx
.
dtype
==
torch
.
bool
:
idx
=
idx
.
flatten
().
nonzero
(
as_tuple
=
False
).
flatten
().
tolist
()
elif
isinstance
(
idx
,
np
.
ndarray
)
and
idx
.
dtype
==
np
.
int64
:
idx
=
idx
.
flatten
().
tolist
()
elif
isinstance
(
idx
,
np
.
ndarray
)
and
idx
.
dtype
==
np
.
bool
:
idx
=
idx
.
flatten
().
nonzero
()[
0
].
flatten
().
tolist
()
elif
isinstance
(
idx
,
Sequence
)
and
not
isinstance
(
idx
,
str
):
pass
else
:
raise
IndexError
(
f
"Only integers, slices (':'), list, tuples, torch.tensor and "
f
"np.ndarray of dtype long or bool are valid indices (got "
f
"'
{
type
(
idx
).
__name__
}
')"
)
return
[
self
.
get_example
(
i
)
for
i
in
idx
]
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
str
):
return
super
(
Batch
,
self
).
__getitem__
(
idx
)
elif
isinstance
(
idx
,
(
int
,
np
.
integer
)):
return
self
.
get_example
(
idx
)
else
:
return
self
.
index_select
(
idx
)
def
to_data_list
(
self
)
->
List
[
Data
]:
r
"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
from the batch object.
The batch object must have been created via :meth:`from_data_list` in
order to be able to reconstruct the initial objects."""
return
[
self
.
get_example
(
i
)
for
i
in
range
(
self
.
num_graphs
)]
@
property
def
num_graphs
(
self
)
->
int
:
"""Returns the number of graphs in the batch."""
if
self
.
__num_graphs__
is
not
None
:
return
self
.
__num_graphs__
elif
self
.
ptr
is
not
None
:
return
self
.
ptr
.
numel
()
-
1
elif
self
.
batch
is
not
None
:
return
int
(
self
.
batch
.
max
())
+
1
else
:
raise
ValueError
mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py
0 → 100644
View file @
73ff4f3a
import
collections
import
copy
import
re
import
torch
# from ..utils.num_nodes import maybe_num_nodes
__num_nodes_warn_msg__
=
(
"The number of nodes in your data object can only be inferred by its {} "
"indices, and hence may result in unexpected batch-wise behavior, e.g., "
"in case there exists isolated nodes. Please consider explicitly setting "
"the number of nodes for this data object by assigning it to "
"data.num_nodes."
)
def
size_repr
(
key
,
item
,
indent
=
0
):
indent_str
=
" "
*
indent
if
torch
.
is_tensor
(
item
)
and
item
.
dim
()
==
0
:
out
=
item
.
item
()
elif
torch
.
is_tensor
(
item
):
out
=
str
(
list
(
item
.
size
()))
elif
isinstance
(
item
,
list
)
or
isinstance
(
item
,
tuple
):
out
=
str
([
len
(
item
)])
elif
isinstance
(
item
,
dict
):
lines
=
[
indent_str
+
size_repr
(
k
,
v
,
2
)
for
k
,
v
in
item
.
items
()]
out
=
"{
\n
"
+
",
\n
"
.
join
(
lines
)
+
"
\n
"
+
indent_str
+
"}"
elif
isinstance
(
item
,
str
):
out
=
f
'"
{
item
}
"'
else
:
out
=
str
(
item
)
return
f
"
{
indent_str
}{
key
}
=
{
out
}
"
class
Data
(
object
):
r
"""A plain old python object modeling a single graph with various
(optional) attributes:
Args:
x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
num_node_features]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in COO format
with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y (Tensor, optional): Graph or node targets with arbitrary shape.
(default: :obj:`None`)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
normal (Tensor, optional): Normal vector matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
face (LongTensor, optional): Face adjacency matrix with shape
:obj:`[3, num_faces]`. (default: :obj:`None`)
The data object is not restricted to these attributes and can be extended
by any other additional data.
Example::
data = Data(x=x, edge_index=edge_index)
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)
"""
def
__init__
(
self
,
x
=
None
,
edge_index
=
None
,
edge_attr
=
None
,
y
=
None
,
pos
=
None
,
normal
=
None
,
face
=
None
,
**
kwargs
,
):
self
.
x
=
x
self
.
edge_index
=
edge_index
self
.
edge_attr
=
edge_attr
self
.
y
=
y
self
.
pos
=
pos
self
.
normal
=
normal
self
.
face
=
face
for
key
,
item
in
kwargs
.
items
():
if
key
==
"num_nodes"
:
self
.
__num_nodes__
=
item
else
:
self
[
key
]
=
item
if
edge_index
is
not
None
and
edge_index
.
dtype
!=
torch
.
long
:
raise
ValueError
(
(
f
"Argument `edge_index` needs to be of type `torch.long` but "
f
"found type `
{
edge_index
.
dtype
}
`."
)
)
if
face
is
not
None
and
face
.
dtype
!=
torch
.
long
:
raise
ValueError
(
(
f
"Argument `face` needs to be of type `torch.long` but found "
f
"type `
{
face
.
dtype
}
`."
)
)
@
classmethod
def
from_dict
(
cls
,
dictionary
):
r
"""Creates a data object from a python dictionary."""
data
=
cls
()
for
key
,
item
in
dictionary
.
items
():
data
[
key
]
=
item
return
data
def
to_dict
(
self
):
return
{
key
:
item
for
key
,
item
in
self
}
def
to_namedtuple
(
self
):
keys
=
self
.
keys
DataTuple
=
collections
.
namedtuple
(
"DataTuple"
,
keys
)
return
DataTuple
(
*
[
self
[
key
]
for
key
in
keys
])
def
__getitem__
(
self
,
key
):
r
"""Gets the data of the attribute :obj:`key`."""
return
getattr
(
self
,
key
,
None
)
def
__setitem__
(
self
,
key
,
value
):
"""Sets the attribute :obj:`key` to :obj:`value`."""
setattr
(
self
,
key
,
value
)
def
__delitem__
(
self
,
key
):
r
"""Delete the data of the attribute :obj:`key`."""
return
delattr
(
self
,
key
)
@
property
def
keys
(
self
):
r
"""Returns all names of graph attributes."""
keys
=
[
key
for
key
in
self
.
__dict__
.
keys
()
if
self
[
key
]
is
not
None
]
keys
=
[
key
for
key
in
keys
if
key
[:
2
]
!=
"__"
and
key
[
-
2
:]
!=
"__"
]
return
keys
def
__len__
(
self
):
r
"""Returns the number of all present attributes."""
return
len
(
self
.
keys
)
def
__contains__
(
self
,
key
):
r
"""Returns :obj:`True`, if the attribute :obj:`key` is present in the
data."""
return
key
in
self
.
keys
def
__iter__
(
self
):
r
"""Iterates over all present attributes in the data, yielding their
attribute names and content."""
for
key
in
sorted
(
self
.
keys
):
yield
key
,
self
[
key
]
def
__call__
(
self
,
*
keys
):
r
"""Iterates over all attributes :obj:`*keys` in the data, yielding
their attribute names and content.
If :obj:`*keys` is not given this method will iterative over all
present attributes."""
for
key
in
sorted
(
self
.
keys
)
if
not
keys
else
keys
:
if
key
in
self
:
yield
key
,
self
[
key
]
def
__cat_dim__
(
self
,
key
,
value
):
r
"""Returns the dimension for which :obj:`value` of attribute
:obj:`key` will get concatenated when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
if
bool
(
re
.
search
(
"(index|face)"
,
key
)):
return
-
1
return
0
def
__inc__
(
self
,
key
,
value
):
r
"""Returns the incremental count to cumulatively increase the value
of the next attribute of :obj:`key` when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
# Only `*index*` and `*face*` attributes should be cumulatively summed
# up when creating batches.
return
self
.
num_nodes
if
bool
(
re
.
search
(
"(index|face)"
,
key
))
else
0
@
property
def
num_nodes
(
self
):
r
"""Returns or sets the number of nodes in the graph.
.. note::
The number of nodes in your data object is typically automatically
inferred, *e.g.*, when node features :obj:`x` are present.
In some cases however, a graph may only be given by its edge
indices :obj:`edge_index`.
PyTorch Geometric then *guesses* the number of nodes
according to :obj:`edge_index.max().item() + 1`, but in case there
exists isolated nodes, this number has not to be correct and can
therefore result in unexpected batch-wise behavior.
Thus, we recommend to set the number of nodes in your data object
explicitly via :obj:`data.num_nodes = ...`.
You will be given a warning that requests you to do so.
"""
if
hasattr
(
self
,
"__num_nodes__"
):
return
self
.
__num_nodes__
for
key
,
item
in
self
(
"x"
,
"pos"
,
"normal"
,
"batch"
):
return
item
.
size
(
self
.
__cat_dim__
(
key
,
item
))
if
hasattr
(
self
,
"adj"
):
return
self
.
adj
.
size
(
0
)
if
hasattr
(
self
,
"adj_t"
):
return
self
.
adj_t
.
size
(
1
)
# if self.face is not None:
# logging.warning(__num_nodes_warn_msg__.format("face"))
# return maybe_num_nodes(self.face)
# if self.edge_index is not None:
# logging.warning(__num_nodes_warn_msg__.format("edge"))
# return maybe_num_nodes(self.edge_index)
return
None
@
num_nodes
.
setter
def
num_nodes
(
self
,
num_nodes
):
self
.
__num_nodes__
=
num_nodes
@
property
def
num_edges
(
self
):
"""
Returns the number of edges in the graph.
For undirected graphs, this will return the number of bi-directional
edges, which is double the amount of unique edges.
"""
for
key
,
item
in
self
(
"edge_index"
,
"edge_attr"
):
return
item
.
size
(
self
.
__cat_dim__
(
key
,
item
))
for
key
,
item
in
self
(
"adj"
,
"adj_t"
):
return
item
.
nnz
()
return
None
@
property
def
num_faces
(
self
):
r
"""Returns the number of faces in the mesh."""
if
self
.
face
is
not
None
:
return
self
.
face
.
size
(
self
.
__cat_dim__
(
"face"
,
self
.
face
))
return
None
@
property
def
num_node_features
(
self
):
r
"""Returns the number of features per node in the graph."""
if
self
.
x
is
None
:
return
0
return
1
if
self
.
x
.
dim
()
==
1
else
self
.
x
.
size
(
1
)
@
property
def
num_features
(
self
):
r
"""Alias for :py:attr:`~num_node_features`."""
return
self
.
num_node_features
@
property
def
num_edge_features
(
self
):
r
"""Returns the number of features per edge in the graph."""
if
self
.
edge_attr
is
None
:
return
0
return
1
if
self
.
edge_attr
.
dim
()
==
1
else
self
.
edge_attr
.
size
(
1
)
def
__apply__
(
self
,
item
,
func
):
if
torch
.
is_tensor
(
item
):
return
func
(
item
)
elif
isinstance
(
item
,
(
tuple
,
list
)):
return
[
self
.
__apply__
(
v
,
func
)
for
v
in
item
]
elif
isinstance
(
item
,
dict
):
return
{
k
:
self
.
__apply__
(
v
,
func
)
for
k
,
v
in
item
.
items
()}
else
:
return
item
def
apply
(
self
,
func
,
*
keys
):
r
"""Applies the function :obj:`func` to all tensor attributes
:obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to
all present attributes.
"""
for
key
,
item
in
self
(
*
keys
):
self
[
key
]
=
self
.
__apply__
(
item
,
func
)
return
self
def
contiguous
(
self
,
*
keys
):
r
"""Ensures a contiguous memory layout for all attributes :obj:`*keys`.
If :obj:`*keys` is not given, all present attributes are ensured to
have a contiguous memory layout."""
return
self
.
apply
(
lambda
x
:
x
.
contiguous
(),
*
keys
)
def
to
(
self
,
device
,
*
keys
,
**
kwargs
):
r
"""Performs tensor dtype and/or device conversion to all attributes
:obj:`*keys`.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return
self
.
apply
(
lambda
x
:
x
.
to
(
device
,
**
kwargs
),
*
keys
)
def
cpu
(
self
,
*
keys
):
r
"""Copies all attributes :obj:`*keys` to CPU memory.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return
self
.
apply
(
lambda
x
:
x
.
cpu
(),
*
keys
)
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
,
*
keys
):
r
"""Copies all attributes :obj:`*keys` to CUDA memory.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return
self
.
apply
(
lambda
x
:
x
.
cuda
(
device
=
device
,
non_blocking
=
non_blocking
),
*
keys
)
def
clone
(
self
):
r
"""Performs a deep-copy of the data object."""
return
self
.
__class__
.
from_dict
(
{
k
:
v
.
clone
()
if
torch
.
is_tensor
(
v
)
else
copy
.
deepcopy
(
v
)
for
k
,
v
in
self
.
__dict__
.
items
()
}
)
def
pin_memory
(
self
,
*
keys
):
r
"""Copies all attributes :obj:`*keys` to pinned memory.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return
self
.
apply
(
lambda
x
:
x
.
pin_memory
(),
*
keys
)
def
debug
(
self
):
if
self
.
edge_index
is
not
None
:
if
self
.
edge_index
.
dtype
!=
torch
.
long
:
raise
RuntimeError
(
(
"Expected edge indices of dtype {}, but found dtype "
" {}"
).
format
(
torch
.
long
,
self
.
edge_index
.
dtype
)
)
if
self
.
face
is
not
None
:
if
self
.
face
.
dtype
!=
torch
.
long
:
raise
RuntimeError
(
(
"Expected face indices of dtype {}, but found dtype "
" {}"
).
format
(
torch
.
long
,
self
.
face
.
dtype
)
)
if
self
.
edge_index
is
not
None
:
if
self
.
edge_index
.
dim
()
!=
2
or
self
.
edge_index
.
size
(
0
)
!=
2
:
raise
RuntimeError
(
(
"Edge indices should have shape [2, num_edges] but found"
" shape {}"
).
format
(
self
.
edge_index
.
size
())
)
if
self
.
edge_index
is
not
None
and
self
.
num_nodes
is
not
None
:
if
self
.
edge_index
.
numel
()
>
0
:
min_index
=
self
.
edge_index
.
min
()
max_index
=
self
.
edge_index
.
max
()
else
:
min_index
=
max_index
=
0
if
min_index
<
0
or
max_index
>
self
.
num_nodes
-
1
:
raise
RuntimeError
(
(
"Edge indices must lay in the interval [0, {}]"
" but found them in the interval [{}, {}]"
).
format
(
self
.
num_nodes
-
1
,
min_index
,
max_index
)
)
if
self
.
face
is
not
None
:
if
self
.
face
.
dim
()
!=
2
or
self
.
face
.
size
(
0
)
!=
3
:
raise
RuntimeError
(
(
"Face indices should have shape [3, num_faces] but found"
" shape {}"
).
format
(
self
.
face
.
size
())
)
if
self
.
face
is
not
None
and
self
.
num_nodes
is
not
None
:
if
self
.
face
.
numel
()
>
0
:
min_index
=
self
.
face
.
min
()
max_index
=
self
.
face
.
max
()
else
:
min_index
=
max_index
=
0
if
min_index
<
0
or
max_index
>
self
.
num_nodes
-
1
:
raise
RuntimeError
(
(
"Face indices must lay in the interval [0, {}]"
" but found them in the interval [{}, {}]"
).
format
(
self
.
num_nodes
-
1
,
min_index
,
max_index
)
)
if
self
.
edge_index
is
not
None
and
self
.
edge_attr
is
not
None
:
if
self
.
edge_index
.
size
(
1
)
!=
self
.
edge_attr
.
size
(
0
):
raise
RuntimeError
(
(
"Edge indices and edge attributes hold a differing "
"number of edges, found {} and {}"
).
format
(
self
.
edge_index
.
size
(),
self
.
edge_attr
.
size
())
)
if
self
.
x
is
not
None
and
self
.
num_nodes
is
not
None
:
if
self
.
x
.
size
(
0
)
!=
self
.
num_nodes
:
raise
RuntimeError
(
(
"Node features should hold {} elements in the first "
"dimension but found {}"
).
format
(
self
.
num_nodes
,
self
.
x
.
size
(
0
))
)
if
self
.
pos
is
not
None
and
self
.
num_nodes
is
not
None
:
if
self
.
pos
.
size
(
0
)
!=
self
.
num_nodes
:
raise
RuntimeError
(
(
"Node positions should hold {} elements in the first "
"dimension but found {}"
).
format
(
self
.
num_nodes
,
self
.
pos
.
size
(
0
))
)
if
self
.
normal
is
not
None
and
self
.
num_nodes
is
not
None
:
if
self
.
normal
.
size
(
0
)
!=
self
.
num_nodes
:
raise
RuntimeError
(
(
"Node normals should hold {} elements in the first "
"dimension but found {}"
).
format
(
self
.
num_nodes
,
self
.
normal
.
size
(
0
))
)
def
__repr__
(
self
):
cls
=
str
(
self
.
__class__
.
__name__
)
has_dict
=
any
([
isinstance
(
item
,
dict
)
for
_
,
item
in
self
])
if
not
has_dict
:
info
=
[
size_repr
(
key
,
item
)
for
key
,
item
in
self
]
return
"{}({})"
.
format
(
cls
,
", "
.
join
(
info
))
else
:
info
=
[
size_repr
(
key
,
item
,
indent
=
2
)
for
key
,
item
in
self
]
return
"{}(
\n
{}
\n
)"
.
format
(
cls
,
",
\n
"
.
join
(
info
))
mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py
0 → 100644
View file @
73ff4f3a
from
collections.abc
import
Mapping
,
Sequence
from
typing
import
List
,
Optional
,
Union
import
torch.utils.data
from
torch.utils.data.dataloader
import
default_collate
from
.batch
import
Batch
from
.data
import
Data
from
.dataset
import
Dataset
class
Collater
:
def
__init__
(
self
,
follow_batch
,
exclude_keys
):
self
.
follow_batch
=
follow_batch
self
.
exclude_keys
=
exclude_keys
def
__call__
(
self
,
batch
):
elem
=
batch
[
0
]
if
isinstance
(
elem
,
Data
):
return
Batch
.
from_data_list
(
batch
,
follow_batch
=
self
.
follow_batch
,
exclude_keys
=
self
.
exclude_keys
,
)
elif
isinstance
(
elem
,
torch
.
Tensor
):
return
default_collate
(
batch
)
elif
isinstance
(
elem
,
float
):
return
torch
.
tensor
(
batch
,
dtype
=
torch
.
float
)
elif
isinstance
(
elem
,
int
):
return
torch
.
tensor
(
batch
)
elif
isinstance
(
elem
,
str
):
return
batch
elif
isinstance
(
elem
,
Mapping
):
return
{
key
:
self
([
data
[
key
]
for
data
in
batch
])
for
key
in
elem
}
elif
isinstance
(
elem
,
tuple
)
and
hasattr
(
elem
,
"_fields"
):
return
type
(
elem
)(
*
(
self
(
s
)
for
s
in
zip
(
*
batch
)))
elif
isinstance
(
elem
,
Sequence
)
and
not
isinstance
(
elem
,
str
):
return
[
self
(
s
)
for
s
in
zip
(
*
batch
)]
raise
TypeError
(
f
"DataLoader found invalid type:
{
type
(
elem
)
}
"
)
def
collate
(
self
,
batch
):
# Deprecated...
return
self
(
batch
)
class
DataLoader
(
torch
.
utils
.
data
.
DataLoader
):
r
"""A data loader which merges data objects from a
:class:`torch_geometric.data.Dataset` to a mini-batch.
Data objects can be either of type :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData`.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
follow_batch (List[str], optional): Creates assignment batch
vectors for each key in the list. (default: :obj:`None`)
exclude_keys (List[str], optional): Will exclude each key in the
list. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def
__init__
(
self
,
dataset
:
Dataset
,
batch_size
:
int
=
1
,
shuffle
:
bool
=
False
,
follow_batch
:
Optional
[
List
[
str
]]
=
[
None
],
exclude_keys
:
Optional
[
List
[
str
]]
=
[
None
],
**
kwargs
,
):
if
"collate_fn"
in
kwargs
:
del
kwargs
[
"collate_fn"
]
# Save for PyTorch Lightning < 1.6:
self
.
follow_batch
=
follow_batch
self
.
exclude_keys
=
exclude_keys
super
().
__init__
(
dataset
,
batch_size
,
shuffle
,
collate_fn
=
Collater
(
follow_batch
,
exclude_keys
),
**
kwargs
,
)
mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py
0 → 100644
View file @
73ff4f3a
import
copy
import
os.path
as
osp
import
re
import
warnings
from
collections.abc
import
Sequence
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch.utils.data
from
torch
import
Tensor
from
.data
import
Data
from
.utils
import
makedirs
IndexType
=
Union
[
slice
,
Tensor
,
np
.
ndarray
,
Sequence
]
class
Dataset
(
torch
.
utils
.
data
.
Dataset
):
r
"""Dataset base class for creating graph datasets.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_dataset.html>`__ for the accompanying tutorial.
Args:
root (string, optional): Root directory where the dataset should be
saved. (optional: :obj:`None`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
@
property
def
raw_file_names
(
self
)
->
Union
[
str
,
List
[
str
],
Tuple
]:
r
"""The name of the files to find in the :obj:`self.raw_dir` folder in
order to skip the download."""
raise
NotImplementedError
@
property
def
processed_file_names
(
self
)
->
Union
[
str
,
List
[
str
],
Tuple
]:
r
"""The name of the files to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
raise
NotImplementedError
def
download
(
self
):
r
"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
raise
NotImplementedError
def
process
(
self
):
r
"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise
NotImplementedError
def
len
(
self
)
->
int
:
raise
NotImplementedError
def
get
(
self
,
idx
:
int
)
->
Data
:
r
"""Gets the data object at index :obj:`idx`."""
raise
NotImplementedError
def
__init__
(
self
,
root
:
Optional
[
str
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
pre_transform
:
Optional
[
Callable
]
=
None
,
pre_filter
:
Optional
[
Callable
]
=
None
,
):
super
().
__init__
()
if
isinstance
(
root
,
str
):
root
=
osp
.
expanduser
(
osp
.
normpath
(
root
))
self
.
root
=
root
self
.
transform
=
transform
self
.
pre_transform
=
pre_transform
self
.
pre_filter
=
pre_filter
self
.
_indices
:
Optional
[
Sequence
]
=
None
if
"download"
in
self
.
__class__
.
__dict__
.
keys
():
self
.
_download
()
if
"process"
in
self
.
__class__
.
__dict__
.
keys
():
self
.
_process
()
def
indices
(
self
)
->
Sequence
:
return
range
(
self
.
len
())
if
self
.
_indices
is
None
else
self
.
_indices
@
property
def
raw_dir
(
self
)
->
str
:
return
osp
.
join
(
self
.
root
,
"raw"
)
@
property
def
processed_dir
(
self
)
->
str
:
return
osp
.
join
(
self
.
root
,
"processed"
)
@
property
def
num_node_features
(
self
)
->
int
:
r
"""Returns the number of features per node in the dataset."""
data
=
self
[
0
]
if
hasattr
(
data
,
"num_node_features"
):
return
data
.
num_node_features
raise
AttributeError
(
f
"'
{
data
.
__class__
.
__name__
}
' object has no "
f
"attribute 'num_node_features'"
)
@
property
def
num_features
(
self
)
->
int
:
r
"""Alias for :py:attr:`~num_node_features`."""
return
self
.
num_node_features
@
property
def
num_edge_features
(
self
)
->
int
:
r
"""Returns the number of features per edge in the dataset."""
data
=
self
[
0
]
if
hasattr
(
data
,
"num_edge_features"
):
return
data
.
num_edge_features
raise
AttributeError
(
f
"'
{
data
.
__class__
.
__name__
}
' object has no "
f
"attribute 'num_edge_features'"
)
@
property
def
raw_paths
(
self
)
->
List
[
str
]:
r
"""The filepaths to find in order to skip the download."""
files
=
to_list
(
self
.
raw_file_names
)
return
[
osp
.
join
(
self
.
raw_dir
,
f
)
for
f
in
files
]
@
property
def
processed_paths
(
self
)
->
List
[
str
]:
r
"""The filepaths to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
files
=
to_list
(
self
.
processed_file_names
)
return
[
osp
.
join
(
self
.
processed_dir
,
f
)
for
f
in
files
]
def
_download
(
self
):
if
files_exist
(
self
.
raw_paths
):
# pragma: no cover
return
makedirs
(
self
.
raw_dir
)
self
.
download
()
def
_process
(
self
):
f
=
osp
.
join
(
self
.
processed_dir
,
"pre_transform.pt"
)
if
osp
.
exists
(
f
)
and
torch
.
load
(
f
)
!=
_repr
(
self
.
pre_transform
):
warnings
.
warn
(
f
"The `pre_transform` argument differs from the one used in "
f
"the pre-processed version of this dataset. If you want to "
f
"make use of another pre-processing technique, make sure to "
f
"sure to delete '
{
self
.
processed_dir
}
' first"
)
f
=
osp
.
join
(
self
.
processed_dir
,
"pre_filter.pt"
)
if
osp
.
exists
(
f
)
and
torch
.
load
(
f
)
!=
_repr
(
self
.
pre_filter
):
warnings
.
warn
(
"The `pre_filter` argument differs from the one used in the "
"pre-processed version of this dataset. If you want to make "
"use of another pre-fitering technique, make sure to delete "
"'{self.processed_dir}' first"
)
if
files_exist
(
self
.
processed_paths
):
# pragma: no cover
return
print
(
"Processing..."
)
makedirs
(
self
.
processed_dir
)
self
.
process
()
path
=
osp
.
join
(
self
.
processed_dir
,
"pre_transform.pt"
)
torch
.
save
(
_repr
(
self
.
pre_transform
),
path
)
path
=
osp
.
join
(
self
.
processed_dir
,
"pre_filter.pt"
)
torch
.
save
(
_repr
(
self
.
pre_filter
),
path
)
print
(
"Done!"
)
def
__len__
(
self
)
->
int
:
r
"""The number of examples in the dataset."""
return
len
(
self
.
indices
())
def
__getitem__
(
self
,
idx
:
Union
[
int
,
np
.
integer
,
IndexType
],
)
->
Union
[
"Dataset"
,
Data
]:
r
"""In case :obj:`idx` is of type integer, will return the data object
at index :obj:`idx` (and transforms it in case :obj:`transform` is
present).
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy
:obj:`np.array`, will return a subset of the dataset at the specified
indices."""
if
(
isinstance
(
idx
,
(
int
,
np
.
integer
))
or
(
isinstance
(
idx
,
Tensor
)
and
idx
.
dim
()
==
0
)
or
(
isinstance
(
idx
,
np
.
ndarray
)
and
np
.
isscalar
(
idx
))
):
data
=
self
.
get
(
self
.
indices
()[
idx
])
data
=
data
if
self
.
transform
is
None
else
self
.
transform
(
data
)
return
data
else
:
return
self
.
index_select
(
idx
)
def
index_select
(
self
,
idx
:
IndexType
)
->
"Dataset"
:
indices
=
self
.
indices
()
if
isinstance
(
idx
,
slice
):
indices
=
indices
[
idx
]
elif
isinstance
(
idx
,
Tensor
)
and
idx
.
dtype
==
torch
.
long
:
return
self
.
index_select
(
idx
.
flatten
().
tolist
())
elif
isinstance
(
idx
,
Tensor
)
and
idx
.
dtype
==
torch
.
bool
:
idx
=
idx
.
flatten
().
nonzero
(
as_tuple
=
False
)
return
self
.
index_select
(
idx
.
flatten
().
tolist
())
elif
isinstance
(
idx
,
np
.
ndarray
)
and
idx
.
dtype
==
np
.
int64
:
return
self
.
index_select
(
idx
.
flatten
().
tolist
())
elif
isinstance
(
idx
,
np
.
ndarray
)
and
idx
.
dtype
==
np
.
bool
:
idx
=
idx
.
flatten
().
nonzero
()[
0
]
return
self
.
index_select
(
idx
.
flatten
().
tolist
())
elif
isinstance
(
idx
,
Sequence
)
and
not
isinstance
(
idx
,
str
):
indices
=
[
indices
[
i
]
for
i
in
idx
]
else
:
raise
IndexError
(
f
"Only integers, slices (':'), list, tuples, torch.tensor and "
f
"np.ndarray of dtype long or bool are valid indices (got "
f
"'
{
type
(
idx
).
__name__
}
')"
)
dataset
=
copy
.
copy
(
self
)
dataset
.
_indices
=
indices
return
dataset
def
shuffle
(
self
,
return_perm
:
bool
=
False
,
)
->
Union
[
"Dataset"
,
Tuple
[
"Dataset"
,
Tensor
]]:
r
"""Randomly shuffles the examples in the dataset.
Args:
return_perm (bool, optional): If set to :obj:`True`, will return
the random permutation used to shuffle the dataset in addition.
(default: :obj:`False`)
"""
perm
=
torch
.
randperm
(
len
(
self
))
dataset
=
self
.
index_select
(
perm
)
return
(
dataset
,
perm
)
if
return_perm
is
True
else
dataset
def
__repr__
(
self
)
->
str
:
arg_repr
=
str
(
len
(
self
))
if
len
(
self
)
>
1
else
""
return
f
"
{
self
.
__class__
.
__name__
}
(
{
arg_repr
}
)"
def
to_list
(
value
:
Any
)
->
Sequence
:
if
isinstance
(
value
,
Sequence
)
and
not
isinstance
(
value
,
str
):
return
value
else
:
return
[
value
]
def
files_exist
(
files
:
List
[
str
])
->
bool
:
# NOTE: We return `False` in case `files` is empty, leading to a
# re-processing of files on every instantiation.
return
len
(
files
)
!=
0
and
all
([
osp
.
exists
(
f
)
for
f
in
files
])
def
_repr
(
obj
:
Any
)
->
str
:
if
obj
is
None
:
return
"None"
return
re
.
sub
(
"(<.*?)
\\
s.*(>)"
,
r
"\1\2"
,
obj
.
__repr__
())
mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py
0 → 100644
View file @
73ff4f3a
import
random
import
numpy
as
np
import
torch
def
seed_everything
(
seed
:
int
):
r
"""Sets the seed for generating random numbers in :pytorch:`PyTorch`,
:obj:`numpy` and Python.
Args:
seed (int): The desired seed.
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py
0 → 100644
View file @
73ff4f3a
import
os
import
os.path
as
osp
import
ssl
import
urllib
import
zipfile
def
makedirs
(
dir
):
os
.
makedirs
(
dir
,
exist_ok
=
True
)
def
download_url
(
url
,
folder
,
log
=
True
):
r
"""Downloads the content of an URL to a specific folder.
Args:
url (string): The url.
folder (string): The folder.
log (bool, optional): If :obj:`False`, will not print anything to the
console. (default: :obj:`True`)
"""
filename
=
url
.
rpartition
(
"/"
)[
2
].
split
(
"?"
)[
0
]
path
=
osp
.
join
(
folder
,
filename
)
if
osp
.
exists
(
path
):
# pragma: no cover
if
log
:
print
(
"Using exist file"
,
filename
)
return
path
if
log
:
print
(
"Downloading"
,
url
)
makedirs
(
folder
)
context
=
ssl
.
_create_unverified_context
()
data
=
urllib
.
request
.
urlopen
(
url
,
context
=
context
)
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
data
.
read
())
return
path
def
extract_zip
(
path
,
folder
,
log
=
True
):
r
"""Extracts a zip archive to a specific folder.
Args:
path (string): The path to the tar archive.
folder (string): The folder.
log (bool, optional): If :obj:`False`, will not print anything to the
console. (default: :obj:`True`)
"""
with
zipfile
.
ZipFile
(
path
,
"r"
)
as
f
:
f
.
extractall
(
folder
)
mace-bench/3rdparty/mace/mace/tools/torch_tools.py
0 → 100644
View file @
73ff4f3a
###########################################################################################
# Tools for torch
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import
logging
from
contextlib
import
contextmanager
from
typing
import
Dict
,
Union
import
numpy
as
np
import
torch
from
e3nn.io
import
CartesianTensor
TensorDict
=
Dict
[
str
,
torch
.
Tensor
]
def
to_one_hot
(
indices
:
torch
.
Tensor
,
num_classes
:
int
)
->
torch
.
Tensor
:
"""
Generates one-hot encoding with <num_classes> classes from <indices>
:param indices: (N x 1) tensor
:param num_classes: number of classes
:param device: torch device
:return: (N x num_classes) tensor
"""
shape
=
indices
.
shape
[:
-
1
]
+
(
num_classes
,)
oh
=
torch
.
zeros
(
shape
,
device
=
indices
.
device
).
view
(
shape
)
# scatter_ is the in-place version of scatter
oh
.
scatter_
(
dim
=-
1
,
index
=
indices
,
value
=
1
)
return
oh
.
view
(
*
shape
)
def
count_parameters
(
module
:
torch
.
nn
.
Module
)
->
int
:
return
int
(
sum
(
np
.
prod
(
p
.
shape
)
for
p
in
module
.
parameters
()))
def
tensor_dict_to_device
(
td
:
TensorDict
,
device
:
torch
.
device
)
->
TensorDict
:
return
{
k
:
v
.
to
(
device
)
if
v
is
not
None
else
None
for
k
,
v
in
td
.
items
()}
def
set_seeds
(
seed
:
int
)
->
None
:
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
def
to_numpy
(
t
:
torch
.
Tensor
)
->
np
.
ndarray
:
return
t
.
cpu
().
detach
().
numpy
()
def
init_device
(
device_str
:
str
)
->
torch
.
device
:
if
"cuda"
in
device_str
:
assert
torch
.
cuda
.
is_available
(),
"No CUDA device available!"
if
":"
in
device_str
:
# Check if the desired device is available
assert
int
(
device_str
.
split
(
":"
)[
-
1
])
<
torch
.
cuda
.
device_count
()
logging
.
info
(
f
"CUDA version:
{
torch
.
version
.
cuda
}
, CUDA device:
{
torch
.
cuda
.
current_device
()
}
"
)
torch
.
cuda
.
init
()
return
torch
.
device
(
device_str
)
if
device_str
==
"mps"
:
assert
torch
.
backends
.
mps
.
is_available
(),
"No MPS backend is available!"
logging
.
info
(
"Using MPS GPU acceleration"
)
return
torch
.
device
(
"mps"
)
if
device_str
==
"xpu"
:
torch
.
xpu
.
is_available
()
return
torch
.
device
(
"xpu"
)
logging
.
info
(
"Using CPU"
)
return
torch
.
device
(
"cpu"
)
dtype_dict
=
{
"float32"
:
torch
.
float32
,
"float64"
:
torch
.
float64
}
def
set_default_dtype
(
dtype
:
str
)
->
None
:
torch
.
set_default_dtype
(
dtype_dict
[
dtype
])
def
spherical_to_cartesian
(
t
:
torch
.
Tensor
):
"""
Convert spherical notation to cartesian notation
"""
stress_cart_tensor
=
CartesianTensor
(
"ij=ji"
)
stress_rtp
=
stress_cart_tensor
.
reduced_tensor_products
()
return
stress_cart_tensor
.
to_cartesian
(
t
,
rtp
=
stress_rtp
)
def
cartesian_to_spherical
(
t
:
torch
.
Tensor
):
"""
Convert cartesian notation to spherical notation
"""
stress_cart_tensor
=
CartesianTensor
(
"ij=ji"
)
stress_rtp
=
stress_cart_tensor
.
reduced_tensor_products
()
return
stress_cart_tensor
.
to_cartesian
(
t
,
rtp
=
stress_rtp
)
def
voigt_to_matrix
(
t
:
torch
.
Tensor
):
"""
Convert voigt notation to matrix notation
:param t: (6,) tensor or (3, 3) tensor or (9,) tensor
:return: (3, 3) tensor
"""
if
t
.
shape
==
(
3
,
3
):
return
t
if
t
.
shape
==
(
6
,):
return
torch
.
tensor
(
[
[
t
[
0
],
t
[
5
],
t
[
4
]],
[
t
[
5
],
t
[
1
],
t
[
3
]],
[
t
[
4
],
t
[
3
],
t
[
2
]],
],
dtype
=
t
.
dtype
,
)
if
t
.
shape
==
(
9
,):
return
t
.
view
(
3
,
3
)
raise
ValueError
(
f
"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape
{
t
.
shape
}
"
)
def
init_wandb
(
project
:
str
,
entity
:
str
,
name
:
str
,
config
:
dict
,
directory
:
str
):
import
wandb
wandb
.
init
(
project
=
project
,
entity
=
entity
,
name
=
name
,
config
=
config
,
dir
=
directory
,
resume
=
"allow"
,
)
@
contextmanager
def
default_dtype
(
dtype
:
Union
[
torch
.
dtype
,
str
]):
"""Context manager for configuring the default_dtype used by torch
Args:
dtype (torch.dtype|str): the default dtype to use within this context manager
"""
init
=
torch
.
get_default_dtype
()
if
isinstance
(
dtype
,
str
):
set_default_dtype
(
dtype
)
else
:
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
init
)
mace-bench/3rdparty/mace/mace/tools/train.py
0 → 100644
View file @
73ff4f3a
###########################################################################################
# Training script
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import
dataclasses
import
logging
import
time
from
contextlib
import
nullcontext
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.distributed
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.optim
import
LBFGS
from
torch.optim.swa_utils
import
SWALR
,
AveragedModel
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
torch_ema
import
ExponentialMovingAverage
from
torchmetrics
import
Metric
from
mace.cli.visualise_train
import
TrainingPlotter
from
.
import
torch_geometric
from
.checkpoint
import
CheckpointHandler
,
CheckpointState
from
.torch_tools
import
to_numpy
from
.utils
import
(
MetricsLogger
,
compute_mae
,
compute_q95
,
compute_rel_mae
,
compute_rel_rmse
,
compute_rmse
,
)
@
dataclasses
.
dataclass
class
SWAContainer
:
model
:
AveragedModel
scheduler
:
SWALR
start
:
int
loss_fn
:
torch
.
nn
.
Module
def
valid_err_log
(
valid_loss
,
eval_metrics
,
logger
,
log_errors
,
epoch
=
None
,
valid_loader_name
=
"Default"
,
):
eval_metrics
[
"mode"
]
=
"eval"
eval_metrics
[
"epoch"
]
=
epoch
eval_metrics
[
"head"
]
=
valid_loader_name
logger
.
log
(
eval_metrics
)
if
epoch
is
None
:
inintial_phrase
=
"Initial"
else
:
inintial_phrase
=
f
"Epoch
{
epoch
}
"
if
log_errors
==
"PerAtomRMSE"
:
error_e
=
eval_metrics
[
"rmse_e_per_atom"
]
*
1e3
error_f
=
eval_metrics
[
"rmse_f"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, RMSE_E_per_atom=
{
error_e
:
8.2
f
}
meV, RMSE_F=
{
error_f
:
8.2
f
}
meV / A"
)
elif
(
log_errors
==
"PerAtomRMSEstressvirials"
and
eval_metrics
[
"rmse_stress"
]
is
not
None
):
error_e
=
eval_metrics
[
"rmse_e_per_atom"
]
*
1e3
error_f
=
eval_metrics
[
"rmse_f"
]
*
1e3
error_stress
=
eval_metrics
[
"rmse_stress"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, RMSE_E_per_atom=
{
error_e
:
8.2
f
}
meV, RMSE_F=
{
error_f
:
8.2
f
}
meV / A, RMSE_stress=
{
error_stress
:
8.2
f
}
meV / A^3"
,
)
elif
(
log_errors
==
"PerAtomRMSEstressvirials"
and
eval_metrics
[
"rmse_virials_per_atom"
]
is
not
None
):
error_e
=
eval_metrics
[
"rmse_e_per_atom"
]
*
1e3
error_f
=
eval_metrics
[
"rmse_f"
]
*
1e3
error_virials
=
eval_metrics
[
"rmse_virials_per_atom"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, RMSE_E_per_atom=
{
error_e
:
8.2
f
}
meV, RMSE_F=
{
error_f
:
8.2
f
}
meV / A, RMSE_virials_per_atom=
{
error_virials
:
8.2
f
}
meV"
,
)
elif
(
log_errors
==
"PerAtomMAEstressvirials"
and
eval_metrics
[
"mae_stress_per_atom"
]
is
not
None
):
error_e
=
eval_metrics
[
"mae_e_per_atom"
]
*
1e3
error_f
=
eval_metrics
[
"mae_f"
]
*
1e3
error_stress
=
eval_metrics
[
"mae_stress"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: loss=
{
valid_loss
:
8.8
f
}
, MAE_E_per_atom=
{
error_e
:
8.2
f
}
meV, MAE_F=
{
error_f
:
8.2
f
}
meV / A, MAE_stress=
{
error_stress
:
8.2
f
}
meV / A^3"
)
elif
(
log_errors
==
"PerAtomMAEstressvirials"
and
eval_metrics
[
"mae_virials_per_atom"
]
is
not
None
):
error_e
=
eval_metrics
[
"mae_e_per_atom"
]
*
1e3
error_f
=
eval_metrics
[
"mae_f"
]
*
1e3
error_virials
=
eval_metrics
[
"mae_virials"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: loss=
{
valid_loss
:
8.8
f
}
, MAE_E_per_atom=
{
error_e
:
8.2
f
}
meV, MAE_F=
{
error_f
:
8.2
f
}
meV / A, MAE_virials=
{
error_virials
:
8.2
f
}
meV"
)
elif
log_errors
==
"TotalRMSE"
:
error_e
=
eval_metrics
[
"rmse_e"
]
*
1e3
error_f
=
eval_metrics
[
"rmse_f"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, RMSE_E=
{
error_e
:
8.2
f
}
meV, RMSE_F=
{
error_f
:
8.2
f
}
meV / A"
,
)
elif
log_errors
==
"PerAtomMAE"
:
error_e
=
eval_metrics
[
"mae_e_per_atom"
]
*
1e3
error_f
=
eval_metrics
[
"mae_f"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, MAE_E_per_atom=
{
error_e
:
8.2
f
}
meV, MAE_F=
{
error_f
:
8.2
f
}
meV / A"
,
)
elif
log_errors
==
"TotalMAE"
:
error_e
=
eval_metrics
[
"mae_e"
]
*
1e3
error_f
=
eval_metrics
[
"mae_f"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, MAE_E=
{
error_e
:
8.2
f
}
meV, MAE_F=
{
error_f
:
8.2
f
}
meV / A"
,
)
elif
log_errors
==
"DipoleRMSE"
:
error_mu
=
eval_metrics
[
"rmse_mu_per_atom"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, RMSE_MU_per_atom=
{
error_mu
:
8.2
f
}
mDebye"
,
)
elif
log_errors
==
"EnergyDipoleRMSE"
:
error_e
=
eval_metrics
[
"rmse_e_per_atom"
]
*
1e3
error_f
=
eval_metrics
[
"rmse_f"
]
*
1e3
error_mu
=
eval_metrics
[
"rmse_mu_per_atom"
]
*
1e3
logging
.
info
(
f
"
{
inintial_phrase
}
: head:
{
valid_loader_name
}
, loss=
{
valid_loss
:
8.8
f
}
, RMSE_E_per_atom=
{
error_e
:
8.2
f
}
meV, RMSE_F=
{
error_f
:
8.2
f
}
meV / A, RMSE_Mu_per_atom=
{
error_mu
:
8.2
f
}
mDebye"
,
)
def
train
(
model
:
torch
.
nn
.
Module
,
loss_fn
:
torch
.
nn
.
Module
,
train_loader
:
DataLoader
,
valid_loaders
:
Dict
[
str
,
DataLoader
],
optimizer
:
torch
.
optim
.
Optimizer
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
ExponentialLR
,
start_epoch
:
int
,
max_num_epochs
:
int
,
patience
:
int
,
checkpoint_handler
:
CheckpointHandler
,
logger
:
MetricsLogger
,
eval_interval
:
int
,
output_args
:
Dict
[
str
,
bool
],
device
:
torch
.
device
,
log_errors
:
str
,
swa
:
Optional
[
SWAContainer
]
=
None
,
ema
:
Optional
[
ExponentialMovingAverage
]
=
None
,
max_grad_norm
:
Optional
[
float
]
=
10.0
,
log_wandb
:
bool
=
False
,
distributed
:
bool
=
False
,
save_all_checkpoints
:
bool
=
False
,
plotter
:
TrainingPlotter
=
None
,
distributed_model
:
Optional
[
DistributedDataParallel
]
=
None
,
train_sampler
:
Optional
[
DistributedSampler
]
=
None
,
rank
:
Optional
[
int
]
=
0
,
):
lowest_loss
=
np
.
inf
valid_loss
=
np
.
inf
patience_counter
=
0
swa_start
=
True
keep_last
=
False
if
log_wandb
:
import
wandb
if
max_grad_norm
is
not
None
:
logging
.
info
(
f
"Using gradient clipping with tolerance=
{
max_grad_norm
:.
3
f
}
"
)
logging
.
info
(
""
)
logging
.
info
(
"===========TRAINING==========="
)
logging
.
info
(
"Started training, reporting errors on validation set"
)
logging
.
info
(
"Loss metrics on validation set"
)
epoch
=
start_epoch
# log validation loss before _any_ training
for
valid_loader_name
,
valid_loader
in
valid_loaders
.
items
():
valid_loss_head
,
eval_metrics
=
evaluate
(
model
=
model
,
loss_fn
=
loss_fn
,
data_loader
=
valid_loader
,
output_args
=
output_args
,
device
=
device
,
)
valid_err_log
(
valid_loss_head
,
eval_metrics
,
logger
,
log_errors
,
None
,
valid_loader_name
)
valid_loss
=
valid_loss_head
# consider only the last head for the checkpoint
while
epoch
<
max_num_epochs
:
# LR scheduler and SWA update
if
swa
is
None
or
epoch
<
swa
.
start
:
if
epoch
>
start_epoch
:
lr_scheduler
.
step
(
metrics
=
valid_loss
)
# Can break if exponential LR, TODO fix that!
else
:
if
swa_start
:
logging
.
info
(
"Changing loss based on Stage Two Weights"
)
lowest_loss
=
np
.
inf
swa_start
=
False
keep_last
=
True
loss_fn
=
swa
.
loss_fn
swa
.
model
.
update_parameters
(
model
)
if
epoch
>
start_epoch
:
swa
.
scheduler
.
step
()
# Train
if
distributed
:
train_sampler
.
set_epoch
(
epoch
)
if
"ScheduleFree"
in
type
(
optimizer
).
__name__
:
optimizer
.
train
()
train_one_epoch
(
model
=
model
,
loss_fn
=
loss_fn
,
data_loader
=
train_loader
,
optimizer
=
optimizer
,
epoch
=
epoch
,
output_args
=
output_args
,
max_grad_norm
=
max_grad_norm
,
ema
=
ema
,
logger
=
logger
,
device
=
device
,
distributed
=
distributed
,
distributed_model
=
distributed_model
,
rank
=
rank
,
)
if
distributed
:
torch
.
distributed
.
barrier
()
# Validate
if
epoch
%
eval_interval
==
0
:
model_to_evaluate
=
(
model
if
distributed_model
is
None
else
distributed_model
)
param_context
=
(
ema
.
average_parameters
()
if
ema
is
not
None
else
nullcontext
()
)
if
"ScheduleFree"
in
type
(
optimizer
).
__name__
:
optimizer
.
eval
()
with
param_context
:
wandb_log_dict
=
{}
for
valid_loader_name
,
valid_loader
in
valid_loaders
.
items
():
valid_loss_head
,
eval_metrics
=
evaluate
(
model
=
model_to_evaluate
,
loss_fn
=
loss_fn
,
data_loader
=
valid_loader
,
output_args
=
output_args
,
device
=
device
,
)
if
rank
==
0
:
valid_err_log
(
valid_loss_head
,
eval_metrics
,
logger
,
log_errors
,
epoch
,
valid_loader_name
,
)
if
log_wandb
:
wandb_log_dict
[
valid_loader_name
]
=
{
"epoch"
:
epoch
,
"valid_loss"
:
valid_loss_head
,
"valid_rmse_e_per_atom"
:
eval_metrics
[
"rmse_e_per_atom"
],
"valid_rmse_f"
:
eval_metrics
[
"rmse_f"
],
}
if
plotter
and
epoch
%
plotter
.
plot_frequency
==
0
:
try
:
plotter
.
plot
(
epoch
,
model_to_evaluate
,
rank
)
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
debug
(
f
"Plotting failed:
{
e
}
"
)
valid_loss
=
(
valid_loss_head
# consider only the last head for the checkpoint
)
if
log_wandb
:
wandb
.
log
(
wandb_log_dict
)
if
rank
==
0
:
if
valid_loss
>=
lowest_loss
:
patience_counter
+=
1
if
patience_counter
>=
patience
:
if
swa
is
not
None
and
epoch
<
swa
.
start
:
logging
.
info
(
f
"Stopping optimization after
{
patience_counter
}
epochs without improvement and starting Stage Two"
)
epoch
=
swa
.
start
else
:
logging
.
info
(
f
"Stopping optimization after
{
patience_counter
}
epochs without improvement"
)
break
if
save_all_checkpoints
:
param_context
=
(
ema
.
average_parameters
()
if
ema
is
not
None
else
nullcontext
()
)
with
param_context
:
checkpoint_handler
.
save
(
state
=
CheckpointState
(
model
,
optimizer
,
lr_scheduler
),
epochs
=
epoch
,
keep_last
=
True
,
)
else
:
lowest_loss
=
valid_loss
patience_counter
=
0
param_context
=
(
ema
.
average_parameters
()
if
ema
is
not
None
else
nullcontext
()
)
with
param_context
:
checkpoint_handler
.
save
(
state
=
CheckpointState
(
model
,
optimizer
,
lr_scheduler
),
epochs
=
epoch
,
keep_last
=
keep_last
,
)
keep_last
=
False
or
save_all_checkpoints
if
distributed
:
torch
.
distributed
.
barrier
()
epoch
+=
1
logging
.
info
(
"Training complete"
)
def
train_one_epoch
(
model
:
torch
.
nn
.
Module
,
loss_fn
:
torch
.
nn
.
Module
,
data_loader
:
DataLoader
,
optimizer
:
torch
.
optim
.
Optimizer
,
epoch
:
int
,
output_args
:
Dict
[
str
,
bool
],
max_grad_norm
:
Optional
[
float
],
ema
:
Optional
[
ExponentialMovingAverage
],
logger
:
MetricsLogger
,
device
:
torch
.
device
,
distributed
:
bool
,
distributed_model
:
Optional
[
DistributedDataParallel
]
=
None
,
rank
:
Optional
[
int
]
=
0
,
)
->
None
:
model_to_train
=
model
if
distributed_model
is
None
else
distributed_model
if
isinstance
(
optimizer
,
LBFGS
):
_
,
opt_metrics
=
take_step_lbfgs
(
model
=
model_to_train
,
loss_fn
=
loss_fn
,
data_loader
=
data_loader
,
optimizer
=
optimizer
,
ema
=
ema
,
output_args
=
output_args
,
max_grad_norm
=
max_grad_norm
,
device
=
device
,
distributed
=
distributed
,
rank
=
rank
,
)
opt_metrics
[
"mode"
]
=
"opt"
opt_metrics
[
"epoch"
]
=
epoch
if
rank
==
0
:
logger
.
log
(
opt_metrics
)
else
:
for
batch
in
data_loader
:
_
,
opt_metrics
=
take_step
(
model
=
model_to_train
,
loss_fn
=
loss_fn
,
batch
=
batch
,
optimizer
=
optimizer
,
ema
=
ema
,
output_args
=
output_args
,
max_grad_norm
=
max_grad_norm
,
device
=
device
,
)
opt_metrics
[
"mode"
]
=
"opt"
opt_metrics
[
"epoch"
]
=
epoch
if
rank
==
0
:
logger
.
log
(
opt_metrics
)
def
take_step
(
model
:
torch
.
nn
.
Module
,
loss_fn
:
torch
.
nn
.
Module
,
batch
:
torch_geometric
.
batch
.
Batch
,
optimizer
:
torch
.
optim
.
Optimizer
,
ema
:
Optional
[
ExponentialMovingAverage
],
output_args
:
Dict
[
str
,
bool
],
max_grad_norm
:
Optional
[
float
],
device
:
torch
.
device
,
)
->
Tuple
[
float
,
Dict
[
str
,
Any
]]:
start_time
=
time
.
time
()
batch
=
batch
.
to
(
device
)
batch_dict
=
batch
.
to_dict
()
def
closure
():
optimizer
.
zero_grad
(
set_to_none
=
True
)
output
=
model
(
batch_dict
,
training
=
True
,
compute_force
=
output_args
[
"forces"
],
compute_virials
=
output_args
[
"virials"
],
compute_stress
=
output_args
[
"stress"
],
)
loss
=
loss_fn
(
pred
=
output
,
ref
=
batch
)
loss
.
backward
()
if
max_grad_norm
is
not
None
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
max_norm
=
max_grad_norm
)
return
loss
loss
=
closure
()
optimizer
.
step
()
if
ema
is
not
None
:
ema
.
update
()
loss_dict
=
{
"loss"
:
to_numpy
(
loss
),
"time"
:
time
.
time
()
-
start_time
,
}
return
loss
,
loss_dict
def
take_step_lbfgs
(
model
:
torch
.
nn
.
Module
,
loss_fn
:
torch
.
nn
.
Module
,
data_loader
:
DataLoader
,
optimizer
:
torch
.
optim
.
Optimizer
,
ema
:
Optional
[
ExponentialMovingAverage
],
output_args
:
Dict
[
str
,
bool
],
max_grad_norm
:
Optional
[
float
],
device
:
torch
.
device
,
distributed
:
bool
,
rank
:
int
,
)
->
Tuple
[
float
,
Dict
[
str
,
Any
]]:
start_time
=
time
.
time
()
logging
.
debug
(
f
"Max Allocated:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
total_sample_count
=
0
for
batch
in
data_loader
:
total_sample_count
+=
batch
.
num_graphs
if
distributed
:
global_sample_count
=
torch
.
tensor
(
total_sample_count
,
device
=
device
)
torch
.
distributed
.
all_reduce
(
global_sample_count
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
total_sample_count
=
global_sample_count
.
item
()
signal
=
torch
.
zeros
(
1
,
device
=
device
)
if
distributed
else
None
def
closure
():
if
distributed
:
if
rank
==
0
:
signal
.
fill_
(
1
)
torch
.
distributed
.
broadcast
(
signal
,
src
=
0
)
for
param
in
model
.
parameters
():
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
0
)
optimizer
.
zero_grad
(
set_to_none
=
True
)
total_loss
=
torch
.
tensor
(
0.0
,
device
=
device
)
# Process each batch and then collect the results we pass to the optimizer
for
batch
in
data_loader
:
batch
=
batch
.
to
(
device
)
batch_dict
=
batch
.
to_dict
()
output
=
model
(
batch_dict
,
training
=
True
,
compute_force
=
output_args
[
"forces"
],
compute_virials
=
output_args
[
"virials"
],
compute_stress
=
output_args
[
"stress"
],
)
batch_loss
=
loss_fn
(
pred
=
output
,
ref
=
batch
)
batch_loss
=
batch_loss
*
(
batch
.
num_graphs
/
total_sample_count
)
batch_loss
.
backward
()
total_loss
+=
batch_loss
if
max_grad_norm
is
not
None
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
max_norm
=
max_grad_norm
)
if
distributed
:
torch
.
distributed
.
all_reduce
(
total_loss
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
return
total_loss
if
distributed
:
if
rank
==
0
:
loss
=
optimizer
.
step
(
closure
)
signal
.
fill_
(
0
)
torch
.
distributed
.
broadcast
(
signal
,
src
=
0
)
else
:
while
True
:
# Other ranks wait for signals from rank 0
torch
.
distributed
.
broadcast
(
signal
,
src
=
0
)
if
signal
.
item
()
==
0
:
break
if
signal
.
item
()
==
1
:
loss
=
closure
()
for
param
in
model
.
parameters
():
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
0
)
else
:
loss
=
optimizer
.
step
(
closure
)
if
ema
is
not
None
:
ema
.
update
()
loss_dict
=
{
"loss"
:
to_numpy
(
loss
),
"time"
:
time
.
time
()
-
start_time
,
}
return
loss
,
loss_dict
def
evaluate
(
model
:
torch
.
nn
.
Module
,
loss_fn
:
torch
.
nn
.
Module
,
data_loader
:
DataLoader
,
output_args
:
Dict
[
str
,
bool
],
device
:
torch
.
device
,
)
->
Tuple
[
float
,
Dict
[
str
,
Any
]]:
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
metrics
=
MACELoss
(
loss_fn
=
loss_fn
).
to
(
device
)
start_time
=
time
.
time
()
for
batch
in
data_loader
:
batch
=
batch
.
to
(
device
)
batch_dict
=
batch
.
to_dict
()
output
=
model
(
batch_dict
,
training
=
False
,
compute_force
=
output_args
[
"forces"
],
compute_virials
=
output_args
[
"virials"
],
compute_stress
=
output_args
[
"stress"
],
)
avg_loss
,
aux
=
metrics
(
batch
,
output
)
avg_loss
,
aux
=
metrics
.
compute
()
aux
[
"time"
]
=
time
.
time
()
-
start_time
metrics
.
reset
()
for
param
in
model
.
parameters
():
param
.
requires_grad
=
True
return
avg_loss
,
aux
class
MACELoss
(
Metric
):
def
__init__
(
self
,
loss_fn
:
torch
.
nn
.
Module
):
super
().
__init__
()
self
.
loss_fn
=
loss_fn
self
.
add_state
(
"total_loss"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"num_data"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"E_computed"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"delta_es"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"delta_es_per_atom"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"Fs_computed"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"fs"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"delta_fs"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"stress_computed"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"delta_stress"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"virials_computed"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"delta_virials"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"delta_virials_per_atom"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"Mus_computed"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"mus"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"delta_mus"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
self
.
add_state
(
"delta_mus_per_atom"
,
default
=
[],
dist_reduce_fx
=
"cat"
)
def
update
(
self
,
batch
,
output
):
# pylint: disable=arguments-differ
loss
=
self
.
loss_fn
(
pred
=
output
,
ref
=
batch
)
self
.
total_loss
+=
loss
self
.
num_data
+=
batch
.
num_graphs
if
output
.
get
(
"energy"
)
is
not
None
and
batch
.
energy
is
not
None
:
self
.
E_computed
+=
1.0
self
.
delta_es
.
append
(
batch
.
energy
-
output
[
"energy"
])
self
.
delta_es_per_atom
.
append
(
(
batch
.
energy
-
output
[
"energy"
])
/
(
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
])
)
if
output
.
get
(
"forces"
)
is
not
None
and
batch
.
forces
is
not
None
:
self
.
Fs_computed
+=
1.0
self
.
fs
.
append
(
batch
.
forces
)
self
.
delta_fs
.
append
(
batch
.
forces
-
output
[
"forces"
])
if
output
.
get
(
"stress"
)
is
not
None
and
batch
.
stress
is
not
None
:
self
.
stress_computed
+=
1.0
self
.
delta_stress
.
append
(
batch
.
stress
-
output
[
"stress"
])
if
output
.
get
(
"virials"
)
is
not
None
and
batch
.
virials
is
not
None
:
self
.
virials_computed
+=
1.0
self
.
delta_virials
.
append
(
batch
.
virials
-
output
[
"virials"
])
self
.
delta_virials_per_atom
.
append
(
(
batch
.
virials
-
output
[
"virials"
])
/
(
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
]).
view
(
-
1
,
1
,
1
)
)
if
output
.
get
(
"dipole"
)
is
not
None
and
batch
.
dipole
is
not
None
:
self
.
Mus_computed
+=
1.0
self
.
mus
.
append
(
batch
.
dipole
)
self
.
delta_mus
.
append
(
batch
.
dipole
-
output
[
"dipole"
])
self
.
delta_mus_per_atom
.
append
(
(
batch
.
dipole
-
output
[
"dipole"
])
/
(
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
]).
unsqueeze
(
-
1
)
)
def
convert
(
self
,
delta
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]])
->
np
.
ndarray
:
if
isinstance
(
delta
,
list
):
delta
=
torch
.
cat
(
delta
)
return
to_numpy
(
delta
)
def
compute
(
self
):
aux
=
{}
aux
[
"loss"
]
=
to_numpy
(
self
.
total_loss
/
self
.
num_data
).
item
()
if
self
.
E_computed
:
delta_es
=
self
.
convert
(
self
.
delta_es
)
delta_es_per_atom
=
self
.
convert
(
self
.
delta_es_per_atom
)
aux
[
"mae_e"
]
=
compute_mae
(
delta_es
)
aux
[
"mae_e_per_atom"
]
=
compute_mae
(
delta_es_per_atom
)
aux
[
"rmse_e"
]
=
compute_rmse
(
delta_es
)
aux
[
"rmse_e_per_atom"
]
=
compute_rmse
(
delta_es_per_atom
)
aux
[
"q95_e"
]
=
compute_q95
(
delta_es
)
if
self
.
Fs_computed
:
fs
=
self
.
convert
(
self
.
fs
)
delta_fs
=
self
.
convert
(
self
.
delta_fs
)
aux
[
"mae_f"
]
=
compute_mae
(
delta_fs
)
aux
[
"rel_mae_f"
]
=
compute_rel_mae
(
delta_fs
,
fs
)
aux
[
"rmse_f"
]
=
compute_rmse
(
delta_fs
)
aux
[
"rel_rmse_f"
]
=
compute_rel_rmse
(
delta_fs
,
fs
)
aux
[
"q95_f"
]
=
compute_q95
(
delta_fs
)
if
self
.
stress_computed
:
delta_stress
=
self
.
convert
(
self
.
delta_stress
)
aux
[
"mae_stress"
]
=
compute_mae
(
delta_stress
)
aux
[
"rmse_stress"
]
=
compute_rmse
(
delta_stress
)
aux
[
"q95_stress"
]
=
compute_q95
(
delta_stress
)
if
self
.
virials_computed
:
delta_virials
=
self
.
convert
(
self
.
delta_virials
)
delta_virials_per_atom
=
self
.
convert
(
self
.
delta_virials_per_atom
)
aux
[
"mae_virials"
]
=
compute_mae
(
delta_virials
)
aux
[
"rmse_virials"
]
=
compute_rmse
(
delta_virials
)
aux
[
"rmse_virials_per_atom"
]
=
compute_rmse
(
delta_virials_per_atom
)
aux
[
"q95_virials"
]
=
compute_q95
(
delta_virials
)
if
self
.
Mus_computed
:
mus
=
self
.
convert
(
self
.
mus
)
delta_mus
=
self
.
convert
(
self
.
delta_mus
)
delta_mus_per_atom
=
self
.
convert
(
self
.
delta_mus_per_atom
)
aux
[
"mae_mu"
]
=
compute_mae
(
delta_mus
)
aux
[
"mae_mu_per_atom"
]
=
compute_mae
(
delta_mus_per_atom
)
aux
[
"rel_mae_mu"
]
=
compute_rel_mae
(
delta_mus
,
mus
)
aux
[
"rmse_mu"
]
=
compute_rmse
(
delta_mus
)
aux
[
"rmse_mu_per_atom"
]
=
compute_rmse
(
delta_mus_per_atom
)
aux
[
"rel_rmse_mu"
]
=
compute_rel_rmse
(
delta_mus
,
mus
)
aux
[
"q95_mu"
]
=
compute_q95
(
delta_mus
)
return
aux
[
"loss"
],
aux
mace-bench/3rdparty/mace/mace/tools/utils.py
0 → 100644
View file @
73ff4f3a
###########################################################################################
# Statistics utilities
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import
json
import
logging
import
os
import
sys
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Sequence
,
Union
import
numpy
as
np
import
torch
from
.torch_tools
import
to_numpy
def
compute_mae
(
delta
:
np
.
ndarray
)
->
float
:
return
np
.
mean
(
np
.
abs
(
delta
)).
item
()
def
compute_rel_mae
(
delta
:
np
.
ndarray
,
target_val
:
np
.
ndarray
)
->
float
:
target_norm
=
np
.
mean
(
np
.
abs
(
target_val
))
return
np
.
mean
(
np
.
abs
(
delta
)).
item
()
/
(
target_norm
+
1e-9
)
*
100
def
compute_rmse
(
delta
:
np
.
ndarray
)
->
float
:
return
np
.
sqrt
(
np
.
mean
(
np
.
square
(
delta
))).
item
()
def
compute_rel_rmse
(
delta
:
np
.
ndarray
,
target_val
:
np
.
ndarray
)
->
float
:
target_norm
=
np
.
sqrt
(
np
.
mean
(
np
.
square
(
target_val
))).
item
()
return
np
.
sqrt
(
np
.
mean
(
np
.
square
(
delta
))).
item
()
/
(
target_norm
+
1e-9
)
*
100
def
compute_q95
(
delta
:
np
.
ndarray
)
->
float
:
return
np
.
percentile
(
np
.
abs
(
delta
),
q
=
95
)
def
compute_c
(
delta
:
np
.
ndarray
,
eta
:
float
)
->
float
:
return
np
.
mean
(
np
.
abs
(
delta
)
<
eta
).
item
()
def
get_tag
(
name
:
str
,
seed
:
int
)
->
str
:
return
f
"
{
name
}
_run-
{
seed
}
"
def
setup_logger
(
level
:
Union
[
int
,
str
]
=
logging
.
INFO
,
tag
:
Optional
[
str
]
=
None
,
directory
:
Optional
[
str
]
=
None
,
rank
:
Optional
[
int
]
=
0
,
):
# Create a logger
logger
=
logging
.
getLogger
()
logger
.
setLevel
(
logging
.
DEBUG
)
# Set to DEBUG to capture all levels
# Create formatters
formatter
=
logging
.
Formatter
(
"%(asctime)s.%(msecs)03d %(levelname)s: %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
# Add filter for rank
logger
.
addFilter
(
lambda
_
:
rank
==
0
)
# Create console handler
ch
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
ch
.
setLevel
(
level
)
ch
.
setFormatter
(
formatter
)
logger
.
addHandler
(
ch
)
if
directory
is
not
None
and
tag
is
not
None
:
os
.
makedirs
(
name
=
directory
,
exist_ok
=
True
)
# Create file handler for non-debug logs
main_log_path
=
os
.
path
.
join
(
directory
,
f
"
{
tag
}
.log"
)
fh_main
=
logging
.
FileHandler
(
main_log_path
)
fh_main
.
setLevel
(
level
)
fh_main
.
setFormatter
(
formatter
)
logger
.
addHandler
(
fh_main
)
# Create file handler for debug logs
debug_log_path
=
os
.
path
.
join
(
directory
,
f
"
{
tag
}
_debug.log"
)
fh_debug
=
logging
.
FileHandler
(
debug_log_path
)
fh_debug
.
setLevel
(
logging
.
DEBUG
)
fh_debug
.
setFormatter
(
formatter
)
fh_debug
.
addFilter
(
lambda
record
:
record
.
levelno
>=
logging
.
DEBUG
)
logger
.
addHandler
(
fh_debug
)
class
AtomicNumberTable
:
def
__init__
(
self
,
zs
:
Sequence
[
int
]):
self
.
zs
=
zs
def
__len__
(
self
)
->
int
:
return
len
(
self
.
zs
)
def
__str__
(
self
):
return
f
"AtomicNumberTable:
{
tuple
(
s
for
s
in
self
.
zs
)
}
"
def
index_to_z
(
self
,
index
:
int
)
->
int
:
return
self
.
zs
[
index
]
def
z_to_index
(
self
,
atomic_number
:
str
)
->
int
:
return
self
.
zs
.
index
(
atomic_number
)
def
get_atomic_number_table_from_zs
(
zs
:
Iterable
[
int
])
->
AtomicNumberTable
:
z_set
=
set
()
for
z
in
zs
:
z_set
.
add
(
z
)
return
AtomicNumberTable
(
sorted
(
list
(
z_set
)))
def
atomic_numbers_to_indices
(
atomic_numbers
:
np
.
ndarray
,
z_table
:
AtomicNumberTable
)
->
np
.
ndarray
:
to_index_fn
=
np
.
vectorize
(
z_table
.
z_to_index
)
return
to_index_fn
(
atomic_numbers
)
class
UniversalEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
o
):
if
isinstance
(
o
,
np
.
integer
):
return
int
(
o
)
if
isinstance
(
o
,
np
.
floating
):
return
float
(
o
)
if
isinstance
(
o
,
np
.
ndarray
):
return
o
.
tolist
()
if
isinstance
(
o
,
torch
.
Tensor
):
return
to_numpy
(
o
)
return
json
.
JSONEncoder
.
default
(
self
,
o
)
class
MetricsLogger
:
def
__init__
(
self
,
directory
:
str
,
tag
:
str
)
->
None
:
self
.
directory
=
directory
self
.
filename
=
tag
+
".txt"
self
.
path
=
os
.
path
.
join
(
self
.
directory
,
self
.
filename
)
def
log
(
self
,
d
:
Dict
[
str
,
Any
])
->
None
:
os
.
makedirs
(
name
=
self
.
directory
,
exist_ok
=
True
)
with
open
(
self
.
path
,
mode
=
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
d
,
cls
=
UniversalEncoder
))
f
.
write
(
"
\n
"
)
# pylint: disable=abstract-method, arguments-differ
class
LAMMPS_MP
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
*
args
):
feats
,
data
=
args
# unpack
ctx
.
vec_len
=
feats
.
shape
[
-
1
]
ctx
.
data
=
data
out
=
torch
.
empty_like
(
feats
)
data
.
forward_exchange
(
feats
,
out
,
ctx
.
vec_len
)
return
out
@
staticmethod
def
backward
(
ctx
,
*
grad_outputs
):
(
grad
,)
=
grad_outputs
# unpack
gout
=
torch
.
empty_like
(
grad
)
ctx
.
data
.
reverse_exchange
(
grad
,
gout
,
ctx
.
vec_len
)
return
gout
,
None
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment