Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
nivren
ICT-CSP
Commits
fa84b16c
Unverified
Commit
fa84b16c
authored
Aug 24, 2025
by
zcxzcx1
Committed by
GitHub
Aug 24, 2025
Browse files
Add files via upload
parent
09624897
Changes
52
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2500 additions
and
0 deletions
+2500
-0
mace-bench/src/batchopt/__pycache__/pbc_graph_legacy.cpython-310.pyc
...src/batchopt/__pycache__/pbc_graph_legacy.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc
...ench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/__pycache__/utils.cpython-310.pyc
mace-bench/src/batchopt/__pycache__/utils.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/atoms_to_graphs.py
mace-bench/src/batchopt/atoms_to_graphs.py
+309
-0
mace-bench/src/batchopt/baseline.py
mace-bench/src/batchopt/baseline.py
+171
-0
mace-bench/src/batchopt/extensions/__init__.py
mace-bench/src/batchopt/extensions/__init__.py
+12
-0
mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc
.../batchopt/extensions/__pycache__/__init__.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/extensions/cuda_ops/__init__.py
mace-bench/src/batchopt/extensions/cuda_ops/__init__.py
+91
-0
mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc
.../extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/extensions/cuda_ops/pbc_graph.cu
mace-bench/src/batchopt/extensions/cuda_ops/pbc_graph.cu
+286
-0
mace-bench/src/batchopt/pbc_graph.py
mace-bench/src/batchopt/pbc_graph.py
+158
-0
mace-bench/src/batchopt/pbc_graph_legacy.py
mace-bench/src/batchopt/pbc_graph_legacy.py
+563
-0
mace-bench/src/batchopt/relaxation/__init__.py
mace-bench/src/batchopt/relaxation/__init__.py
+11
-0
mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc
.../batchopt/relaxation/__pycache__/__init__.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc
...batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc
...tchopt/relaxation/__pycache__/optimizable.cpython-310.pyc
+0
-0
mace-bench/src/batchopt/relaxation/ase_utils.py
mace-bench/src/batchopt/relaxation/ase_utils.py
+95
-0
mace-bench/src/batchopt/relaxation/optimizable.py
mace-bench/src/batchopt/relaxation/optimizable.py
+791
-0
mace-bench/src/batchopt/relaxation/optimizers/__init__.py
mace-bench/src/batchopt/relaxation/optimizers/__init__.py
+13
-0
mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc
...elaxation/optimizers/__pycache__/__init__.cpython-310.pyc
+0
-0
No files found.
mace-bench/src/batchopt/__pycache__/pbc_graph_legacy.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/__pycache__/utils.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/atoms_to_graphs.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
ase.db.sqlite
import
ase.io.trajectory
import
numpy
as
np
import
torch
from
ase.geometry
import
wrap_positions
from
torch_geometric.data
import
Data
from
batchopt.utils
import
collate
if
TYPE_CHECKING
:
from
collections.abc
import
Sequence
try
:
from
pymatgen.io.ase
import
AseAtomsAdaptor
except
ImportError
:
AseAtomsAdaptor
=
None
from
tqdm
import
tqdm
class
AtomsToGraphs
:
"""A class to help convert periodic atomic structures to graphs.
The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts
them into graph representations for use in PyTorch. The primary purpose of this class is to determine the
nearest neighbors within some radius around each individual atom, taking into account PBC, and set the
pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information
are put into a PyTorch geometric data object for use with PyTorch.
Args:
max_neigh (int): Maximum number of neighbors to consider.
radius (int or float): Cutoff radius in Angstroms to search for neighbors.
r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned.
r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned.
r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned.
r_distances (bool): Return the distances with other properties.
Default is False, so the distances will not be returned.
r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned.
r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms.
Default is True, so the fixed indices will be returned.
r_pbc (bool): Return the periodic boundary conditions with other properties.
Default is False, so the periodic boundary conditions will not be returned.
r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other
properties. Default is None, so no data will be returned as properties.
Attributes:
max_neigh (int): Maximum number of neighbors to consider.
radius (int or float): Cutoff radius in Angstoms to search for neighbors.
r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned.
r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned.
r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned.
r_distances (bool): Return the distances with other properties.
Default is False, so the distances will not be returned.
r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned.
r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms.
Default is True, so the fixed indices will be returned.
r_pbc (bool): Return the periodic boundary conditions with other properties.
Default is False, so the periodic boundary conditions will not be returned.
r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other
properties. Default is None, so no data will be returned as properties.
"""
def
__init__
(
self
,
max_neigh
:
int
=
200
,
radius
:
int
=
6
,
r_energy
:
bool
=
False
,
r_forces
:
bool
=
False
,
r_distances
:
bool
=
False
,
r_edges
:
bool
=
True
,
r_fixed
:
bool
=
True
,
r_pbc
:
bool
=
False
,
r_stress
:
bool
=
False
,
r_data_keys
:
Sequence
[
str
]
|
None
=
None
,
)
->
None
:
self
.
max_neigh
=
max_neigh
self
.
radius
=
radius
self
.
r_energy
=
r_energy
self
.
r_forces
=
r_forces
self
.
r_stress
=
r_stress
self
.
r_distances
=
r_distances
self
.
r_fixed
=
r_fixed
self
.
r_edges
=
r_edges
self
.
r_pbc
=
r_pbc
self
.
r_data_keys
=
r_data_keys
def
_get_neighbors_pymatgen
(
self
,
atoms
:
ase
.
Atoms
):
"""Preforms nearest neighbor search and returns edge index, distances,
and cell offsets"""
if
AseAtomsAdaptor
is
None
:
raise
RuntimeError
(
"Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed."
)
struct
=
AseAtomsAdaptor
.
get_structure
(
atoms
)
_c_index
,
_n_index
,
_offsets
,
n_distance
=
struct
.
get_neighbor_list
(
r
=
self
.
radius
,
numerical_tol
=
0
,
exclude_self
=
True
)
_nonmax_idx
=
[]
for
i
in
range
(
len
(
atoms
)):
idx_i
=
(
_c_index
==
i
).
nonzero
()[
0
]
# sort neighbors by distance, remove edges larger than max_neighbors
idx_sorted
=
np
.
argsort
(
n_distance
[
idx_i
])[:
self
.
max_neigh
]
_nonmax_idx
.
append
(
idx_i
[
idx_sorted
])
_nonmax_idx
=
np
.
concatenate
(
_nonmax_idx
)
_c_index
=
_c_index
[
_nonmax_idx
]
_n_index
=
_n_index
[
_nonmax_idx
]
n_distance
=
n_distance
[
_nonmax_idx
]
_offsets
=
_offsets
[
_nonmax_idx
]
return
_c_index
,
_n_index
,
n_distance
,
_offsets
def
_reshape_features
(
self
,
c_index
,
n_index
,
n_distance
,
offsets
):
"""Stack center and neighbor index and reshapes distances,
takes in np.arrays and returns torch tensors"""
edge_index
=
torch
.
LongTensor
(
np
.
vstack
((
n_index
,
c_index
)))
edge_distances
=
torch
.
FloatTensor
(
n_distance
)
cell_offsets
=
torch
.
LongTensor
(
offsets
)
# remove distances smaller than a tolerance ~ 0. The small tolerance is
# needed to correct for pymatgen's neighbor_list returning self atoms
# in a few edge cases.
nonzero
=
torch
.
where
(
edge_distances
>=
1e-8
)[
0
]
edge_index
=
edge_index
[:,
nonzero
]
edge_distances
=
edge_distances
[
nonzero
]
cell_offsets
=
cell_offsets
[
nonzero
]
return
edge_index
,
edge_distances
,
cell_offsets
def
get_edge_distance_vec
(
self
,
pos
,
edge_index
,
cell
,
cell_offsets
,
):
row
,
col
=
edge_index
distance_vectors
=
pos
[
row
]
-
pos
[
col
]
# correct for pbc
cell
=
torch
.
repeat_interleave
(
cell
,
edge_index
.
shape
[
1
],
dim
=
0
)
offsets
=
cell_offsets
.
float
().
view
(
-
1
,
1
,
3
).
bmm
(
cell
.
float
()).
view
(
-
1
,
3
)
distance_vectors
+=
offsets
return
distance_vectors
def
convert
(
self
,
atoms
:
ase
.
Atoms
,
sid
=
None
):
"""Convert a single atomic structure to a graph.
Args:
atoms (ase.atoms.Atoms): An ASE atoms object.
sid (uniquely identifying object): An identifier that can be used to track the structure in downstream
tasks. Common sids used in OCP datasets include unique strings or integers.
Returns:
data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags,
and optionally, energy, forces, distances, edges, and periodic boundary conditions.
Optional properties can included by setting r_property=True when constructing the class.
"""
# set the atomic numbers, positions, and cell
positions
=
np
.
array
(
atoms
.
get_positions
(),
copy
=
True
)
pbc
=
np
.
array
(
atoms
.
pbc
,
copy
=
True
)
cell
=
np
.
array
(
atoms
.
get_cell
(
complete
=
True
),
copy
=
True
)
# TODO: change this back &&& ^^^
# positions = wrap_positions(positions, cell, pbc=pbc, eps=0)
atomic_numbers
=
torch
.
tensor
(
atoms
.
get_atomic_numbers
(),
dtype
=
torch
.
uint8
)
positions
=
torch
.
from_numpy
(
positions
).
float
()
cell
=
torch
.
from_numpy
(
cell
).
view
(
1
,
3
,
3
).
float
()
natoms
=
positions
.
shape
[
0
]
# initialized to torch.zeros(natoms) if tags missing.
# https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags
tags
=
torch
.
tensor
(
atoms
.
get_tags
(),
dtype
=
torch
.
int
)
# put the minimum data in torch geometric data object
data
=
Data
(
cell
=
cell
,
pos
=
positions
,
atomic_numbers
=
atomic_numbers
,
natoms
=
natoms
,
tags
=
tags
,
)
# Optionally add a systemid (sid) to the object
if
sid
is
not
None
:
data
.
sid
=
sid
# optionally include other properties
if
self
.
r_edges
:
# run internal functions to get padded indices and distances
atoms_copy
=
atoms
.
copy
()
atoms_copy
.
set_positions
(
positions
)
split_idx_dist
=
self
.
_get_neighbors_pymatgen
(
atoms_copy
)
edge_index
,
edge_distances
,
cell_offsets
=
self
.
_reshape_features
(
*
split_idx_dist
)
data
.
edge_index
=
edge_index
data
.
cell_offsets
=
cell_offsets
data
.
edge_distance_vec
=
self
.
get_edge_distance_vec
(
positions
,
edge_index
,
cell
,
cell_offsets
)
del
atoms_copy
if
self
.
r_energy
:
energy
=
atoms
.
get_potential_energy
(
apply_constraint
=
False
)
data
.
energy
=
energy
if
self
.
r_forces
:
forces
=
torch
.
tensor
(
atoms
.
get_forces
(
apply_constraint
=
False
),
dtype
=
torch
.
float32
)
data
.
forces
=
forces
if
self
.
r_stress
:
stress
=
torch
.
tensor
(
atoms
.
get_stress
(
apply_constraint
=
False
,
voigt
=
False
),
dtype
=
torch
.
float32
,
)
data
.
stress
=
stress
if
self
.
r_distances
and
self
.
r_edges
:
data
.
distances
=
edge_distances
if
self
.
r_fixed
:
fixed_idx
=
torch
.
zeros
(
natoms
,
dtype
=
torch
.
int
)
if
hasattr
(
atoms
,
"constraints"
):
from
ase.constraints
import
FixAtoms
for
constraint
in
atoms
.
constraints
:
if
isinstance
(
constraint
,
FixAtoms
):
fixed_idx
[
constraint
.
index
]
=
1
data
.
fixed
=
fixed_idx
if
self
.
r_pbc
:
data
.
pbc
=
torch
.
tensor
(
atoms
.
pbc
,
dtype
=
torch
.
bool
)
if
self
.
r_data_keys
is
not
None
:
for
data_key
in
self
.
r_data_keys
:
data
[
data_key
]
=
(
atoms
.
info
[
data_key
]
if
isinstance
(
atoms
.
info
[
data_key
],
(
int
,
float
,
str
))
else
torch
.
tensor
(
atoms
.
info
[
data_key
])
)
return
data
def
convert_all
(
self
,
atoms_collection
,
processed_file_path
:
str
|
None
=
None
,
collate_and_save
=
False
,
disable_tqdm
=
False
,
):
"""Convert all atoms objects in a list or in an ase.db to graphs.
Args:
atoms_collection (list of ase.atoms.Atoms or ase.db.sqlite.SQLite3Database):
Either a list of ASE atoms objects or an ASE database.
processed_file_path (str):
A string of the path to where the processed file will be written. Default is None.
collate_and_save (bool): A boolean to collate and save or not. Default is False, so will not write a file.
Returns:
data_list (list of torch_geometric.data.Data):
A list of torch geometric data objects containing molecular graph info and properties.
"""
# list for all data
data_list
=
[]
if
isinstance
(
atoms_collection
,
list
):
atoms_iter
=
atoms_collection
elif
isinstance
(
atoms_collection
,
ase
.
db
.
sqlite
.
SQLite3Database
):
atoms_iter
=
atoms_collection
.
select
()
elif
isinstance
(
atoms_collection
,
(
ase
.
io
.
trajectory
.
SlicedTrajectory
,
ase
.
io
.
trajectory
.
TrajectoryReader
),
):
atoms_iter
=
atoms_collection
else
:
raise
NotImplementedError
for
atoms
in
tqdm
(
atoms_iter
,
desc
=
"converting ASE atoms collection to graphs"
,
total
=
len
(
atoms_collection
),
unit
=
" systems"
,
disable
=
disable_tqdm
,
):
# check if atoms is an ASE Atoms object this for the ase.db case
data
=
self
.
convert
(
atoms
if
isinstance
(
atoms
,
ase
.
atoms
.
Atoms
)
else
atoms
.
toatoms
()
)
data_list
.
append
(
data
)
if
collate_and_save
:
data
,
slices
=
collate
(
data_list
)
torch
.
save
((
data
,
slices
),
processed_file_path
)
return
data_list
mace-bench/src/batchopt/baseline.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from
ase.io
import
read
import
logging
from
joblib
import
Parallel
,
delayed
from
ase.optimize
import
LBFGS
as
ASE_LBFGS
from
ase.optimize
import
QuasiNewton
as
ASE_QuasiNewton
from
ase.optimize
import
BFGS
as
ASE_BFGS
import
time
import
csv
import
os
try
:
from
mace.calculators
import
mace_off
except
ImportError
:
logging
.
warning
(
"Failed to import MACE modules"
)
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
def
ensure_directory
(
directory
):
"""Create directory if it doesn't exist."""
if
not
os
.
path
.
exists
(
directory
):
os
.
makedirs
(
directory
)
logging
.
info
(
f
"Created directory:
{
directory
}
"
)
def
baseline_task
(
file
,
device
,
max_steps
,
filter1
=
None
,
filter2
=
None
,
skip_second_stage
=
False
,
scalar_pressure
=
0.0006
,
first_optimizer
=
"LBFGS"
,
second_optimizer
=
"LBFGS"
):
"""
Runs the baseline optimization using LBFGS from ase.optimize.
"""
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device
.
split
(
":"
)[
-
1
]
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logging
.
info
(
f
"Starting baseline optimization for file
{
file
}
on device
{
device
}
."
)
start_time
=
time
.
perf_counter
()
crystal
=
read
(
file
)
# calc = mace_off(model="small", device=device)
calc
=
mace_off
(
model
=
"small"
,
device
=
"cuda"
)
crystal
.
calc
=
calc
first_optimizer_class
=
{
"LBFGS"
:
ASE_LBFGS
,
"QuasiNewton"
:
ASE_QuasiNewton
,
"BFGS"
:
ASE_BFGS
}.
get
(
first_optimizer
,
ASE_LBFGS
)
# First optimization stage
if
filter1
==
"UnitCellFilter"
:
from
ase.filters
import
UnitCellFilter
atoms_with_filter
=
UnitCellFilter
(
crystal
,
scalar_pressure
=
scalar_pressure
)
first_optimizer_instance
=
first_optimizer_class
(
atoms_with_filter
)
elif
filter1
==
"FrechetCellFilter"
:
from
ase.filters
import
FrechetCellFilter
atoms_with_filter
=
FrechetCellFilter
(
crystal
,
scalar_pressure
=
scalar_pressure
)
first_optimizer_instance
=
first_optimizer_class
(
atoms_with_filter
)
else
:
first_optimizer_instance
=
first_optimizer_class
(
crystal
)
start_time1
=
time
.
perf_counter
()
first_optimizer_instance
.
run
(
fmax
=
0.01
,
steps
=
max_steps
)
end_time1
=
time
.
perf_counter
()
# Save intermediate result
output_dir_press
=
"./cif_result_press"
output_file_press
=
os
.
path
.
join
(
output_dir_press
,
os
.
path
.
basename
(
file
).
replace
(
".cif"
,
"_press.cif"
))
crystal
.
write
(
output_file_press
)
elapsed_time1
=
end_time1
-
start_time1
steps1
=
first_optimizer_instance
.
nsteps
if
skip_second_stage
:
ret_result
=
{
"file"
:
file
,
"stage1_time"
:
elapsed_time1
,
"stage1_steps"
:
steps1
,
"stage2_time"
:
0.0
,
"stage2_steps"
:
0
,
"total_time"
:
elapsed_time1
,
"total_steps"
:
steps1
}
else
:
# Second optimization stage
crystal
=
read
(
output_file_press
)
crystal
.
calc
=
calc
second_optimizer_class
=
{
"LBFGS"
:
ASE_LBFGS
,
"QuasiNewton"
:
ASE_QuasiNewton
,
"BFGS"
:
ASE_BFGS
}.
get
(
second_optimizer
,
ASE_LBFGS
)
if
filter2
==
"UnitCellFilter"
:
from
ase.filters
import
UnitCellFilter
atoms_with_filter2
=
UnitCellFilter
(
crystal
)
second_optimizer_instance
=
second_optimizer_class
(
atoms_with_filter2
)
elif
filter2
==
"FrechetCellFilter"
:
from
ase.filters
import
FrechetCellFilter
atoms_with_filter2
=
FrechetCellFilter
(
crystal
)
second_optimizer_instance
=
second_optimizer_class
(
atoms_with_filter2
)
else
:
second_optimizer_instance
=
second_optimizer_class
(
crystal
)
start_time2
=
time
.
perf_counter
()
second_optimizer_instance
.
run
(
fmax
=
0.01
,
steps
=
max_steps
)
end_time2
=
time
.
perf_counter
()
# Save final result
output_dir_final
=
"./cif_result_final"
output_file_final
=
os
.
path
.
join
(
output_dir_final
,
os
.
path
.
basename
(
file
).
replace
(
".cif"
,
"_opt.cif"
))
crystal
.
write
(
output_file_final
)
# Collect metrics
elapsed_time2
=
end_time2
-
start_time2
total_time
=
elapsed_time1
+
elapsed_time2
steps2
=
second_optimizer_instance
.
nsteps
ret_result
=
{
"file"
:
file
,
"stage1_time"
:
elapsed_time1
,
"stage1_steps"
:
steps1
,
"stage2_time"
:
elapsed_time2
,
"stage2_steps"
:
steps2
,
"total_time"
:
total_time
,
"total_steps"
:
steps1
+
steps2
}
logging
.
info
(
f
"Baseline optimization completed for file
{
file
}
."
)
return
ret_result
def
run_baseline
(
files
,
num_workers
,
devices
,
max_steps
,
filter1
=
None
,
filter2
=
None
,
skip_second_stage
=
False
,
scalar_pressure
=
0.0006
,
optimizer1
=
None
,
optimizer2
=
None
):
"""
Runs the baseline optimization using LBFGS from ase.optimize.
"""
logging
.
info
(
f
"Starting baseline optimization with
{
num_workers
}
workers."
)
start_time
=
time
.
perf_counter
()
results
=
Parallel
(
n_jobs
=
num_workers
)(
delayed
(
baseline_task
)(
file
,
devices
[
i
%
len
(
devices
)],
max_steps
,
filter1
,
filter2
,
skip_second_stage
,
scalar_pressure
,
optimizer1
,
optimizer2
)
for
i
,
file
in
enumerate
(
files
)
)
end_time
=
time
.
perf_counter
()
csv_file
=
"results_baseline.csv"
with
open
(
csv_file
,
mode
=
'w'
,
newline
=
''
)
as
file
:
writer
=
csv
.
DictWriter
(
file
,
fieldnames
=
[
"file"
,
"stage1_time"
,
"stage1_steps"
,
"stage2_time"
,
"stage2_steps"
,
"total_time"
,
"total_steps"
])
writer
.
writeheader
()
for
result
in
results
:
writer
.
writerow
(
result
)
logging
.
info
(
f
"Baseline optimization completed in
{
end_time
-
start_time
:.
2
f
}
seconds."
)
final_elapsed_time
=
end_time
-
start_time
summary_csv_file
=
"summary_baseline.csv"
with
open
(
summary_csv_file
,
mode
=
'w'
,
newline
=
''
)
as
file
:
writer
=
csv
.
DictWriter
(
file
,
fieldnames
=
[
"elapsed_time"
,
"num_workers"
,
"batch_size"
])
writer
.
writeheader
()
writer
.
writerow
({
"elapsed_time"
:
final_elapsed_time
,
"num_workers"
:
num_workers
,
"batch_size"
:
1
})
logging
.
info
(
f
"Summary results written to
{
summary_csv_file
}
."
)
\ No newline at end of file
mace-bench/src/batchopt/extensions/__init__.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
BatchOpt Extensions - C++ and CUDA implementations for performance-critical operations.
This module provides optimized implementations of common operations using
torch.utils.cpp_extension for JIT compilation.
"""
mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/extensions/cuda_ops/__init__.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
CUDA Extension wrapper for vector addition and PBC graph operations.
"""
import
torch
from
torch.utils.cpp_extension
import
load
import
os
def
load_cuda_extension
():
"""Load the CUDA extension for vector addition."""
# Check if CUDA is available
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
"CUDA is not available. Cannot load CUDA extension."
)
# Get the directory of this file
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# Path to the CUDA source file
cuda_file
=
os
.
path
.
join
(
current_dir
,
"vector_add.cu"
)
# Load the extension
return
load
(
name
=
"vector_add_cuda"
,
sources
=
[
cuda_file
],
verbose
=
True
,
extra_cflags
=
[
'-O3'
],
extra_cuda_cflags
=
[
'-O3'
,
'--use_fast_math'
],
)
def
load_pbc_graph_cuda_extension
():
"""Load the CUDA extension for PBC graph operations."""
# Check if CUDA is available
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
"CUDA is not available. Cannot load CUDA extension."
)
# Get the directory of this file
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# Path to the CUDA source file
cuda_file
=
os
.
path
.
join
(
current_dir
,
"pbc_graph.cu"
)
# Load the extension
return
load
(
name
=
"pbc_graph_cuda"
,
sources
=
[
cuda_file
],
verbose
=
True
,
extra_cflags
=
[
'-O3'
],
extra_cuda_cflags
=
[
'-O3'
,
'--use_fast_math'
],
)
# Global variable to store loaded extension
_cuda_extension
=
None
_pbc_graph_cuda_extension
=
None
def
get_cuda_extension
():
"""Get or load the CUDA extension."""
global
_cuda_extension
if
_cuda_extension
is
None
:
_cuda_extension
=
load_cuda_extension
()
return
_cuda_extension
def
get_pbc_graph_cuda_extension
():
"""Get or load the PBC graph CUDA extension."""
global
_pbc_graph_cuda_extension
if
_pbc_graph_cuda_extension
is
None
:
_pbc_graph_cuda_extension
=
load_pbc_graph_cuda_extension
()
return
_pbc_graph_cuda_extension
def
vector_add_cuda
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Perform vector addition using CUDA implementation.
Args:
a: First input tensor (must be on CUDA device)
b: Second input tensor (must be on CUDA device)
Returns:
Result tensor of element-wise addition
"""
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
"CUDA is not available."
)
if
not
(
a
.
is_cuda
and
b
.
is_cuda
):
raise
ValueError
(
"CUDA implementation requires CUDA tensors. Use .cuda() to move tensors to GPU."
)
extension
=
get_cuda_extension
()
return
extension
.
vector_add
(
a
.
float
(),
b
.
float
())
mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/extensions/cuda_ops/pbc_graph.cu
0 → 100644
View file @
fa84b16c
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <type_traits>
// Template function to get appropriate epsilon for different floating point types
template
<
typename
T
>
__device__
__forceinline__
T
get_epsilon
()
{
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
static_cast
<
T
>
(
1e-8
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
double
>
)
{
return
static_cast
<
T
>
(
1e-12
);
}
else
{
return
static_cast
<
T
>
(
1e-8
);
// fallback
}
}
// Templated CUDA kernel for computing pairwise distances with PBC offsets
// This version avoids repeat_interleave by computing offsets directly in the kernel
template
<
typename
T
>
__global__
void
pbc_distance_kernel_optimized
(
const
T
*
pos1
,
const
T
*
pos2
,
const
T
*
pbc_offsets
,
// [batch_size, 3]
const
int64_t
*
num_atoms_per_image_sqr
,
// [batch_size]
const
int64_t
*
batch_offsets
,
// [batch_size] - cumulative offsets for each batch
T
*
distances_squared
,
bool
*
valid_mask
,
int
num_pairs
,
T
radius_squared
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
num_pairs
)
{
// Find which batch this pair belongs to
int
batch_idx
=
0
;
while
(
batch_idx
<
num_pairs
&&
idx
>=
batch_offsets
[
batch_idx
+
1
])
{
batch_idx
++
;
}
// Get PBC offset for this batch
T
offset_x
=
pbc_offsets
[
batch_idx
*
3
];
T
offset_y
=
pbc_offsets
[
batch_idx
*
3
+
1
];
T
offset_z
=
pbc_offsets
[
batch_idx
*
3
+
2
];
// Get positions for this atom pair with PBC offset
T
dx
=
pos2
[
idx
*
3
]
-
pos1
[
idx
*
3
]
+
offset_x
;
T
dy
=
pos2
[
idx
*
3
+
1
]
-
pos1
[
idx
*
3
+
1
]
+
offset_y
;
T
dz
=
pos2
[
idx
*
3
+
2
]
-
pos1
[
idx
*
3
+
2
]
+
offset_z
;
// Compute squared distance
T
dist_sq
=
dx
*
dx
+
dy
*
dy
+
dz
*
dz
;
distances_squared
[
idx
]
=
dist_sq
;
// Check if within radius
valid_mask
[
idx
]
=
(
dist_sq
<=
radius_squared
)
&&
(
dist_sq
>
get_epsilon
<
T
>
());
}
}
// Original kernel for fallback
template
<
typename
T
>
__global__
void
pbc_distance_kernel
(
const
T
*
pos1
,
const
T
*
pos2
,
const
T
*
pbc_offsets
,
T
*
distances_squared
,
bool
*
valid_mask
,
int
num_pairs
,
T
radius_squared
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
num_pairs
)
{
// Get positions for this atom pair
T
dx
=
pos2
[
idx
*
3
]
-
pos1
[
idx
*
3
]
+
pbc_offsets
[
idx
*
3
];
T
dy
=
pos2
[
idx
*
3
+
1
]
-
pos1
[
idx
*
3
+
1
]
+
pbc_offsets
[
idx
*
3
+
1
];
T
dz
=
pos2
[
idx
*
3
+
2
]
-
pos1
[
idx
*
3
+
2
]
+
pbc_offsets
[
idx
*
3
+
2
];
// Compute squared distance
T
dist_sq
=
dx
*
dx
+
dy
*
dy
+
dz
*
dz
;
distances_squared
[
idx
]
=
dist_sq
;
// Check if within radius
valid_mask
[
idx
]
=
(
dist_sq
<=
radius_squared
)
&&
(
dist_sq
>
get_epsilon
<
T
>
());
}
}
// Template helper function to launch the appropriate optimized kernel
template
<
typename
T
>
inline
void
launch_pbc_distance_kernel_optimized
(
const
T
*
pos1
,
const
T
*
pos2
,
const
T
*
pbc_offsets
,
const
int64_t
*
num_atoms_per_image_sqr
,
const
int64_t
*
batch_offsets
,
T
*
distances_squared
,
bool
*
valid_mask
,
int
num_pairs
,
T
radius_squared
,
int
blocks
,
int
threads_per_block
)
{
pbc_distance_kernel_optimized
<
T
><<<
blocks
,
threads_per_block
>>>
(
pos1
,
pos2
,
pbc_offsets
,
num_atoms_per_image_sqr
,
batch_offsets
,
distances_squared
,
valid_mask
,
num_pairs
,
radius_squared
);
}
// Template helper function to launch the appropriate kernel (fallback)
template
<
typename
T
>
void
launch_pbc_distance_kernel
(
const
T
*
pos1
,
const
T
*
pos2
,
const
T
*
pbc_offsets
,
T
*
distances_squared
,
bool
*
valid_mask
,
int
num_pairs
,
T
radius_squared
,
int
blocks
,
int
threads_per_block
)
{
pbc_distance_kernel
<
T
><<<
blocks
,
threads_per_block
>>>
(
pos1
,
pos2
,
pbc_offsets
,
distances_squared
,
valid_mask
,
num_pairs
,
radius_squared
);
}
// CUDA function to compute distances for all unit cell offsets
std
::
vector
<
torch
::
Tensor
>
pbc_distance_cuda
(
torch
::
Tensor
pos1
,
torch
::
Tensor
pos2
,
torch
::
Tensor
data_cell
,
torch
::
Tensor
num_atoms_per_image_sqr
,
int
batch_size
,
std
::
vector
<
int
>
max_rep
,
float
radius
,
torch
::
Device
device
)
{
// Convert tensors to CUDA if not already, but preserve original dtype
pos1
=
pos1
.
to
(
device
).
contiguous
();
pos2
=
pos2
.
to
(
device
).
contiguous
();
data_cell
=
data_cell
.
to
(
device
).
contiguous
();
num_atoms_per_image_sqr
=
num_atoms_per_image_sqr
.
to
(
device
);
// Check that all position tensors have the same dtype
TORCH_CHECK
(
pos1
.
dtype
()
==
pos2
.
dtype
(),
"pos1 and pos2 must have the same dtype"
);
TORCH_CHECK
(
pos1
.
dtype
()
==
data_cell
.
dtype
(),
"pos1 and data_cell must have the same dtype"
);
// Determine if we're working with float32 or float64
bool
is_float64
=
pos1
.
dtype
()
==
torch
::
kFloat64
;
int
num_pairs
=
pos1
.
size
(
0
);
// Storage for all results across unit cells
std
::
vector
<
torch
::
Tensor
>
all_index1
,
all_index2
,
all_unit_cell
,
all_distances_sq
;
// Create base indices for original atom pairs
torch
::
Tensor
base_indices
=
torch
::
arange
(
num_pairs
,
torch
::
dtype
(
torch
::
kLong
).
device
(
device
));
// Launch parameters
int
threads_per_block
=
512
;
int
blocks
=
(
num_pairs
+
threads_per_block
-
1
)
/
threads_per_block
;
// Pre-allocate tensors outside the loop for reuse
torch
::
Tensor
distances_squared
=
torch
::
zeros
({
num_pairs
},
torch
::
dtype
(
pos1
.
dtype
()).
device
(
device
));
torch
::
Tensor
valid_mask
=
torch
::
zeros
({
num_pairs
},
torch
::
dtype
(
torch
::
kBool
).
device
(
device
));
torch
::
Tensor
unit_cell_offset
=
torch
::
zeros
({
3
},
torch
::
dtype
(
pos1
.
dtype
()).
device
(
device
));
torch
::
Tensor
unit_cell_offset_batch
=
torch
::
zeros
({
batch_size
,
3
,
1
},
torch
::
dtype
(
pos1
.
dtype
()).
device
(
device
));
// Pre-compute batch offsets for optimized kernel
torch
::
Tensor
batch_offsets
=
torch
::
zeros
({
batch_size
+
1
},
torch
::
dtype
(
torch
::
kLong
).
device
(
device
));
torch
::
Tensor
cumsum
=
torch
::
cumsum
(
num_atoms_per_image_sqr
,
0
);
batch_offsets
.
slice
(
0
,
1
,
batch_size
+
1
)
=
cumsum
;
// Iterate over unit cell offsets (triple loop)
// NOTE: for i, j, k loop can not be flatten, as we need to limit the device memory usage
#pragma unroll
for
(
int
i
=
-
max_rep
[
0
];
i
<=
max_rep
[
0
];
i
++
)
{
#pragma unroll
for
(
int
j
=
-
max_rep
[
1
];
j
<=
max_rep
[
1
];
j
++
)
{
#pragma unroll
for
(
int
k
=
-
max_rep
[
2
];
k
<=
max_rep
[
2
];
k
++
)
{
// Reuse pre-allocated unit cell offset tensor
unit_cell_offset
[
0
]
=
static_cast
<
float
>
(
i
);
unit_cell_offset
[
1
]
=
static_cast
<
float
>
(
j
);
unit_cell_offset
[
2
]
=
static_cast
<
float
>
(
k
);
// Compute PBC offsets for this unit cell
// unit_cell_offset_batch.fill_(0);
unit_cell_offset_batch
.
select
(
2
,
0
)
=
unit_cell_offset
.
unsqueeze
(
0
).
expand
({
batch_size
,
-
1
});
torch
::
Tensor
pbc_offsets
=
torch
::
bmm
(
data_cell
,
unit_cell_offset_batch
).
squeeze
(
-
1
);
// // Optimized: Use index_select instead of repeat_interleave
// // Create index tensor for selecting pbc_offsets based on atom pairs
// int64_t offset = 0;
// for (int b = 0; b < batch_size; b++) {
// int64_t num_pairs_in_batch = num_atoms_per_image_sqr[b].item<int64_t>();
// auto batch_indices = torch::full({num_pairs_in_batch}, b,
// torch::dtype(torch::kLong).device(device));
// pbc_offsets_per_atom.slice(0, offset, offset + num_pairs_in_batch) =
// pbc_offsets.index_select(0, batch_indices);
// offset += num_pairs_in_batch;
// }
// Reset output tensors for reuse
// distances_squared.fill_(0);
// valid_mask.fill_(false);
// Launch templated CUDA kernel
if
(
is_float64
)
{
double
radius_squared
=
static_cast
<
double
>
(
radius
)
*
static_cast
<
double
>
(
radius
);
launch_pbc_distance_kernel_optimized
<
double
>
(
pos1
.
data_ptr
<
double
>
(),
pos2
.
data_ptr
<
double
>
(),
// pbc_offsets_per_atom.data_ptr<double>(),
pbc_offsets
.
data_ptr
<
double
>
(),
num_atoms_per_image_sqr
.
data_ptr
<
int64_t
>
(),
batch_offsets
.
data_ptr
<
int64_t
>
(),
distances_squared
.
data_ptr
<
double
>
(),
valid_mask
.
data_ptr
<
bool
>
(),
num_pairs
,
radius_squared
,
blocks
,
threads_per_block
);
}
else
{
float
radius_squared
=
radius
*
radius
;
launch_pbc_distance_kernel_optimized
<
float
>
(
pos1
.
data_ptr
<
float
>
(),
pos2
.
data_ptr
<
float
>
(),
// pbc_offsets_per_atom.data_ptr<float>(),
pbc_offsets
.
data_ptr
<
float
>
(),
num_atoms_per_image_sqr
.
data_ptr
<
int64_t
>
(),
batch_offsets
.
data_ptr
<
int64_t
>
(),
distances_squared
.
data_ptr
<
float
>
(),
valid_mask
.
data_ptr
<
bool
>
(),
num_pairs
,
radius_squared
,
blocks
,
threads_per_block
);
}
// Filter valid pairs
torch
::
Tensor
valid_indices
=
torch
::
nonzero
(
valid_mask
).
squeeze
(
-
1
);
if
(
valid_indices
.
numel
()
>
0
)
{
torch
::
Tensor
valid_base_indices
=
base_indices
.
index_select
(
0
,
valid_indices
);
torch
::
Tensor
valid_distances
=
distances_squared
.
index_select
(
0
,
valid_indices
);
torch
::
Tensor
valid_unit_cell
=
unit_cell_offset
.
unsqueeze
(
0
).
repeat
({
valid_indices
.
size
(
0
),
1
});
all_index1
.
push_back
(
valid_base_indices
);
all_unit_cell
.
push_back
(
valid_unit_cell
);
all_distances_sq
.
push_back
(
valid_distances
);
}
}
}
}
// Single synchronization after all kernel launches
cudaDeviceSynchronize
();
// Concatenate results
torch
::
Tensor
final_indices
,
final_unit_cell
,
final_distances
;
if
(
all_index1
.
size
()
>
0
)
{
final_indices
=
torch
::
cat
(
all_index1
);
final_unit_cell
=
torch
::
cat
(
all_unit_cell
);
final_distances
=
torch
::
cat
(
all_distances_sq
);
}
else
{
final_indices
=
torch
::
empty
({
0
},
torch
::
dtype
(
torch
::
kLong
).
device
(
device
));
final_unit_cell
=
torch
::
empty
({
0
,
3
},
torch
::
dtype
(
pos1
.
dtype
()).
device
(
device
));
final_distances
=
torch
::
empty
({
0
},
torch
::
dtype
(
pos1
.
dtype
()).
device
(
device
));
}
return
{
final_indices
,
final_unit_cell
,
final_distances
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"pbc_distance_cuda"
,
&
pbc_distance_cuda
,
"PBC distance computation with CUDA"
);
}
mace-bench/src/batchopt/pbc_graph.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
CUDA-accelerated PBC graph operations for atomic systems.
"""
import
torch
from
typing
import
Optional
,
List
from
.pbc_graph_legacy
import
get_max_neighbors_mask
from
.extensions.cuda_ops
import
get_pbc_graph_cuda_extension
def
radius_graph_pbc_cuda
(
data
,
radius
,
max_num_neighbors_threshold
,
enforce_max_neighbors_strictly
:
bool
=
False
,
pbc
=
None
,
dtype
=
torch
.
float64
,
):
"""
Memory-efficient CUDA-accelerated version of radius_graph_pbc.
This implementation follows the memory-efficient approach with triple loops
but accelerates the distance computation using CUDA kernels.
"""
if
pbc
is
None
:
pbc
=
[
True
,
True
,
True
]
device
=
data
.
pos
.
device
batch_size
=
len
(
data
.
natoms
)
# Handle PBC settings
if
hasattr
(
data
,
"pbc"
):
data
.
pbc
=
torch
.
atleast_2d
(
data
.
pbc
)
for
i
in
range
(
3
):
if
not
torch
.
any
(
data
.
pbc
[:,
i
]).
item
():
pbc
[
i
]
=
False
elif
torch
.
all
(
data
.
pbc
[:,
i
]).
item
():
pbc
[
i
]
=
True
else
:
raise
RuntimeError
(
"Different structures in the batch have different PBC configurations."
)
# position of the atoms
atom_pos
=
data
.
pos
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
num_atoms_per_image
=
data
.
natoms
num_atoms_per_image_sqr
=
(
num_atoms_per_image
**
2
).
long
()
# index offset between images
index_offset
=
torch
.
cumsum
(
num_atoms_per_image
,
dim
=
0
)
-
num_atoms_per_image
index_offset_expand
=
torch
.
repeat_interleave
(
index_offset
,
num_atoms_per_image_sqr
)
num_atoms_per_image_expand
=
torch
.
repeat_interleave
(
num_atoms_per_image
,
num_atoms_per_image_sqr
)
# Compute atom pair indices
num_atom_pairs
=
torch
.
sum
(
num_atoms_per_image_sqr
)
index_sqr_offset
=
(
torch
.
cumsum
(
num_atoms_per_image_sqr
,
dim
=
0
)
-
num_atoms_per_image_sqr
)
index_sqr_offset
=
torch
.
repeat_interleave
(
index_sqr_offset
,
num_atoms_per_image_sqr
)
atom_count_sqr
=
torch
.
arange
(
num_atom_pairs
,
device
=
device
)
-
index_sqr_offset
# Compute the indices for the pairs of atoms (using division and mod)
index1
=
(
torch
.
div
(
atom_count_sqr
,
num_atoms_per_image_expand
,
rounding_mode
=
"floor"
)
)
+
index_offset_expand
index2
=
(
atom_count_sqr
%
num_atoms_per_image_expand
)
+
index_offset_expand
# Get the positions for each atom
pos1
=
torch
.
index_select
(
atom_pos
,
0
,
index1
)
pos2
=
torch
.
index_select
(
atom_pos
,
0
,
index2
)
# Calculate required number of unit cells in each direction for PBC
cross_a2a3
=
torch
.
cross
(
data
.
cell
[:,
1
],
data
.
cell
[:,
2
],
dim
=-
1
)
cell_vol
=
torch
.
sum
(
data
.
cell
[:,
0
]
*
cross_a2a3
,
dim
=-
1
,
keepdim
=
True
)
if
pbc
[
0
]:
inv_min_dist_a1
=
torch
.
norm
(
cross_a2a3
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a1
=
torch
.
ceil
(
radius
*
inv_min_dist_a1
)
else
:
rep_a1
=
data
.
cell
.
new_zeros
(
1
)
if
pbc
[
1
]:
cross_a3a1
=
torch
.
cross
(
data
.
cell
[:,
2
],
data
.
cell
[:,
0
],
dim
=-
1
)
inv_min_dist_a2
=
torch
.
norm
(
cross_a3a1
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a2
=
torch
.
ceil
(
radius
*
inv_min_dist_a2
)
else
:
rep_a2
=
data
.
cell
.
new_zeros
(
1
)
if
pbc
[
2
]:
cross_a1a2
=
torch
.
cross
(
data
.
cell
[:,
0
],
data
.
cell
[:,
1
],
dim
=-
1
)
inv_min_dist_a3
=
torch
.
norm
(
cross_a1a2
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a3
=
torch
.
ceil
(
radius
*
inv_min_dist_a3
)
else
:
rep_a3
=
data
.
cell
.
new_zeros
(
1
)
# Take the max over all images for uniformity
max_rep
=
[
int
(
2
*
rep_a1
.
max
().
item
()),
int
(
2
*
rep_a2
.
max
().
item
()),
int
(
2
*
rep_a3
.
max
().
item
())]
# Pre-transpose data_cell for efficiency
data_cell
=
torch
.
transpose
(
data
.
cell
,
1
,
2
)
# Use CUDA kernel for the triple loop computation
# try:
pbc_graph_cuda
=
get_pbc_graph_cuda_extension
()
# Call the CUDA implementation
valid_pair_indices
,
unit_cell
,
atom_distance_sqr
=
pbc_graph_cuda
.
pbc_distance_cuda
(
pos1
,
pos2
,
data_cell
,
num_atoms_per_image_sqr
,
batch_size
,
max_rep
,
float
(
radius
),
device
)
# Map back to original index1 and index2
if
len
(
valid_pair_indices
)
>
0
:
index1
=
index1
.
index_select
(
0
,
valid_pair_indices
.
long
())
index2
=
index2
.
index_select
(
0
,
valid_pair_indices
.
long
())
else
:
index1
=
torch
.
empty
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
index2
=
torch
.
empty
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
unit_cell
=
torch
.
empty
(
0
,
3
,
dtype
=
dtype
,
device
=
device
)
atom_distance_sqr
=
torch
.
empty
(
0
,
dtype
=
dtype
,
device
=
device
)
# Sort index1 in ascending order and rearrange other arrays correspondingly
if
len
(
index1
)
>
0
:
sort_indices
=
torch
.
argsort
(
index1
)
index1
=
index1
[
sort_indices
]
index2
=
index2
[
sort_indices
]
unit_cell
=
unit_cell
[
sort_indices
]
atom_distance_sqr
=
atom_distance_sqr
[
sort_indices
]
mask_num_neighbors
,
num_neighbors_image
=
get_max_neighbors_mask
(
natoms
=
data
.
natoms
,
index
=
index1
,
atom_distance
=
atom_distance_sqr
,
max_num_neighbors_threshold
=
max_num_neighbors_threshold
,
enforce_max_strictly
=
enforce_max_neighbors_strictly
,
)
if
not
torch
.
all
(
mask_num_neighbors
):
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
index1
=
torch
.
masked_select
(
index1
,
mask_num_neighbors
)
index2
=
torch
.
masked_select
(
index2
,
mask_num_neighbors
)
unit_cell
=
torch
.
masked_select
(
unit_cell
.
view
(
-
1
,
3
),
mask_num_neighbors
.
view
(
-
1
,
1
).
expand
(
-
1
,
3
)
)
unit_cell
=
unit_cell
.
view
(
-
1
,
3
)
edge_index
=
torch
.
stack
((
index2
,
index1
))
return
edge_index
,
unit_cell
,
num_neighbors_image
\ No newline at end of file
mace-bench/src/batchopt/pbc_graph_legacy.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch_geometric
from
matplotlib.backends.backend_agg
import
FigureCanvasAgg
as
FigureCanvas
from
matplotlib.figure
import
Figure
from
torch_geometric.data
import
Data
from
torch_geometric.utils
import
remove_self_loops
from
torch_scatter
import
scatter
,
segment_coo
,
segment_csr
if
TYPE_CHECKING
:
from
collections.abc
import
Mapping
from
torch.nn.modules.module
import
_IncompatibleKeys
DEFAULT_ENV_VARS
=
{
# Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes).
# see https://pytorch.org/docs/stable/notes/cuda.html.
"PYTORCH_CUDA_ALLOC_CONF"
:
"expandable_segments:True"
,
}
def
get_pbc_distances
(
pos
,
edge_index
,
cell
,
cell_offsets
,
neighbors
,
return_offsets
:
bool
=
False
,
return_distance_vec
:
bool
=
False
,
):
row
,
col
=
edge_index
distance_vectors
=
pos
[
row
]
-
pos
[
col
]
# correct for pbc
neighbors
=
neighbors
.
to
(
cell
.
device
)
cell
=
torch
.
repeat_interleave
(
cell
,
neighbors
,
dim
=
0
)
offsets
=
cell_offsets
.
float
().
view
(
-
1
,
1
,
3
).
bmm
(
cell
.
float
()).
view
(
-
1
,
3
)
distance_vectors
+=
offsets
# compute distances
distances
=
distance_vectors
.
norm
(
dim
=-
1
)
# redundancy: remove zero distances
nonzero_idx
=
torch
.
arange
(
len
(
distances
),
device
=
distances
.
device
)[
distances
!=
0
]
edge_index
=
edge_index
[:,
nonzero_idx
]
distances
=
distances
[
nonzero_idx
]
out
=
{
"edge_index"
:
edge_index
,
"distances"
:
distances
,
}
if
return_distance_vec
:
out
[
"distance_vec"
]
=
distance_vectors
[
nonzero_idx
]
if
return_offsets
:
out
[
"offsets"
]
=
offsets
[
nonzero_idx
]
return
out
def
radius_graph_pbc_mem_effi
(
data
,
radius
,
max_num_neighbors_threshold
,
enforce_max_neighbors_strictly
:
bool
=
False
,
pbc
=
None
,
dtype
=
torch
.
float64
,
):
if
pbc
is
None
:
pbc
=
[
True
,
True
,
True
]
device
=
data
.
pos
.
device
batch_size
=
len
(
data
.
natoms
)
if
hasattr
(
data
,
"pbc"
):
data
.
pbc
=
torch
.
atleast_2d
(
data
.
pbc
)
for
i
in
range
(
3
):
if
not
torch
.
any
(
data
.
pbc
[:,
i
]).
item
():
pbc
[
i
]
=
False
elif
torch
.
all
(
data
.
pbc
[:,
i
]).
item
():
pbc
[
i
]
=
True
else
:
raise
RuntimeError
(
"Different structures in the batch have different PBC configurations. This is not currently supported."
)
# position of the atoms
atom_pos
=
data
.
pos
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
num_atoms_per_image
=
data
.
natoms
num_atoms_per_image_sqr
=
(
num_atoms_per_image
**
2
).
long
()
# index offset between images
index_offset
=
torch
.
cumsum
(
num_atoms_per_image
,
dim
=
0
)
-
num_atoms_per_image
index_offset_expand
=
torch
.
repeat_interleave
(
index_offset
,
num_atoms_per_image_sqr
)
num_atoms_per_image_expand
=
torch
.
repeat_interleave
(
num_atoms_per_image
,
num_atoms_per_image_sqr
)
# Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
# that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
# the following (but 10x faster since it removes the for loop)
# for batch_idx in range(batch_size):
# batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
num_atom_pairs
=
torch
.
sum
(
num_atoms_per_image_sqr
)
index_sqr_offset
=
(
torch
.
cumsum
(
num_atoms_per_image_sqr
,
dim
=
0
)
-
num_atoms_per_image_sqr
)
index_sqr_offset
=
torch
.
repeat_interleave
(
index_sqr_offset
,
num_atoms_per_image_sqr
)
atom_count_sqr
=
torch
.
arange
(
num_atom_pairs
,
device
=
device
)
-
index_sqr_offset
# Compute the indices for the pairs of atoms (using division and mod)
# If the systems get too large this apporach could run into numerical precision issues
index1
=
(
torch
.
div
(
atom_count_sqr
,
num_atoms_per_image_expand
,
rounding_mode
=
"floor"
)
)
+
index_offset_expand
index2
=
(
atom_count_sqr
%
num_atoms_per_image_expand
)
+
index_offset_expand
# Get the positions for each atom
pos1
=
torch
.
index_select
(
atom_pos
,
0
,
index1
)
pos2
=
torch
.
index_select
(
atom_pos
,
0
,
index2
)
# Calculate required number of unit cells in each direction.
# Smallest distance between planes separated by a1 is
# 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
# Note that the unit cell volume V = a1 * (a2 x a3) and that
# (a2 x a3) / V is also the reciprocal primitive vector
# (crystallographer's definition).
cross_a2a3
=
torch
.
cross
(
data
.
cell
[:,
1
],
data
.
cell
[:,
2
],
dim
=-
1
)
cell_vol
=
torch
.
sum
(
data
.
cell
[:,
0
]
*
cross_a2a3
,
dim
=-
1
,
keepdim
=
True
)
if
pbc
[
0
]:
inv_min_dist_a1
=
torch
.
norm
(
cross_a2a3
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a1
=
torch
.
ceil
(
radius
*
inv_min_dist_a1
)
else
:
rep_a1
=
data
.
cell
.
new_zeros
(
1
)
if
pbc
[
1
]:
cross_a3a1
=
torch
.
cross
(
data
.
cell
[:,
2
],
data
.
cell
[:,
0
],
dim
=-
1
)
inv_min_dist_a2
=
torch
.
norm
(
cross_a3a1
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a2
=
torch
.
ceil
(
radius
*
inv_min_dist_a2
)
else
:
rep_a2
=
data
.
cell
.
new_zeros
(
1
)
if
pbc
[
2
]:
cross_a1a2
=
torch
.
cross
(
data
.
cell
[:,
0
],
data
.
cell
[:,
1
],
dim
=-
1
)
inv_min_dist_a3
=
torch
.
norm
(
cross_a1a2
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a3
=
torch
.
ceil
(
radius
*
inv_min_dist_a3
)
else
:
rep_a3
=
data
.
cell
.
new_zeros
(
1
)
# Take the max over all images for uniformity. This is essentially padding.
# Note that this can significantly increase the number of computed distances
# if the required repetitions are very different between images
# (which they usually are). Changing this to sparse (scatter) operations
# might be worth the effort if this function becomes a bottleneck.
max_rep
=
[
int
(
2
*
rep_a1
.
max
().
item
()),
int
(
2
*
rep_a2
.
max
().
item
()),
int
(
2
*
rep_a3
.
max
().
item
())]
# Memory-efficient implementation: iterate over unit cell offsets instead of expanding all at once
# This reduces memory usage by avoiding the creation of large tensor products
all_index1
=
[]
all_index2
=
[]
all_unit_cell
=
[]
all_atom_distance_sqr
=
[]
# Pre-transpose data_cell for efficiency
data_cell
=
torch
.
transpose
(
data
.
cell
,
1
,
2
)
# Iterate over each unit cell offset combination
for
i
in
range
(
-
max_rep
[
0
],
max_rep
[
0
]
+
1
):
for
j
in
range
(
-
max_rep
[
1
],
max_rep
[
1
]
+
1
):
for
k
in
range
(
-
max_rep
[
2
],
max_rep
[
2
]
+
1
):
# Create unit cell offset
unit_cell_offset
=
torch
.
tensor
([
i
,
j
,
k
],
device
=
device
,
dtype
=
dtype
)
# Compute the x, y, z positional offsets for this specific cell in each image
# unit_cell_offset_batch = unit_cell_offset.view(3, 1).expand(3, batch_size)
unit_cell_offset_batch
=
unit_cell_offset
.
view
(
1
,
3
,
1
).
expand
(
batch_size
,
-
1
,
-
1
)
pbc_offsets
=
torch
.
bmm
(
data_cell
,
unit_cell_offset_batch
).
squeeze
(
-
1
)
pbc_offsets_per_atom
=
torch
.
repeat_interleave
(
pbc_offsets
,
num_atoms_per_image_sqr
,
dim
=
0
)
# Apply PBC offsets to the second atom positions
pos2_offset
=
pos2
+
pbc_offsets_per_atom
# Compute the squared distance between atoms
atom_distance_sqr
=
torch
.
sum
((
pos1
-
pos2_offset
)
**
2
,
dim
=
1
)
# Remove pairs that are too far apart
mask_within_radius
=
torch
.
le
(
atom_distance_sqr
,
radius
*
radius
)
# Remove pairs with the same atoms (distance = 0.0)
mask_not_same
=
torch
.
gt
(
atom_distance_sqr
,
0.0001
)
mask
=
torch
.
logical_and
(
mask_within_radius
,
mask_not_same
)
# Only keep valid pairs for this unit cell offset
if
torch
.
any
(
mask
):
valid_index1
=
torch
.
masked_select
(
index1
,
mask
)
valid_index2
=
torch
.
masked_select
(
index2
,
mask
)
valid_distances
=
torch
.
masked_select
(
atom_distance_sqr
,
mask
)
valid_unit_cell
=
unit_cell_offset
.
unsqueeze
(
0
).
repeat
(
valid_index1
.
shape
[
0
],
1
)
all_index1
.
append
(
valid_index1
)
all_index2
.
append
(
valid_index2
)
all_unit_cell
.
append
(
valid_unit_cell
)
all_atom_distance_sqr
.
append
(
valid_distances
)
# Concatenate all results
if
len
(
all_index1
)
>
0
:
index1
=
torch
.
cat
(
all_index1
)
index2
=
torch
.
cat
(
all_index2
)
unit_cell
=
torch
.
cat
(
all_unit_cell
)
atom_distance_sqr
=
torch
.
cat
(
all_atom_distance_sqr
)
# Sort index1 in ascending order and rearrange other arrays correspondingly
sort_indices
=
torch
.
argsort
(
index1
)
index1
=
index1
[
sort_indices
]
index2
=
index2
[
sort_indices
]
unit_cell
=
unit_cell
[
sort_indices
]
atom_distance_sqr
=
atom_distance_sqr
[
sort_indices
]
else
:
# No valid pairs found
index1
=
torch
.
empty
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
index2
=
torch
.
empty
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
unit_cell
=
torch
.
empty
(
0
,
3
,
dtype
=
dtype
,
device
=
device
)
atom_distance_sqr
=
torch
.
empty
(
0
,
dtype
=
dtype
,
device
=
device
)
mask_num_neighbors
,
num_neighbors_image
=
get_max_neighbors_mask
(
natoms
=
data
.
natoms
,
index
=
index1
,
atom_distance
=
atom_distance_sqr
,
max_num_neighbors_threshold
=
max_num_neighbors_threshold
,
enforce_max_strictly
=
enforce_max_neighbors_strictly
,
)
if
not
torch
.
all
(
mask_num_neighbors
):
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
index1
=
torch
.
masked_select
(
index1
,
mask_num_neighbors
)
index2
=
torch
.
masked_select
(
index2
,
mask_num_neighbors
)
unit_cell
=
torch
.
masked_select
(
unit_cell
.
view
(
-
1
,
3
),
mask_num_neighbors
.
view
(
-
1
,
1
).
expand
(
-
1
,
3
)
)
unit_cell
=
unit_cell
.
view
(
-
1
,
3
)
edge_index
=
torch
.
stack
((
index2
,
index1
))
return
edge_index
,
unit_cell
,
num_neighbors_image
def
radius_graph_pbc
(
data
,
radius
,
max_num_neighbors_threshold
,
enforce_max_neighbors_strictly
:
bool
=
False
,
pbc
=
None
,
dtype
=
torch
.
float64
,
):
if
pbc
is
None
:
pbc
=
[
True
,
True
,
True
]
device
=
data
.
pos
.
device
batch_size
=
len
(
data
.
natoms
)
if
hasattr
(
data
,
"pbc"
):
data
.
pbc
=
torch
.
atleast_2d
(
data
.
pbc
)
for
i
in
range
(
3
):
if
not
torch
.
any
(
data
.
pbc
[:,
i
]).
item
():
pbc
[
i
]
=
False
elif
torch
.
all
(
data
.
pbc
[:,
i
]).
item
():
pbc
[
i
]
=
True
else
:
raise
RuntimeError
(
"Different structures in the batch have different PBC configurations. This is not currently supported."
)
# position of the atoms
atom_pos
=
data
.
pos
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
num_atoms_per_image
=
data
.
natoms
num_atoms_per_image_sqr
=
(
num_atoms_per_image
**
2
).
long
()
# index offset between images
index_offset
=
torch
.
cumsum
(
num_atoms_per_image
,
dim
=
0
)
-
num_atoms_per_image
index_offset_expand
=
torch
.
repeat_interleave
(
index_offset
,
num_atoms_per_image_sqr
)
num_atoms_per_image_expand
=
torch
.
repeat_interleave
(
num_atoms_per_image
,
num_atoms_per_image_sqr
)
# Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
# that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
# the following (but 10x faster since it removes the for loop)
# for batch_idx in range(batch_size):
# batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
num_atom_pairs
=
torch
.
sum
(
num_atoms_per_image_sqr
)
index_sqr_offset
=
(
torch
.
cumsum
(
num_atoms_per_image_sqr
,
dim
=
0
)
-
num_atoms_per_image_sqr
)
index_sqr_offset
=
torch
.
repeat_interleave
(
index_sqr_offset
,
num_atoms_per_image_sqr
)
atom_count_sqr
=
torch
.
arange
(
num_atom_pairs
,
device
=
device
)
-
index_sqr_offset
# Compute the indices for the pairs of atoms (using division and mod)
# If the systems get too large this apporach could run into numerical precision issues
index1
=
(
torch
.
div
(
atom_count_sqr
,
num_atoms_per_image_expand
,
rounding_mode
=
"floor"
)
)
+
index_offset_expand
index2
=
(
atom_count_sqr
%
num_atoms_per_image_expand
)
+
index_offset_expand
# Get the positions for each atom
pos1
=
torch
.
index_select
(
atom_pos
,
0
,
index1
)
pos2
=
torch
.
index_select
(
atom_pos
,
0
,
index2
)
# Calculate required number of unit cells in each direction.
# Smallest distance between planes separated by a1 is
# 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
# Note that the unit cell volume V = a1 * (a2 x a3) and that
# (a2 x a3) / V is also the reciprocal primitive vector
# (crystallographer's definition).
cross_a2a3
=
torch
.
cross
(
data
.
cell
[:,
1
],
data
.
cell
[:,
2
],
dim
=-
1
)
cell_vol
=
torch
.
sum
(
data
.
cell
[:,
0
]
*
cross_a2a3
,
dim
=-
1
,
keepdim
=
True
)
if
pbc
[
0
]:
inv_min_dist_a1
=
torch
.
norm
(
cross_a2a3
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a1
=
torch
.
ceil
(
radius
*
inv_min_dist_a1
)
else
:
rep_a1
=
data
.
cell
.
new_zeros
(
1
)
if
pbc
[
1
]:
cross_a3a1
=
torch
.
cross
(
data
.
cell
[:,
2
],
data
.
cell
[:,
0
],
dim
=-
1
)
inv_min_dist_a2
=
torch
.
norm
(
cross_a3a1
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a2
=
torch
.
ceil
(
radius
*
inv_min_dist_a2
)
else
:
rep_a2
=
data
.
cell
.
new_zeros
(
1
)
if
pbc
[
2
]:
cross_a1a2
=
torch
.
cross
(
data
.
cell
[:,
0
],
data
.
cell
[:,
1
],
dim
=-
1
)
inv_min_dist_a3
=
torch
.
norm
(
cross_a1a2
/
cell_vol
,
p
=
2
,
dim
=-
1
)
rep_a3
=
torch
.
ceil
(
radius
*
inv_min_dist_a3
)
else
:
rep_a3
=
data
.
cell
.
new_zeros
(
1
)
# Take the max over all images for uniformity. This is essentially padding.
# Note that this can significantly increase the number of computed distances
# if the required repetitions are very different between images
# (which they usually are). Changing this to sparse (scatter) operations
# might be worth the effort if this function becomes a bottleneck.
max_rep
=
[
2
*
rep_a1
.
max
(),
2
*
rep_a2
.
max
(),
2
*
rep_a3
.
max
()]
# max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()]
# max_rep = [torch.tensor(1, device=device)] * 3
# logging.info(f"&&& max_rep: {max_rep}")
# Tensor of unit cells
cells_per_dim
=
[
torch
.
arange
(
-
rep
.
item
(),
rep
.
item
()
+
1
,
device
=
device
,
dtype
=
dtype
)
for
rep
in
max_rep
]
unit_cell
=
torch
.
cartesian_prod
(
*
cells_per_dim
)
num_cells
=
len
(
unit_cell
)
unit_cell_per_atom
=
unit_cell
.
view
(
1
,
num_cells
,
3
).
repeat
(
len
(
index2
),
1
,
1
)
unit_cell
=
torch
.
transpose
(
unit_cell
,
0
,
1
)
unit_cell_batch
=
unit_cell
.
view
(
1
,
3
,
num_cells
).
expand
(
batch_size
,
-
1
,
-
1
)
# Compute the x, y, z positional offsets for each cell in each image
# data_cell = torch.transpose(data.cell, 1, 2)
data_cell
=
torch
.
transpose
(
data
.
cell
,
1
,
2
)
pbc_offsets
=
torch
.
bmm
(
data_cell
,
unit_cell_batch
)
pbc_offsets_per_atom
=
torch
.
repeat_interleave
(
pbc_offsets
,
num_atoms_per_image_sqr
,
dim
=
0
)
# Expand the positions and indices for the 9 cells
pos1
=
pos1
.
view
(
-
1
,
3
,
1
).
expand
(
-
1
,
-
1
,
num_cells
)
pos2
=
pos2
.
view
(
-
1
,
3
,
1
).
expand
(
-
1
,
-
1
,
num_cells
)
index1
=
index1
.
view
(
-
1
,
1
).
repeat
(
1
,
num_cells
).
view
(
-
1
)
index2
=
index2
.
view
(
-
1
,
1
).
repeat
(
1
,
num_cells
).
view
(
-
1
)
# Add the PBC offsets for the second atom
pos2
=
pos2
+
pbc_offsets_per_atom
# Compute the squared distance between atoms
atom_distance_sqr
=
torch
.
sum
((
pos1
-
pos2
)
**
2
,
dim
=
1
)
atom_distance_sqr
=
atom_distance_sqr
.
view
(
-
1
)
# Remove pairs that are too far apart
mask_within_radius
=
torch
.
le
(
atom_distance_sqr
,
radius
*
radius
)
# Remove pairs with the same atoms (distance = 0.0)
mask_not_same
=
torch
.
gt
(
atom_distance_sqr
,
0.0001
)
mask
=
torch
.
logical_and
(
mask_within_radius
,
mask_not_same
)
index1
=
torch
.
masked_select
(
index1
,
mask
)
index2
=
torch
.
masked_select
(
index2
,
mask
)
unit_cell
=
torch
.
masked_select
(
unit_cell_per_atom
.
view
(
-
1
,
3
),
mask
.
view
(
-
1
,
1
).
expand
(
-
1
,
3
)
)
unit_cell
=
unit_cell
.
view
(
-
1
,
3
)
atom_distance_sqr
=
torch
.
masked_select
(
atom_distance_sqr
,
mask
)
mask_num_neighbors
,
num_neighbors_image
=
get_max_neighbors_mask
(
natoms
=
data
.
natoms
,
index
=
index1
,
atom_distance
=
atom_distance_sqr
,
max_num_neighbors_threshold
=
max_num_neighbors_threshold
,
enforce_max_strictly
=
enforce_max_neighbors_strictly
,
)
if
not
torch
.
all
(
mask_num_neighbors
):
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
index1
=
torch
.
masked_select
(
index1
,
mask_num_neighbors
)
index2
=
torch
.
masked_select
(
index2
,
mask_num_neighbors
)
unit_cell
=
torch
.
masked_select
(
unit_cell
.
view
(
-
1
,
3
),
mask_num_neighbors
.
view
(
-
1
,
1
).
expand
(
-
1
,
3
)
)
unit_cell
=
unit_cell
.
view
(
-
1
,
3
)
edge_index
=
torch
.
stack
((
index2
,
index1
))
return
edge_index
,
unit_cell
,
num_neighbors_image
@
torch
.
compiler
.
disable
def
get_max_neighbors_mask
(
natoms
,
index
,
atom_distance
,
max_num_neighbors_threshold
,
degeneracy_tolerance
:
float
=
0.01
,
enforce_max_strictly
:
bool
=
False
,
):
"""
Give a mask that filters out edges so that each atom has at most
`max_num_neighbors_threshold` neighbors.
Assumes that `index` is sorted.
Enforcing the max strictly can force the arbitrary choice between
degenerate edges. This can lead to undesired behaviors; for
example, bulk formation energies which are not invariant to
unit cell choice.
A degeneracy tolerance can help prevent sudden changes in edge
existence from small changes in atom position, for example,
rounding errors, slab relaxation, temperature, etc.
"""
device
=
natoms
.
device
num_atoms
=
natoms
.
sum
()
# Get number of neighbors
# segment_coo assumes sorted index
ones
=
index
.
new_ones
(
1
).
expand_as
(
index
)
num_neighbors
=
segment_coo
(
ones
,
index
,
dim_size
=
num_atoms
)
max_num_neighbors
=
num_neighbors
.
max
()
num_neighbors_thresholded
=
num_neighbors
.
clamp
(
max
=
max_num_neighbors_threshold
)
# Get number of (thresholded) neighbors per image
image_indptr
=
torch
.
zeros
(
natoms
.
shape
[
0
]
+
1
,
device
=
device
,
dtype
=
torch
.
long
)
image_indptr
[
1
:]
=
torch
.
cumsum
(
natoms
,
dim
=
0
)
num_neighbors_image
=
segment_csr
(
num_neighbors_thresholded
,
image_indptr
)
# If max_num_neighbors is below the threshold, return early
if
(
max_num_neighbors
<=
max_num_neighbors_threshold
or
max_num_neighbors_threshold
<=
0
):
mask_num_neighbors
=
torch
.
tensor
([
True
],
dtype
=
bool
,
device
=
device
).
expand_as
(
index
)
return
mask_num_neighbors
,
num_neighbors_image
# Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
# Fill with infinity so we can easily remove unused distances later.
distance_sort
=
torch
.
full
([
num_atoms
*
max_num_neighbors
],
np
.
inf
,
device
=
device
)
# Create an index map to map distances from atom_distance to distance_sort
# index_sort_map assumes index to be sorted
index_neighbor_offset
=
torch
.
cumsum
(
num_neighbors
,
dim
=
0
)
-
num_neighbors
index_neighbor_offset_expand
=
torch
.
repeat_interleave
(
index_neighbor_offset
,
num_neighbors
)
index_sort_map
=
(
index
*
max_num_neighbors
+
torch
.
arange
(
len
(
index
),
device
=
device
)
-
index_neighbor_offset_expand
)
distance_sort
.
index_copy_
(
0
,
index_sort_map
,
atom_distance
)
distance_sort
=
distance_sort
.
view
(
num_atoms
,
max_num_neighbors
)
# Sort neighboring atoms based on distance
distance_sort
,
index_sort
=
torch
.
sort
(
distance_sort
,
dim
=
1
)
# Select the max_num_neighbors_threshold neighbors that are closest
if
enforce_max_strictly
:
distance_sort
=
distance_sort
[:,
:
max_num_neighbors_threshold
]
index_sort
=
index_sort
[:,
:
max_num_neighbors_threshold
]
max_num_included
=
max_num_neighbors_threshold
else
:
effective_cutoff
=
(
distance_sort
[:,
max_num_neighbors_threshold
]
+
degeneracy_tolerance
)
is_included
=
torch
.
le
(
distance_sort
.
T
,
effective_cutoff
)
# Set all undesired edges to infinite length to be removed later
distance_sort
[
~
is_included
.
T
]
=
np
.
inf
# Subselect tensors for efficiency
num_included_per_atom
=
torch
.
sum
(
is_included
,
dim
=
0
)
max_num_included
=
torch
.
max
(
num_included_per_atom
)
distance_sort
=
distance_sort
[:,
:
max_num_included
]
index_sort
=
index_sort
[:,
:
max_num_included
]
# Recompute the number of neighbors
num_neighbors_thresholded
=
num_neighbors
.
clamp
(
max
=
num_included_per_atom
)
num_neighbors_image
=
segment_csr
(
num_neighbors_thresholded
,
image_indptr
)
# Offset index_sort so that it indexes into index
index_sort
=
index_sort
+
index_neighbor_offset
.
view
(
-
1
,
1
).
expand
(
-
1
,
max_num_included
)
# Remove "unused pairs" with infinite distances
mask_finite
=
torch
.
isfinite
(
distance_sort
)
index_sort
=
torch
.
masked_select
(
index_sort
,
mask_finite
)
# At this point index_sort contains the index into index of the
# closest max_num_neighbors_threshold neighbors per atom
# Create a mask to remove all pairs not in index_sort
mask_num_neighbors
=
torch
.
zeros
(
len
(
index
),
device
=
device
,
dtype
=
bool
)
mask_num_neighbors
.
index_fill_
(
0
,
index_sort
,
True
)
return
mask_num_neighbors
,
num_neighbors_image
def
get_pruned_edge_idx
(
edge_index
,
num_atoms
:
int
,
max_neigh
:
float
=
1e9
)
->
torch
.
Tensor
:
assert
num_atoms
is
not
None
# TODO: Shouldn't be necessary
# removes neighbors > max_neigh
# assumes neighbors are sorted in increasing distance
_nonmax_idx_list
=
[]
for
i
in
range
(
num_atoms
):
idx_i
=
torch
.
arange
(
len
(
edge_index
[
1
]))[(
edge_index
[
1
]
==
i
)][:
max_neigh
]
_nonmax_idx_list
.
append
(
idx_i
)
return
torch
.
cat
(
_nonmax_idx_list
)
mace-bench/src/batchopt/relaxation/__init__.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from
__future__
import
annotations
from
.optimizable
import
OptimizableBatch
,
OptimizableUnitCellBatch
__all__
=
[
"ml_relax"
,
"OptimizableBatch"
,
"OptimizableUnitCellBatch"
]
mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
mace-bench/src/batchopt/relaxation/ase_utils.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
Utilities to interface OCP models/trainers with the Atomic Simulation
Environment (ASE)
"""
from
__future__
import
annotations
from
types
import
MappingProxyType
from
typing
import
TYPE_CHECKING
import
torch
from
ase
import
Atoms
from
ase.calculators.singlepoint
import
SinglePointCalculator
from
ase.constraints
import
FixAtoms
if
TYPE_CHECKING
:
from
torch_geometric.data
import
Batch
# system level model predictions have different shapes than expected by ASE
ASE_PROP_RESHAPE
=
MappingProxyType
(
{
"stress"
:
(
-
1
,
3
,
3
),
"dielectric_tensor"
:
(
-
1
,
3
,
3
)}
)
def
batch_to_atoms
(
batch
:
Batch
,
results
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
wrap_pos
:
bool
=
True
,
eps
:
float
=
1e-7
,
)
->
list
[
Atoms
]:
"""Convert a data batch to ase Atoms
Args:
batch: data batch
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
are given no calculator will be added to the atoms objects.
wrap_pos: wrap positions back into the cell.
eps: Small number to prevent slightly negative coordinates from being wrapped.
Returns:
list of Atoms
"""
n_systems
=
batch
.
natoms
.
shape
[
0
]
natoms
=
batch
.
natoms
.
tolist
()
numbers
=
torch
.
split
(
batch
.
atomic_numbers
,
natoms
)
fixed
=
torch
.
split
(
batch
.
fixed
.
to
(
torch
.
bool
),
natoms
)
if
results
is
not
None
:
results
=
{
key
:
val
.
view
(
ASE_PROP_RESHAPE
.
get
(
key
,
-
1
)).
tolist
()
if
len
(
val
)
==
len
(
batch
)
else
[
v
.
cpu
().
detach
().
numpy
()
for
v
in
torch
.
split
(
val
,
natoms
)]
for
key
,
val
in
results
.
items
()
}
positions
=
torch
.
split
(
batch
.
pos
,
natoms
)
tags
=
torch
.
split
(
batch
.
tags
,
natoms
)
cells
=
batch
.
cell
atoms_objects
=
[]
for
idx
in
range
(
n_systems
):
pos
=
positions
[
idx
].
cpu
().
detach
().
numpy
()
cell
=
cells
[
idx
].
cpu
().
detach
().
numpy
()
# TODO take pbc from data
# TODO: &&& ^^^ change this back !!!
# if wrap_pos:
# pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)
atoms
=
Atoms
(
numbers
=
numbers
[
idx
].
tolist
(),
cell
=
cell
,
positions
=
pos
,
tags
=
tags
[
idx
].
tolist
(),
constraint
=
FixAtoms
(
mask
=
fixed
[
idx
].
tolist
()),
pbc
=
[
True
,
True
,
True
],
)
if
results
is
not
None
:
calc
=
SinglePointCalculator
(
atoms
=
atoms
,
**
{
key
:
val
[
idx
]
for
key
,
val
in
results
.
items
()}
)
atoms
.
set_calculator
(
calc
)
atoms_objects
.
append
(
atoms
)
return
atoms_objects
mace-bench/src/batchopt/relaxation/optimizable.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) Meta, Inc. and its affiliates.
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
Modified from original Meta implementation.
"""
from
__future__
import
annotations
from
functools
import
cached_property
from
types
import
SimpleNamespace
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Any
,
Generator
import
numpy
as
np
import
torch
import
logging
from
ase.calculators.calculator
import
PropertyNotImplementedError
from
ase.stress
import
voigt_6_to_full_3x3_stress
from
torch_scatter
import
scatter
from
batchopt.relaxation.ase_utils
import
batch_to_atoms
# Define dummy classes for when imports fail
class
_DummyCalculator
:
pass
try
:
from
mace.calculators
import
MACECalculator
except
ImportError
:
logging
.
warning
(
"Unable to import MACECalculator."
)
MACECalculator
=
_DummyCalculator
try
:
from
chgnet.model.dynamics
import
CHGNetCalculator
except
ImportError
:
logging
.
warning
(
"Unable to import CHGNetCalculator."
)
CHGNetCalculator
=
_DummyCalculator
try
:
from
sevenn.calculator
import
(
SevenNetCalculator
,
SevenNetD3Calculator
,
D3Calculator
,
)
except
ImportError
:
logging
.
warning
(
"Unable to import SevenNetCalculator."
)
SevenNetCalculator
=
_DummyCalculator
SevenNetD3Calculator
=
_DummyCalculator
D3Calculator
=
_DummyCalculator
try
:
from
fairchem.core
import
pretrained_mlip
,
FAIRChemCalculator
except
ImportError
:
logging
.
warning
(
"Unable to import FAIRChemCalculator."
)
FAIRChemCalculator
=
_DummyCalculator
# this can be removed after pinning ASE dependency >= 3.23
try
:
from
ase.optimize.optimize
import
Optimizable
except
ImportError
:
class
Optimizable
:
pass
if
TYPE_CHECKING
:
from
collections.abc
import
Sequence
from
ase
import
Atoms
from
numpy.typing
import
NDArray
from
torch_geometric.data
import
Batch
ALL_CHANGES
:
set
[
str
]
=
{
"pos"
,
"atomic_numbers"
,
"cell"
,
"pbc"
,
}
# @torch.compile
def
compare_batches
(
batch1
:
Batch
|
None
,
batch2
:
Batch
,
tol
:
float
=
1e-6
,
excluded_properties
:
set
[
str
]
|
None
=
None
,
)
->
list
[
str
]:
"""Compare properties between two batches
Args:
batch1: atoms batch
batch2: atoms batch
tol: tolerance used to compare equility of floating point properties
excluded_properties: list of properties to exclude from comparison
Returns:
list of system changes, property names that are differente between batch1 and batch2
"""
system_changes
=
[]
if
batch1
is
None
:
system_changes
=
ALL_CHANGES
else
:
properties_to_check
=
set
(
ALL_CHANGES
)
if
excluded_properties
:
properties_to_check
-=
set
(
excluded_properties
)
# Check properties that aren't
for
prop
in
ALL_CHANGES
:
if
prop
in
properties_to_check
:
properties_to_check
.
remove
(
prop
)
if
not
torch
.
allclose
(
getattr
(
batch1
,
prop
),
getattr
(
batch2
,
prop
),
atol
=
tol
):
system_changes
.
append
(
prop
)
return
system_changes
class
OptimizableBatch
(
Optimizable
):
"""A Batch version of ase Optimizable Atoms
This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation
or in ase relaxations classes, i.e. ase.optimize.lbfgs
"""
ignored_changes
:
ClassVar
[
set
[
str
]]
=
set
()
def
__init__
(
self
,
batch
:
Batch
,
trainer
:
Any
,
# Any calculator type (MACECalculator | CHGNetCalculator | SevenNetCalculator | FAIRChemCalculator)
transform
:
torch
.
nn
.
Module
|
None
=
None
,
mask_converged
:
bool
=
True
,
numpy
:
bool
=
False
,
masked_eps
:
float
=
1e-8
,
compute_stress
:
bool
=
False
,
use_fast_predict
:
bool
=
True
,
dtype
:
torch
.
dtype
=
torch
.
float64
,
):
"""Initialize Optimizable Batch
Args:
batch: A batch of atoms graph data
model: An instance of a BaseTrainer derived class
transform: graph transform
mask_converged: if true will mask systems in batch that are already converged
numpy: whether to cast results to numpy arrays
masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero
from zero differences in masked positions at future steps, we add a small number to prevent this.
compute_stress: whether to compute stress during prediction
use_fast_predict: use fast prediction method when available
dtype: data type for tensor operations (torch.float32 or torch.float64)
"""
self
.
batch
=
batch
.
to
(
trainer
.
device
)
self
.
trainer
=
trainer
self
.
transform
=
transform
self
.
numpy
=
numpy
self
.
mask_converged
=
mask_converged
self
.
_cached_batch
=
None
self
.
_update_mask
=
None
self
.
torch_results
=
{}
self
.
results
=
{}
self
.
_eps
=
masked_eps
self
.
dtype
=
dtype
self
.
otf_graph
=
True
# trainer._unwrapped_model.otf_graph
if
not
self
.
otf_graph
and
"edge_index"
not
in
self
.
batch
:
self
.
update_graph
()
self
.
batch
.
pos
=
self
.
batch
.
pos
.
to
(
dtype
=
self
.
dtype
)
self
.
batch
.
cell
=
self
.
batch
.
cell
.
to
(
dtype
=
self
.
dtype
)
self
.
compute_stress
=
compute_stress
self
.
use_fast_predict
=
use_fast_predict
# Determine calculator type once during initialization for efficiency
self
.
_calculator_type
=
self
.
_determine_calculator_type
()
logging
.
info
(
f
"OptimizableBatch initialized with calculator type:
{
self
.
_calculator_type
}
"
)
def
_determine_calculator_type
(
self
)
->
str
:
"""Determine the type of calculator to avoid repeated isinstance checks."""
# Check against actual imported classes, not dummy classes
trainer_class_name
=
type
(
self
.
trainer
).
__name__
trainer_module
=
type
(
self
.
trainer
).
__module__
if
(
"mace"
in
trainer_module
.
lower
()
or
trainer_class_name
==
"MACECalculator"
):
return
"mace"
elif
(
"chgnet"
in
trainer_module
.
lower
()
or
trainer_class_name
==
"CHGNetCalculator"
):
return
"chgnet"
elif
"sevenn"
in
trainer_module
.
lower
()
or
trainer_class_name
in
[
"SevenNetCalculator"
,
"SevenNetD3Calculator"
,
"D3Calculator"
,
]:
return
"sevennet"
elif
(
"fairchem"
in
trainer_module
.
lower
()
or
trainer_class_name
==
"FAIRChemCalculator"
):
return
"fairchem"
else
:
return
"default"
@
property
def
device
(
self
):
return
self
.
trainer
.
device
@
property
def
batch_indices
(
self
):
"""Get the batch indices specifying which position/force corresponds to which batch."""
return
self
.
batch
.
batch
@
property
def
converged_mask
(
self
):
if
self
.
_update_mask
is
not
None
:
return
torch
.
logical_not
(
self
.
_update_mask
)
return
None
@
property
def
update_mask
(
self
):
if
self
.
_update_mask
is
None
:
return
torch
.
ones
(
len
(
self
.
batch
),
dtype
=
bool
)
return
self
.
_update_mask
@
property
def
converge_indices_list
(
self
):
return
torch
.
where
(
~
self
.
update_mask
)[
0
].
tolist
()
@
property
def
elem_per_group
(
self
):
# This return value actually represents the number of elements
# in a group within a batch. Each group corresponds to batch_indices.
# It will count the number of CELL elements in each group.
return
torch
.
bincount
(
self
.
batch_indices
)
@
property
def
batch_size
(
self
):
return
len
(
torch
.
unique
(
self
.
batch_indices
))
def
check_state
(
self
,
batch
:
Batch
,
tol
:
float
=
1e-12
)
->
bool
:
"""Check for any system changes since last calculation."""
return
compare_batches
(
self
.
_cached_batch
,
batch
,
tol
=
tol
,
excluded_properties
=
set
(
self
.
ignored_changes
),
)
def
_predict
(
self
)
->
None
:
"""Run prediction if batch has any changes."""
# TODO: Currently, the batch inference interfaces of various models are not unified and are poorly implemented.
system_changes
=
self
.
check_state
(
self
.
batch
)
if
len
(
system_changes
)
>
0
:
if
self
.
_calculator_type
==
"mace"
:
# FIXME: &&&
# for key, val in self.batch.to_dict().items():
# print(f'&&& key: {key}, val: {val}')
# self.torch_results = self.trainer.predict_debug(atoms_list, self.batch, compute_stress=self.compute_stress)
# self.torch_results = self.trainer.predict(self.config_batch)
if
self
.
use_fast_predict
:
self
.
torch_results
=
self
.
trainer
.
fast_predict
(
self
.
batch
,
compute_stress
=
self
.
compute_stress
)
self
.
batch
.
pos
=
self
.
batch
.
pos
.
to
(
self
.
dtype
)
self
.
batch
.
cell
=
self
.
batch
.
cell
.
to
(
self
.
dtype
)
else
:
atoms_list
=
batch_to_atoms
(
self
.
batch
,
results
=
None
,
wrap_pos
=
False
,
eps
=
1e-17
)
self
.
torch_results
=
self
.
trainer
.
predict
(
atoms_list
,
compute_stress
=
self
.
compute_stress
)
elif
self
.
_calculator_type
==
"fairchem"
:
# TODO: FAIRChemCalculator does not support batch prediction yet
atoms_list
=
batch_to_atoms
(
self
.
batch
,
results
=
None
,
wrap_pos
=
False
,
eps
=
1e-17
)
self
.
torch_results
=
self
.
trainer
.
predict
(
atoms_list
=
atoms_list
)
elif
self
.
_calculator_type
==
"chgnet"
:
atoms_list
=
batch_to_atoms
(
self
.
batch
,
results
=
None
,
wrap_pos
=
False
,
eps
=
1e-17
)
model_prediction
=
self
.
trainer
.
predict
(
atoms_list
=
atoms_list
,
task
=
"efs"
)
results
=
{
"energy"
:
torch
.
tensor
(
[
pred
[
"e"
].
item
()
for
pred
in
model_prediction
],
device
=
self
.
device
,
dtype
=
self
.
dtype
,
),
"forces"
:
torch
.
vstack
(
[
torch
.
from_numpy
(
pred
[
"f"
]).
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
for
pred
in
model_prediction
]
),
"stress"
:
torch
.
vstack
(
[
torch
.
from_numpy
(
pred
[
"s"
]).
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
for
pred
in
model_prediction
]
).
view
(
-
1
,
3
,
3
),
}
self
.
torch_results
=
results
elif
self
.
_calculator_type
==
"sevennet"
:
atoms_list
=
batch_to_atoms
(
self
.
batch
,
results
=
None
,
wrap_pos
=
False
,
eps
=
1e-17
)
self
.
torch_results
=
self
.
trainer
.
predict
(
atoms_list
=
atoms_list
)
else
:
# default case
self
.
torch_results
=
self
.
trainer
.
predict
(
self
.
batch
,
per_image
=
False
,
disable_tqdm
=
True
)
# save only subset of props in simple namespace instead of cloning the whole batch to save memory
changes
=
ALL_CHANGES
-
set
(
self
.
ignored_changes
)
self
.
_cached_batch
=
SimpleNamespace
(
**
{
prop
:
self
.
batch
[
prop
].
clone
()
for
prop
in
changes
}
)
def
get_property
(
self
,
name
,
no_numpy
:
bool
=
False
)
->
torch
.
Tensor
|
NDArray
:
"""Get a predicted property by name."""
self
.
_predict
()
if
self
.
numpy
:
self
.
results
=
{
key
:
pred
.
item
()
if
pred
.
numel
()
==
1
else
pred
.
cpu
().
numpy
()
for
key
,
pred
in
self
.
torch_results
.
items
()
}
else
:
self
.
results
=
self
.
torch_results
if
name
not
in
self
.
results
:
raise
PropertyNotImplementedError
(
f
"
{
name
}
not present in this calculation"
)
return
(
self
.
results
[
name
]
if
no_numpy
is
False
else
self
.
torch_results
[
name
]
)
def
get_positions
(
self
)
->
torch
.
Tensor
|
NDArray
:
"""Get the batch positions"""
pos
=
self
.
batch
.
pos
.
clone
()
if
self
.
numpy
:
if
self
.
mask_converged
:
pos
[
~
self
.
update_mask
[
self
.
batch
.
batch
]]
=
self
.
_eps
pos
=
pos
.
cpu
().
numpy
()
return
pos
def
set_positions
(
self
,
positions
:
torch
.
Tensor
|
NDArray
)
->
None
:
"""Set the atom positions in the batch."""
if
isinstance
(
positions
,
np
.
ndarray
):
positions
=
torch
.
tensor
(
positions
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
else
:
positions
=
positions
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
if
self
.
mask_converged
and
self
.
_update_mask
is
not
None
:
mask
=
self
.
update_mask
[
self
.
batch
.
batch
]
self
.
batch
.
pos
[
mask
]
=
positions
[
mask
]
else
:
self
.
batch
.
pos
=
positions
if
not
self
.
otf_graph
:
self
.
update_graph
()
def
get_forces
(
self
,
apply_constraint
:
bool
=
False
,
no_numpy
:
bool
=
False
)
->
torch
.
Tensor
|
NDArray
:
"""Get predicted batch forces."""
forces
=
self
.
get_property
(
"forces"
,
no_numpy
=
no_numpy
)
if
apply_constraint
:
fixed_idx
=
torch
.
where
(
self
.
batch
.
fixed
==
1
)[
0
]
if
isinstance
(
forces
,
np
.
ndarray
):
fixed_idx
=
fixed_idx
.
tolist
()
forces
[
fixed_idx
]
=
0.0
return
forces
.
view
(
-
1
,
3
)
def
get_potential_energy
(
self
,
**
kwargs
)
->
torch
.
Tensor
|
NDArray
:
"""Get predicted energy as the sum of all batch energies."""
# ASE 3.22.1 expects a check for force_consistent calculations
if
kwargs
.
get
(
"force_consistent"
,
False
)
is
True
:
raise
PropertyNotImplementedError
(
"force_consistent calculations are not implemented"
)
if
(
len
(
self
.
batch
)
==
1
):
# unfortunately batch size 1 returns a float, not a tensor
return
self
.
get_property
(
"energy"
)
return
self
.
get_property
(
"energy"
).
sum
()
def
get_potential_energies
(
self
)
->
torch
.
Tensor
|
NDArray
:
"""Get the predicted energy for each system in batch."""
return
self
.
get_property
(
"energy"
)
def
get_cells
(
self
)
->
torch
.
Tensor
:
"""Get batch crystallographic cells."""
return
self
.
batch
.
cell
def
set_cells
(
self
,
cells
:
torch
.
Tensor
|
NDArray
,
scale_atoms
=
False
)
->
None
:
"""Set batch cells."""
assert
self
.
batch
.
cell
.
shape
==
cells
.
shape
,
"Cell shape mismatch"
if
isinstance
(
cells
,
np
.
ndarray
):
cells
=
torch
.
tensor
(
cells
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
cells
=
cells
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
if
scale_atoms
:
from
ase.geometry.cell
import
complete_cell
# M = torch.linalg.solve(
# self.batch.cell.view(-1, 3, 3),
# cells.view(-1, 3, 3),
# )
# TODO: need to implement a sparse version.
# tmp_pos = torch.matmul(self.batch.pos, M.reshape(-1,3))
for
i
in
range
(
self
.
batch_size
):
if
not
self
.
update_mask
[
i
]:
continue
M
=
np
.
linalg
.
solve
(
complete_cell
(
self
.
batch
.
cell
[
i
].
cpu
().
detach
().
numpy
()),
complete_cell
(
cells
[
i
].
cpu
().
detach
().
numpy
()),
)
pos_update_mask
=
self
.
batch
.
batch
==
i
self
.
batch
.
pos
[
pos_update_mask
]
=
torch
.
matmul
(
self
.
batch
.
pos
[
pos_update_mask
],
torch
.
from_numpy
(
M
).
to
(
self
.
device
).
reshape
(
-
1
,
3
),
)
self
.
batch
.
cell
[
self
.
update_mask
]
=
cells
[
self
.
update_mask
]
def
get_volumes
(
self
)
->
torch
.
Tensor
:
"""Get a tensor of volumes for each cell in batch"""
cells
=
self
.
get_cells
()
return
torch
.
linalg
.
det
(
cells
)
def
iterimages
(
self
)
->
Generator
[
Batch
,
None
,
None
]:
# XXX document purpose of iterimages - this is just needed to work with ASE optimizers
yield
self
.
batch
def
get_max_forces
(
self
,
forces
:
torch
.
Tensor
|
None
=
None
,
apply_constraint
:
bool
=
False
)
->
torch
.
Tensor
:
"""Get the maximum forces per structure in batch"""
if
forces
is
None
:
forces
=
self
.
get_forces
(
apply_constraint
=
apply_constraint
,
no_numpy
=
True
)
return
scatter
(
(
forces
**
2
).
sum
(
axis
=
1
).
sqrt
(),
self
.
batch_indices
,
reduce
=
"max"
)
def
converged
(
self
,
forces
:
torch
.
Tensor
|
NDArray
|
None
,
fmax
:
float
,
max_forces
:
torch
.
Tensor
|
None
=
None
,
f_upper_limit
:
float
=
1e20
,
)
->
bool
:
"""Check if norm of all predicted forces are below fmax"""
if
forces
is
not
None
:
if
isinstance
(
forces
,
np
.
ndarray
):
forces
=
torch
.
tensor
(
forces
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
max_forces
=
self
.
get_max_forces
(
forces
)
elif
max_forces
is
None
:
max_forces
=
self
.
get_max_forces
()
# Update mask is True for forces that are greater than fmax AND less than f_upper_limit
update_mask
=
torch
.
logical_and
(
max_forces
.
ge
(
fmax
),
max_forces
.
le
(
f_upper_limit
)
)
# update cached mask
if
self
.
mask_converged
:
if
self
.
_update_mask
is
None
:
self
.
_update_mask
=
update_mask
else
:
# some models can have random noise in their predictions, so the mask is updated by
# keeping all previously converged structures masked even if new force predictions
# push it slightly above threshold
self
.
_update_mask
=
torch
.
logical_and
(
self
.
_update_mask
,
update_mask
)
update_mask
=
self
.
_update_mask
return
not
torch
.
any
(
update_mask
).
item
()
def
get_atoms_list
(
self
)
->
list
[
Atoms
]:
"""Get ase Atoms objects corresponding to the batch"""
self
.
_predict
()
# in case no predictions have been run
return
batch_to_atoms
(
self
.
batch
,
results
=
self
.
torch_results
)
def
update_graph
(
self
):
"""Update the graph if model does not use otf_graph."""
graph
=
self
.
trainer
.
_unwrapped_model
.
generate_graph
(
self
.
batch
)
self
.
batch
.
edge_index
=
graph
.
edge_index
self
.
batch
.
cell_offsets
=
graph
.
cell_offsets
self
.
batch
.
neighbors
=
graph
.
neighbors
if
self
.
transform
is
not
None
:
self
.
batch
=
self
.
transform
(
self
.
batch
)
def
__len__
(
self
)
->
int
:
# TODO: this might be changed in ASE to be 3 * len(self.atoms)
return
len
(
self
.
batch
.
pos
)
class
OptimizableUnitCellBatch
(
OptimizableBatch
):
"""Modify the supercell and the atom positions in relaxations.
Based on ase UnitCellFilter to work on data batches
"""
def
__init__
(
self
,
batch
:
Batch
,
trainer
:
Any
,
# Any calculator type (MACECalculator | CHGNetCalculator | SevenNetD3Calculator | FAIRChemCalculator)
transform
:
torch
.
nn
.
Module
|
None
=
None
,
numpy
:
bool
=
False
,
mask_converged
:
bool
=
True
,
mask
:
Sequence
[
bool
]
|
None
=
None
,
cell_factor
:
float
|
torch
.
Tensor
|
None
=
None
,
hydrostatic_strain
:
bool
=
False
,
constant_volume
:
bool
=
False
,
scalar_pressure
:
float
=
0.0
,
masked_eps
:
float
=
1e-8
,
use_fast_predict
:
bool
=
True
,
dtype
:
torch
.
dtype
=
torch
.
float64
,
):
"""Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization.
For full details see:
E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras,
Phys. Rev. B 59, 235 (1999)
Args:
batch: A batch of atoms graph data
model: An instance of a BaseTrainer derived class
transform: graph transform
numpy: whether to cast results to numpy arrays
mask_converged: if true will mask systems in batch that are already converged
mask: a boolean mask specifying which strain components are allowed to relax
cell_factor:
Factor by which deformation gradient is multiplied to put
it on the same scale as the positions when assembling
the combined position/cell vector. The stress contribution to
the forces is scaled down by the same factor. This can be thought
of as a very simple preconditioner. Default is number of atoms
which gives approximately the correct scaling.
hydrostatic_strain:
Constrain the cell by only allowing hydrostatic deformation.
The virial tensor is replaced by np.diag([np.trace(virial)]*3).
constant_volume:
Project out the diagonal elements of the virial tensor to allow
relaxations at constant volume, e.g. for mapping out an
energy-volume curve. Note: this only approximately conserves
the volume and breaks energy/force consistency so can only be
used with optimizers that do require a line minimisation
(e.g. FIRE).
scalar_pressure:
Applied pressure to use for enthalpy pV term. As above, this
breaks energy/force consistency.
masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero
from zero differences in masked positions at future steps, we add a small number to prevent this.
dtype: data type for tensor operations (torch.float32 or torch.float64)
"""
super
().
__init__
(
batch
=
batch
,
trainer
=
trainer
,
transform
=
transform
,
numpy
=
numpy
,
mask_converged
=
mask_converged
,
masked_eps
=
masked_eps
,
compute_stress
=
True
,
use_fast_predict
=
use_fast_predict
,
dtype
=
dtype
,
)
self
.
orig_cells
=
self
.
get_cells
().
clone
()
self
.
stress
=
None
if
mask
is
None
:
# mask = torch.eye(3, device=self.device)
mask
=
torch
.
ones
(
6
,
device
=
self
.
device
)
# TODO make sure mask is on GPU
if
mask
.
shape
==
(
6
,):
self
.
mask
=
torch
.
tensor
(
voigt_6_to_full_3x3_stress
(
mask
.
detach
().
cpu
()),
device
=
self
.
device
,
)
elif
mask
.
shape
==
(
3
,
3
):
self
.
mask
=
mask
else
:
raise
ValueError
(
"shape of mask should be (3,3) or (6,)"
)
if
isinstance
(
cell_factor
,
float
):
cell_factor
=
cell_factor
*
torch
.
ones
(
(
3
*
len
(
batch
),
1
),
requires_grad
=
False
)
if
cell_factor
is
None
:
cell_factor
=
self
.
batch
.
natoms
.
repeat_interleave
(
3
).
unsqueeze
(
dim
=
1
)
self
.
hydrostatic_strain
=
hydrostatic_strain
self
.
constant_volume
=
constant_volume
self
.
pressure
=
scalar_pressure
*
torch
.
eye
(
3
,
device
=
self
.
device
)
self
.
cell_factor
=
cell_factor
self
.
stress
=
None
self
.
_batch_trace
=
torch
.
vmap
(
torch
.
trace
)
self
.
_batch_diag
=
torch
.
vmap
(
lambda
x
:
x
*
torch
.
eye
(
3
,
device
=
x
.
device
)
)
@
cached_property
def
batch_indices
(
self
):
"""Get the batch indices specifying which position/force corresponds to which batch.
We augment this to specify the batch indices for augmented positions and forces.
"""
augmented_batch
=
torch
.
repeat_interleave
(
torch
.
arange
(
len
(
self
.
batch
),
dtype
=
self
.
batch
.
batch
.
dtype
,
device
=
self
.
device
,
),
3
,
)
return
torch
.
cat
([
self
.
batch
.
batch
,
augmented_batch
])
def
deform_grad
(
self
):
"""Get the cell deformation matrix"""
return
torch
.
transpose
(
torch
.
linalg
.
solve
(
self
.
orig_cells
,
self
.
get_cells
()),
1
,
2
)
def
get_positions
(
self
):
"""Get positions and cell deformation gradient."""
cur_deform_grad
=
self
.
deform_grad
()
natoms
=
self
.
batch
.
num_nodes
pos
=
torch
.
zeros
(
(
natoms
+
3
*
len
(
self
.
get_cells
()),
3
),
dtype
=
self
.
batch
.
pos
.
dtype
,
device
=
self
.
device
,
)
# Augmented positions are the self.atoms.positions but without the applied deformation gradient
pos
[:
natoms
]
=
torch
.
linalg
.
solve
(
cur_deform_grad
[
self
.
batch
.
batch
,
:,
:],
self
.
batch
.
pos
.
view
(
-
1
,
3
,
1
),
).
view
(
-
1
,
3
)
# cell DOFs are the deformation gradient times a scaling factor
pos
[
natoms
:]
=
self
.
cell_factor
*
cur_deform_grad
.
view
(
-
1
,
3
)
return
pos
.
cpu
().
numpy
()
if
self
.
numpy
else
pos
def
set_positions
(
self
,
positions
:
torch
.
Tensor
|
NDArray
)
->
None
:
"""Set positions and cell.
positions has shape (natoms + ncells * 3, 3).
the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor
for each cell.
"""
if
isinstance
(
positions
,
np
.
ndarray
):
positions
=
torch
.
tensor
(
positions
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
else
:
positions
=
positions
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
natoms
=
self
.
batch
.
num_nodes
new_atom_positions
=
positions
[:
natoms
]
new_deform_grad
=
(
positions
[
natoms
:]
/
self
.
cell_factor
).
view
(
-
1
,
3
,
3
)
# TODO check that in fact symmetry is preserved setting cells and positions
# Set the new cell from the original cell and the new deformation gradient. Both current and final structures
# should preserve symmetry.
new_cells
=
torch
.
bmm
(
self
.
orig_cells
,
torch
.
transpose
(
new_deform_grad
,
1
,
2
)
)
self
.
set_cells
(
new_cells
)
# Set the positions from the ones passed in (which are without the deformation gradient applied) and the new
# deformation gradient. This should also preserve symmetry
new_atom_positions
=
torch
.
bmm
(
new_atom_positions
.
view
(
-
1
,
1
,
3
),
torch
.
transpose
(
new_deform_grad
[
self
.
batch
.
batch
,
:,
:].
view
(
-
1
,
3
,
3
),
1
,
2
),
)
super
().
set_positions
(
new_atom_positions
.
view
(
-
1
,
3
))
def
get_potential_energy
(
self
,
**
kwargs
):
"""
returns potential energy including enthalpy PV term.
"""
atoms_energy
=
super
().
get_potential_energy
(
**
kwargs
)
return
atoms_energy
+
self
.
pressure
[
0
,
0
]
*
self
.
get_volumes
().
sum
()
def
get_forces
(
self
,
apply_constraint
:
bool
=
False
,
no_numpy
:
bool
=
False
)
->
torch
.
Tensor
|
NDArray
:
"""Get forces and unit cell stress."""
stress
=
self
.
get_property
(
"stress"
,
no_numpy
=
True
).
view
(
-
1
,
3
,
3
)
atom_forces
=
self
.
get_property
(
"forces"
,
no_numpy
=
True
)
if
apply_constraint
:
fixed_idx
=
torch
.
where
(
self
.
batch
.
fixed
==
1
)[
0
]
atom_forces
[
fixed_idx
]
=
0.0
volumes
=
self
.
get_volumes
().
view
(
-
1
,
1
,
1
)
# virial = -volumes * stress + self.pressure.view(-1, 3, 3)
virial
=
-
volumes
*
(
stress
+
self
.
pressure
.
view
(
-
1
,
3
,
3
))
# print(f'&&& virial0: {virial}')
cur_deform_grad
=
self
.
deform_grad
()
atom_forces
=
torch
.
bmm
(
atom_forces
.
view
(
-
1
,
1
,
3
),
cur_deform_grad
[
self
.
batch
.
batch
,
:,
:].
view
(
-
1
,
3
,
3
),
)
virial
=
torch
.
linalg
.
solve
(
cur_deform_grad
,
torch
.
transpose
(
virial
,
dim0
=
1
,
dim1
=
2
)
)
virial
=
torch
.
transpose
(
virial
,
dim0
=
1
,
dim1
=
2
)
# print(f'&&& virial1: {virial}')
# TODO this does not work yet! maybe _batch_trace gives an issue
if
self
.
hydrostatic_strain
:
virial
=
self
.
_batch_diag
(
self
.
_batch_trace
(
virial
)
/
3.0
)
# Zero out components corresponding to fixed lattice elements
if
(
self
.
mask
!=
1.0
).
any
():
virial
*=
self
.
mask
.
view
(
-
1
,
3
,
3
)
if
self
.
constant_volume
:
virial
[:,
range
(
3
),
range
(
3
)]
-=
(
self
.
_batch_trace
(
virial
).
view
(
3
,
-
1
)
/
3.0
)
natoms
=
self
.
batch
.
num_nodes
augmented_forces
=
torch
.
zeros
(
(
natoms
+
3
*
len
(
self
.
get_cells
()),
3
),
device
=
self
.
device
,
dtype
=
atom_forces
.
dtype
,
)
# print(f'&&& atom_forces: {atom_forces}')
# print(f'&&& virial2: {virial}')
augmented_forces
[:
natoms
]
=
atom_forces
.
view
(
-
1
,
3
)
augmented_forces
[
natoms
:]
=
virial
.
view
(
-
1
,
3
)
/
self
.
cell_factor
self
.
stress
=
-
virial
.
view
(
-
1
,
9
)
/
volumes
.
view
(
-
1
,
1
)
if
self
.
numpy
and
not
no_numpy
:
augmented_forces
=
augmented_forces
.
cpu
().
numpy
()
# print(f'&&& augmented_forces: {augmented_forces}')
return
augmented_forces
def
__len__
(
self
):
return
len
(
self
.
batch
.
pos
)
+
3
*
len
(
self
.
batch
)
def
get_potential_energies
(
self
)
->
torch
.
Tensor
:
"""Get the predicted energy for each system in batch."""
return
(
self
.
get_property
(
"energy"
).
view
(
-
1
)
+
self
.
pressure
[
0
,
0
]
*
self
.
get_volumes
()
)
mace-bench/src/batchopt/relaxation/optimizers/__init__.py
0 → 100644
View file @
fa84b16c
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from
__future__
import
annotations
from
.bfgs_torch
import
BFGS
from
.bfgsfusedls
import
BFGSFusedLS
__all__
=
[
"BFGS"
,
"BFGSFusedLS"
]
\ No newline at end of file
mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
fa84b16c
File added
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment