Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
nivren
ICT-CSP
Commits
251f5af2
Unverified
Commit
251f5af2
authored
Aug 24, 2025
by
zcxzcx1
Committed by
GitHub
Aug 24, 2025
Browse files
Add files via upload
parent
73ff4f3a
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1365 additions
and
0 deletions
+1365
-0
modules/radial.py
modules/radial.py
+358
-0
modules/symmetric_contraction.py
modules/symmetric_contraction.py
+233
-0
modules/utils.py
modules/utils.py
+582
-0
modules/wrapper_ops.py
modules/wrapper_ops.py
+192
-0
No files found.
modules/radial.py
0 → 100644
View file @
251f5af2
###########################################################################################
# Radial basis and cutoff
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import
logging
import
ase
import
numpy
as
np
import
torch
from
e3nn.util.jit
import
compile_mode
from
mace.tools.scatter
import
scatter_sum
@
compile_mode
(
"script"
)
class
BesselBasis
(
torch
.
nn
.
Module
):
"""
Equation (7)
"""
def
__init__
(
self
,
r_max
:
float
,
num_basis
=
8
,
trainable
=
False
):
super
().
__init__
()
bessel_weights
=
(
np
.
pi
/
r_max
*
torch
.
linspace
(
start
=
1.0
,
end
=
num_basis
,
steps
=
num_basis
,
dtype
=
torch
.
get_default_dtype
(),
)
)
if
trainable
:
self
.
bessel_weights
=
torch
.
nn
.
Parameter
(
bessel_weights
)
else
:
self
.
register_buffer
(
"bessel_weights"
,
bessel_weights
)
self
.
register_buffer
(
"r_max"
,
torch
.
tensor
(
r_max
,
dtype
=
torch
.
get_default_dtype
())
)
self
.
register_buffer
(
"prefactor"
,
torch
.
tensor
(
np
.
sqrt
(
2.0
/
r_max
),
dtype
=
torch
.
get_default_dtype
()),
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [..., 1]
numerator
=
torch
.
sin
(
self
.
bessel_weights
*
x
)
# [..., num_basis]
return
self
.
prefactor
*
(
numerator
/
x
)
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(r_max=
{
self
.
r_max
}
, num_basis=
{
len
(
self
.
bessel_weights
)
}
, "
f
"trainable=
{
self
.
bessel_weights
.
requires_grad
}
)"
)
@
compile_mode
(
"script"
)
class
ChebychevBasis
(
torch
.
nn
.
Module
):
"""
Equation (7)
"""
def
__init__
(
self
,
r_max
:
float
,
num_basis
=
8
):
super
().
__init__
()
self
.
register_buffer
(
"n"
,
torch
.
arange
(
1
,
num_basis
+
1
,
dtype
=
torch
.
get_default_dtype
()).
unsqueeze
(
0
),
)
self
.
num_basis
=
num_basis
self
.
r_max
=
r_max
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [..., 1]
x
=
x
.
repeat
(
1
,
self
.
num_basis
)
n
=
self
.
n
.
repeat
(
len
(
x
),
1
)
return
torch
.
special
.
chebyshev_polynomial_t
(
x
,
n
)
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(r_max=
{
self
.
r_max
}
, num_basis=
{
self
.
num_basis
}
,"
)
@
compile_mode
(
"script"
)
class
GaussianBasis
(
torch
.
nn
.
Module
):
"""
Gaussian basis functions
"""
def
__init__
(
self
,
r_max
:
float
,
num_basis
=
128
,
trainable
=
False
):
super
().
__init__
()
gaussian_weights
=
torch
.
linspace
(
start
=
0.0
,
end
=
r_max
,
steps
=
num_basis
,
dtype
=
torch
.
get_default_dtype
()
)
if
trainable
:
self
.
gaussian_weights
=
torch
.
nn
.
Parameter
(
gaussian_weights
,
requires_grad
=
True
)
else
:
self
.
register_buffer
(
"gaussian_weights"
,
gaussian_weights
)
self
.
coeff
=
-
0.5
/
(
r_max
/
(
num_basis
-
1
))
**
2
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [..., 1]
x
=
x
-
self
.
gaussian_weights
return
torch
.
exp
(
self
.
coeff
*
torch
.
pow
(
x
,
2
))
@
compile_mode
(
"script"
)
class
PolynomialCutoff
(
torch
.
nn
.
Module
):
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""
p
:
torch
.
Tensor
r_max
:
torch
.
Tensor
def
__init__
(
self
,
r_max
:
float
,
p
=
6
):
super
().
__init__
()
self
.
register_buffer
(
"p"
,
torch
.
tensor
(
p
,
dtype
=
torch
.
int
))
self
.
register_buffer
(
"r_max"
,
torch
.
tensor
(
r_max
,
dtype
=
torch
.
get_default_dtype
())
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
calculate_envelope
(
x
,
self
.
r_max
,
self
.
p
.
to
(
torch
.
int
))
@
staticmethod
def
calculate_envelope
(
x
:
torch
.
Tensor
,
r_max
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r_over_r_max
=
x
/
r_max
envelope
=
(
1.0
-
((
p
+
1.0
)
*
(
p
+
2.0
)
/
2.0
)
*
torch
.
pow
(
r_over_r_max
,
p
)
+
p
*
(
p
+
2.0
)
*
torch
.
pow
(
r_over_r_max
,
p
+
1
)
-
(
p
*
(
p
+
1.0
)
/
2
)
*
torch
.
pow
(
r_over_r_max
,
p
+
2
)
)
return
envelope
*
(
x
<
r_max
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(p=
{
self
.
p
}
, r_max=
{
self
.
r_max
}
)"
@
compile_mode
(
"script"
)
class
ZBLBasis
(
torch
.
nn
.
Module
):
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""
p
:
torch
.
Tensor
def
__init__
(
self
,
p
=
6
,
trainable
=
False
,
**
kwargs
):
super
().
__init__
()
if
"r_max"
in
kwargs
:
logging
.
warning
(
"r_max is deprecated. r_max is determined from the covalent radii."
)
# Pre-calculate the p coefficients for the ZBL potential
self
.
register_buffer
(
"c"
,
torch
.
tensor
(
[
0.1818
,
0.5099
,
0.2802
,
0.02817
],
dtype
=
torch
.
get_default_dtype
()
),
)
self
.
register_buffer
(
"p"
,
torch
.
tensor
(
p
,
dtype
=
torch
.
int
))
self
.
register_buffer
(
"covalent_radii"
,
torch
.
tensor
(
ase
.
data
.
covalent_radii
,
dtype
=
torch
.
get_default_dtype
(),
),
)
if
trainable
:
self
.
a_exp
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
0.300
,
requires_grad
=
True
))
self
.
a_prefactor
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
0.4543
,
requires_grad
=
True
)
)
else
:
self
.
register_buffer
(
"a_exp"
,
torch
.
tensor
(
0.300
))
self
.
register_buffer
(
"a_prefactor"
,
torch
.
tensor
(
0.4543
))
def
forward
(
self
,
x
:
torch
.
Tensor
,
node_attrs
:
torch
.
Tensor
,
edge_index
:
torch
.
Tensor
,
atomic_numbers
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
sender
=
edge_index
[
0
]
receiver
=
edge_index
[
1
]
node_atomic_numbers
=
atomic_numbers
[
torch
.
argmax
(
node_attrs
,
dim
=
1
)].
unsqueeze
(
-
1
)
Z_u
=
node_atomic_numbers
[
sender
]
Z_v
=
node_atomic_numbers
[
receiver
]
a
=
(
self
.
a_prefactor
*
0.529
/
(
torch
.
pow
(
Z_u
,
self
.
a_exp
)
+
torch
.
pow
(
Z_v
,
self
.
a_exp
))
)
r_over_a
=
x
/
a
phi
=
(
self
.
c
[
0
]
*
torch
.
exp
(
-
3.2
*
r_over_a
)
+
self
.
c
[
1
]
*
torch
.
exp
(
-
0.9423
*
r_over_a
)
+
self
.
c
[
2
]
*
torch
.
exp
(
-
0.4028
*
r_over_a
)
+
self
.
c
[
3
]
*
torch
.
exp
(
-
0.2016
*
r_over_a
)
)
v_edges
=
(
14.3996
*
Z_u
*
Z_v
)
/
x
*
phi
r_max
=
self
.
covalent_radii
[
Z_u
]
+
self
.
covalent_radii
[
Z_v
]
envelope
=
PolynomialCutoff
.
calculate_envelope
(
x
,
r_max
,
self
.
p
)
v_edges
=
0.5
*
v_edges
*
envelope
V_ZBL
=
scatter_sum
(
v_edges
,
receiver
,
dim
=
0
,
dim_size
=
node_attrs
.
size
(
0
))
return
V_ZBL
.
squeeze
(
-
1
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(c=
{
self
.
c
}
)"
@
compile_mode
(
"script"
)
class
AgnesiTransform
(
torch
.
nn
.
Module
):
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""
def
__init__
(
self
,
q
:
float
=
0.9183
,
p
:
float
=
4.5791
,
a
:
float
=
1.0805
,
trainable
=
False
,
):
super
().
__init__
()
self
.
register_buffer
(
"q"
,
torch
.
tensor
(
q
,
dtype
=
torch
.
get_default_dtype
()))
self
.
register_buffer
(
"p"
,
torch
.
tensor
(
p
,
dtype
=
torch
.
get_default_dtype
()))
self
.
register_buffer
(
"a"
,
torch
.
tensor
(
a
,
dtype
=
torch
.
get_default_dtype
()))
self
.
register_buffer
(
"covalent_radii"
,
torch
.
tensor
(
ase
.
data
.
covalent_radii
,
dtype
=
torch
.
get_default_dtype
(),
),
)
if
trainable
:
self
.
a
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0805
,
requires_grad
=
True
))
self
.
q
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
0.9183
,
requires_grad
=
True
))
self
.
p
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.5791
,
requires_grad
=
True
))
def
forward
(
self
,
x
:
torch
.
Tensor
,
node_attrs
:
torch
.
Tensor
,
edge_index
:
torch
.
Tensor
,
atomic_numbers
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
sender
=
edge_index
[
0
]
receiver
=
edge_index
[
1
]
node_atomic_numbers
=
atomic_numbers
[
torch
.
argmax
(
node_attrs
,
dim
=
1
)].
unsqueeze
(
-
1
)
Z_u
=
node_atomic_numbers
[
sender
]
Z_v
=
node_atomic_numbers
[
receiver
]
r_0
:
torch
.
Tensor
=
0.5
*
(
self
.
covalent_radii
[
Z_u
]
+
self
.
covalent_radii
[
Z_v
])
r_over_r_0
=
x
/
r_0
return
(
1
+
(
self
.
a
*
torch
.
pow
(
r_over_r_0
,
self
.
q
)
/
(
1
+
torch
.
pow
(
r_over_r_0
,
self
.
q
-
self
.
p
))
)
).
reciprocal_
()
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
(a=
{
self
.
a
:.
4
f
}
, q=
{
self
.
q
:.
4
f
}
, p=
{
self
.
p
:.
4
f
}
)"
)
@
compile_mode
(
"script"
)
class
SoftTransform
(
torch
.
nn
.
Module
):
"""
Tanh-based smooth transformation:
T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))],
which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0.
"""
def
__init__
(
self
,
alpha
:
float
=
4.0
,
trainable
=
False
):
"""
Args:
p1 (float): Lower "clamp" point.
alpha (float): Steepness; if None, defaults to ~6/(r0-p1).
trainable (bool): Whether to make parameters trainable.
"""
super
().
__init__
()
# Initialize parameters
self
.
register_buffer
(
"alpha"
,
torch
.
tensor
(
alpha
,
dtype
=
torch
.
get_default_dtype
())
)
if
trainable
:
self
.
alpha
=
torch
.
nn
.
Parameter
(
self
.
alpha
.
clone
())
self
.
register_buffer
(
"covalent_radii"
,
torch
.
tensor
(
ase
.
data
.
covalent_radii
,
dtype
=
torch
.
get_default_dtype
(),
),
)
def
compute_r_0
(
self
,
node_attrs
:
torch
.
Tensor
,
edge_index
:
torch
.
Tensor
,
atomic_numbers
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Compute r_0 based on atomic information.
Args:
node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers).
edge_index (torch.Tensor): Edge index indicating connections.
atomic_numbers (torch.Tensor): Atomic numbers.
Returns:
torch.Tensor: r_0 values for each edge.
"""
sender
=
edge_index
[
0
]
receiver
=
edge_index
[
1
]
node_atomic_numbers
=
atomic_numbers
[
torch
.
argmax
(
node_attrs
,
dim
=
1
)].
unsqueeze
(
-
1
)
Z_u
=
node_atomic_numbers
[
sender
]
Z_v
=
node_atomic_numbers
[
receiver
]
r_0
:
torch
.
Tensor
=
self
.
covalent_radii
[
Z_u
]
+
self
.
covalent_radii
[
Z_v
]
return
r_0
def
forward
(
self
,
x
:
torch
.
Tensor
,
node_attrs
:
torch
.
Tensor
,
edge_index
:
torch
.
Tensor
,
atomic_numbers
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
r_0
=
self
.
compute_r_0
(
node_attrs
,
edge_index
,
atomic_numbers
)
p_0
=
(
3
/
4
)
*
r_0
p_1
=
(
4
/
3
)
*
r_0
m
=
0.5
*
(
p_0
+
p_1
)
alpha
=
self
.
alpha
/
(
p_1
-
p_0
)
s_x
=
0.5
*
(
1.0
+
torch
.
tanh
(
alpha
*
(
x
-
m
)))
return
p_0
+
(
x
-
p_0
)
*
s_x
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(alpha=
{
self
.
alpha
.
item
():.
4
f
}
)"
modules/symmetric_contraction.py
0 → 100644
View file @
251f5af2
###########################################################################################
# Implementation of the symmetric contraction algorithm presented in the MACE paper
# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11)
# Authors: Ilyes Batatia
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from
typing
import
Dict
,
Optional
,
Union
import
opt_einsum_fx
import
torch
import
torch.fx
from
e3nn
import
o3
from
e3nn.util.codegen
import
CodeGenMixin
from
e3nn.util.jit
import
compile_mode
from
mace.tools.cg
import
U_matrix_real
BATCH_EXAMPLE
=
10
ALPHABET
=
[
"w"
,
"x"
,
"v"
,
"n"
,
"z"
,
"r"
,
"t"
,
"y"
,
"u"
,
"o"
,
"p"
,
"s"
]
@
compile_mode
(
"script"
)
class
SymmetricContraction
(
CodeGenMixin
,
torch
.
nn
.
Module
):
def
__init__
(
self
,
irreps_in
:
o3
.
Irreps
,
irreps_out
:
o3
.
Irreps
,
correlation
:
Union
[
int
,
Dict
[
str
,
int
]],
irrep_normalization
:
str
=
"component"
,
path_normalization
:
str
=
"element"
,
internal_weights
:
Optional
[
bool
]
=
None
,
shared_weights
:
Optional
[
bool
]
=
None
,
num_elements
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
if
irrep_normalization
is
None
:
irrep_normalization
=
"component"
if
path_normalization
is
None
:
path_normalization
=
"element"
assert
irrep_normalization
in
[
"component"
,
"norm"
,
"none"
]
assert
path_normalization
in
[
"element"
,
"path"
,
"none"
]
self
.
irreps_in
=
o3
.
Irreps
(
irreps_in
)
self
.
irreps_out
=
o3
.
Irreps
(
irreps_out
)
del
irreps_in
,
irreps_out
if
not
isinstance
(
correlation
,
tuple
):
corr
=
correlation
correlation
=
{}
for
irrep_out
in
self
.
irreps_out
:
correlation
[
irrep_out
]
=
corr
assert
shared_weights
or
not
internal_weights
if
internal_weights
is
None
:
internal_weights
=
True
self
.
internal_weights
=
internal_weights
self
.
shared_weights
=
shared_weights
del
internal_weights
,
shared_weights
self
.
contractions
=
torch
.
nn
.
ModuleList
()
for
irrep_out
in
self
.
irreps_out
:
self
.
contractions
.
append
(
Contraction
(
irreps_in
=
self
.
irreps_in
,
irrep_out
=
o3
.
Irreps
(
str
(
irrep_out
.
ir
)),
correlation
=
correlation
[
irrep_out
],
internal_weights
=
self
.
internal_weights
,
num_elements
=
num_elements
,
weights
=
self
.
shared_weights
,
)
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
outs
=
[
contraction
(
x
,
y
)
for
contraction
in
self
.
contractions
]
return
torch
.
cat
(
outs
,
dim
=-
1
)
@
compile_mode
(
"script"
)
class
Contraction
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
irreps_in
:
o3
.
Irreps
,
irrep_out
:
o3
.
Irreps
,
correlation
:
int
,
internal_weights
:
bool
=
True
,
num_elements
:
Optional
[
int
]
=
None
,
weights
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_features
=
irreps_in
.
count
((
0
,
1
))
self
.
coupling_irreps
=
o3
.
Irreps
([
irrep
.
ir
for
irrep
in
irreps_in
])
self
.
correlation
=
correlation
dtype
=
torch
.
get_default_dtype
()
for
nu
in
range
(
1
,
correlation
+
1
):
U_matrix
=
U_matrix_real
(
irreps_in
=
self
.
coupling_irreps
,
irreps_out
=
irrep_out
,
correlation
=
nu
,
dtype
=
dtype
,
)[
-
1
]
self
.
register_buffer
(
f
"U_matrix_
{
nu
}
"
,
U_matrix
)
# Tensor contraction equations
self
.
contractions_weighting
=
torch
.
nn
.
ModuleList
()
self
.
contractions_features
=
torch
.
nn
.
ModuleList
()
# Create weight for product basis
self
.
weights
=
torch
.
nn
.
ParameterList
([])
for
i
in
range
(
correlation
,
0
,
-
1
):
# Shapes definying
num_params
=
self
.
U_tensors
(
i
).
size
()[
-
1
]
num_equivariance
=
2
*
irrep_out
.
lmax
+
1
num_ell
=
self
.
U_tensors
(
i
).
size
()[
-
2
]
if
i
==
correlation
:
parse_subscript_main
=
(
[
ALPHABET
[
j
]
for
j
in
range
(
i
+
min
(
irrep_out
.
lmax
,
1
)
-
1
)]
+
[
"ik,ekc,bci,be -> bc"
]
+
[
ALPHABET
[
j
]
for
j
in
range
(
i
+
min
(
irrep_out
.
lmax
,
1
)
-
1
)]
)
graph_module_main
=
torch
.
fx
.
symbolic_trace
(
lambda
x
,
y
,
w
,
z
:
torch
.
einsum
(
""
.
join
(
parse_subscript_main
),
x
,
y
,
w
,
z
)
)
# Optimizing the contractions
self
.
graph_opt_main
=
opt_einsum_fx
.
optimize_einsums_full
(
model
=
graph_module_main
,
example_inputs
=
(
torch
.
randn
(
[
num_equivariance
]
+
[
num_ell
]
*
i
+
[
num_params
]
).
squeeze
(
0
),
torch
.
randn
((
num_elements
,
num_params
,
self
.
num_features
)),
torch
.
randn
((
BATCH_EXAMPLE
,
self
.
num_features
,
num_ell
)),
torch
.
randn
((
BATCH_EXAMPLE
,
num_elements
)),
),
)
# Parameters for the product basis
w
=
torch
.
nn
.
Parameter
(
torch
.
randn
((
num_elements
,
num_params
,
self
.
num_features
))
/
num_params
)
self
.
weights_max
=
w
else
:
# Generate optimized contractions equations
parse_subscript_weighting
=
(
[
ALPHABET
[
j
]
for
j
in
range
(
i
+
min
(
irrep_out
.
lmax
,
1
))]
+
[
"k,ekc,be->bc"
]
+
[
ALPHABET
[
j
]
for
j
in
range
(
i
+
min
(
irrep_out
.
lmax
,
1
))]
)
parse_subscript_features
=
(
[
"bc"
]
+
[
ALPHABET
[
j
]
for
j
in
range
(
i
-
1
+
min
(
irrep_out
.
lmax
,
1
))]
+
[
"i,bci->bc"
]
+
[
ALPHABET
[
j
]
for
j
in
range
(
i
-
1
+
min
(
irrep_out
.
lmax
,
1
))]
)
# Symbolic tracing of contractions
graph_module_weighting
=
torch
.
fx
.
symbolic_trace
(
lambda
x
,
y
,
z
:
torch
.
einsum
(
""
.
join
(
parse_subscript_weighting
),
x
,
y
,
z
)
)
graph_module_features
=
torch
.
fx
.
symbolic_trace
(
lambda
x
,
y
:
torch
.
einsum
(
""
.
join
(
parse_subscript_features
),
x
,
y
)
)
# Optimizing the contractions
graph_opt_weighting
=
opt_einsum_fx
.
optimize_einsums_full
(
model
=
graph_module_weighting
,
example_inputs
=
(
torch
.
randn
(
[
num_equivariance
]
+
[
num_ell
]
*
i
+
[
num_params
]
).
squeeze
(
0
),
torch
.
randn
((
num_elements
,
num_params
,
self
.
num_features
)),
torch
.
randn
((
BATCH_EXAMPLE
,
num_elements
)),
),
)
graph_opt_features
=
opt_einsum_fx
.
optimize_einsums_full
(
model
=
graph_module_features
,
example_inputs
=
(
torch
.
randn
(
[
BATCH_EXAMPLE
,
self
.
num_features
,
num_equivariance
]
+
[
num_ell
]
*
i
).
squeeze
(
2
),
torch
.
randn
((
BATCH_EXAMPLE
,
self
.
num_features
,
num_ell
)),
),
)
self
.
contractions_weighting
.
append
(
graph_opt_weighting
)
self
.
contractions_features
.
append
(
graph_opt_features
)
# Parameters for the product basis
w
=
torch
.
nn
.
Parameter
(
torch
.
randn
((
num_elements
,
num_params
,
self
.
num_features
))
/
num_params
)
self
.
weights
.
append
(
w
)
if
not
internal_weights
:
self
.
weights
=
weights
[:
-
1
]
self
.
weights_max
=
weights
[
-
1
]
def
forward
(
self
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
out
=
self
.
graph_opt_main
(
self
.
U_tensors
(
self
.
correlation
),
self
.
weights_max
,
x
,
y
,
)
for
i
,
(
weight
,
contract_weights
,
contract_features
)
in
enumerate
(
zip
(
self
.
weights
,
self
.
contractions_weighting
,
self
.
contractions_features
)
):
c_tensor
=
contract_weights
(
self
.
U_tensors
(
self
.
correlation
-
i
-
1
),
weight
,
y
,
)
c_tensor
=
c_tensor
+
out
out
=
contract_features
(
c_tensor
,
x
)
return
out
.
view
(
out
.
shape
[
0
],
-
1
)
def
U_tensors
(
self
,
nu
:
int
):
return
dict
(
self
.
named_buffers
())[
f
"U_matrix_
{
nu
}
"
]
modules/utils.py
0 → 100644
View file @
251f5af2
###########################################################################################
# Utilities
# Authors: Ilyes Batatia, Gregor Simm and David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import
logging
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.utils.data
from
scipy.constants
import
c
,
e
from
mace.tools
import
to_numpy
from
mace.tools.scatter
import
scatter_mean
,
scatter_std
,
scatter_sum
from
mace.tools.torch_geometric.batch
import
Batch
from
.blocks
import
AtomicEnergiesBlock
def
compute_forces
(
energy
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
training
:
bool
=
True
)
->
torch
.
Tensor
:
grad_outputs
:
List
[
Optional
[
torch
.
Tensor
]]
=
[
torch
.
ones_like
(
energy
)]
gradient
=
torch
.
autograd
.
grad
(
outputs
=
[
energy
],
# [n_graphs, ]
inputs
=
[
positions
],
# [n_nodes, 3]
grad_outputs
=
grad_outputs
,
retain_graph
=
training
,
# Make sure the graph is not destroyed during training
create_graph
=
training
,
# Create graph for second derivative
allow_unused
=
True
,
# For complete dissociation turn to true
)[
0
]
# [n_nodes, 3]
if
gradient
is
None
:
return
torch
.
zeros_like
(
positions
)
return
-
1
*
gradient
def
compute_forces_virials
(
energy
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
cell
:
torch
.
Tensor
,
training
:
bool
=
True
,
compute_stress
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
grad_outputs
:
List
[
Optional
[
torch
.
Tensor
]]
=
[
torch
.
ones_like
(
energy
)]
forces
,
virials
=
torch
.
autograd
.
grad
(
outputs
=
[
energy
],
# [n_graphs, ]
inputs
=
[
positions
,
displacement
],
# [n_nodes, 3]
grad_outputs
=
grad_outputs
,
retain_graph
=
training
,
# Make sure the graph is not destroyed during training
create_graph
=
training
,
# Create graph for second derivative
allow_unused
=
True
,
)
stress
=
torch
.
zeros_like
(
displacement
)
if
compute_stress
and
virials
is
not
None
:
cell
=
cell
.
view
(
-
1
,
3
,
3
)
volume
=
torch
.
linalg
.
det
(
cell
).
abs
().
unsqueeze
(
-
1
)
stress
=
virials
/
volume
.
view
(
-
1
,
1
,
1
)
stress
=
torch
.
where
(
torch
.
abs
(
stress
)
<
1e10
,
stress
,
torch
.
zeros_like
(
stress
))
if
forces
is
None
:
forces
=
torch
.
zeros_like
(
positions
)
if
virials
is
None
:
virials
=
torch
.
zeros
((
1
,
3
,
3
))
return
-
1
*
forces
,
-
1
*
virials
,
stress
def
get_symmetric_displacement
(
positions
:
torch
.
Tensor
,
unit_shifts
:
torch
.
Tensor
,
cell
:
Optional
[
torch
.
Tensor
],
edge_index
:
torch
.
Tensor
,
num_graphs
:
int
,
batch
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
cell
is
None
:
cell
=
torch
.
zeros
(
num_graphs
*
3
,
3
,
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
sender
=
edge_index
[
0
]
displacement
=
torch
.
zeros
(
(
num_graphs
,
3
,
3
),
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
displacement
.
requires_grad_
(
True
)
symmetric_displacement
=
0.5
*
(
displacement
+
displacement
.
transpose
(
-
1
,
-
2
)
)
# From https://github.com/mir-group/nequip
positions
=
positions
+
torch
.
einsum
(
"be,bec->bc"
,
positions
,
symmetric_displacement
[
batch
]
)
cell
=
cell
.
view
(
-
1
,
3
,
3
)
cell
=
cell
+
torch
.
matmul
(
cell
,
symmetric_displacement
)
shifts
=
torch
.
einsum
(
"be,bec->bc"
,
unit_shifts
,
cell
[
batch
[
sender
]],
)
return
positions
,
shifts
,
displacement
@
torch
.
jit
.
unused
def
compute_hessians_vmap
(
forces
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
forces_flatten
=
forces
.
view
(
-
1
)
num_elements
=
forces_flatten
.
shape
[
0
]
def
get_vjp
(
v
):
return
torch
.
autograd
.
grad
(
-
1
*
forces_flatten
,
positions
,
v
,
retain_graph
=
True
,
create_graph
=
False
,
allow_unused
=
False
,
)
I_N
=
torch
.
eye
(
num_elements
).
to
(
forces
.
device
)
try
:
chunk_size
=
1
if
num_elements
<
64
else
16
gradient
=
torch
.
vmap
(
get_vjp
,
in_dims
=
0
,
out_dims
=
0
,
chunk_size
=
chunk_size
)(
I_N
)[
0
]
except
RuntimeError
:
gradient
=
compute_hessians_loop
(
forces
,
positions
)
if
gradient
is
None
:
return
torch
.
zeros
((
positions
.
shape
[
0
],
forces
.
shape
[
0
],
3
,
3
))
return
gradient
@
torch
.
jit
.
unused
def
compute_hessians_loop
(
forces
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hessian
=
[]
for
grad_elem
in
forces
.
view
(
-
1
):
hess_row
=
torch
.
autograd
.
grad
(
outputs
=
[
-
1
*
grad_elem
],
inputs
=
[
positions
],
grad_outputs
=
torch
.
ones_like
(
grad_elem
),
retain_graph
=
True
,
create_graph
=
False
,
allow_unused
=
False
,
)[
0
]
hess_row
=
hess_row
.
detach
()
# this makes it very slow? but needs less memory
if
hess_row
is
None
:
hessian
.
append
(
torch
.
zeros_like
(
positions
))
else
:
hessian
.
append
(
hess_row
)
hessian
=
torch
.
stack
(
hessian
)
return
hessian
def
get_outputs
(
energy
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
cell
:
torch
.
Tensor
,
displacement
:
Optional
[
torch
.
Tensor
],
vectors
:
Optional
[
torch
.
Tensor
]
=
None
,
training
:
bool
=
False
,
compute_force
:
bool
=
True
,
compute_virials
:
bool
=
True
,
compute_stress
:
bool
=
True
,
compute_hessian
:
bool
=
False
,
compute_edge_forces
:
bool
=
False
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
]:
if
(
compute_virials
or
compute_stress
)
and
displacement
is
not
None
:
forces
,
virials
,
stress
=
compute_forces_virials
(
energy
=
energy
,
positions
=
positions
,
displacement
=
displacement
,
cell
=
cell
,
compute_stress
=
compute_stress
,
training
=
(
training
or
compute_hessian
or
compute_edge_forces
),
)
elif
compute_force
:
forces
,
virials
,
stress
=
(
compute_forces
(
energy
=
energy
,
positions
=
positions
,
training
=
(
training
or
compute_hessian
or
compute_edge_forces
),
),
None
,
None
,
)
else
:
forces
,
virials
,
stress
=
(
None
,
None
,
None
)
if
compute_hessian
:
assert
forces
is
not
None
,
"Forces must be computed to get the hessian"
hessian
=
compute_hessians_vmap
(
forces
,
positions
)
else
:
hessian
=
None
if
compute_edge_forces
and
vectors
is
not
None
:
edge_forces
=
compute_forces
(
energy
=
energy
,
positions
=
vectors
,
training
=
(
training
or
compute_hessian
),
)
if
edge_forces
is
not
None
:
edge_forces
=
-
1
*
edge_forces
# Match LAMMPS sign convention
else
:
edge_forces
=
None
return
forces
,
virials
,
stress
,
hessian
,
edge_forces
def
get_atomic_virials_stresses
(
edge_forces
:
torch
.
Tensor
,
# [n_edges, 3]
edge_index
:
torch
.
Tensor
,
# [2, n_edges]
vectors
:
torch
.
Tensor
,
# [n_edges, 3]
num_atoms
:
int
,
batch
:
torch
.
Tensor
,
cell
:
torch
.
Tensor
,
# [n_graphs, 3, 3]
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Compute atomic virials and optionally atomic stresses from edge forces and vectors.
From pobo95 PR #528.
Returns:
Tuple of:
- Atomic virials [num_atoms, 3, 3]
- Atomic stresses [num_atoms, 3, 3] (None if not computed)
"""
edge_virial
=
torch
.
einsum
(
"zi,zj->zij"
,
edge_forces
,
vectors
)
atom_virial_sender
=
scatter_sum
(
src
=
edge_virial
,
index
=
edge_index
[
0
],
dim
=
0
,
dim_size
=
num_atoms
)
atom_virial_receiver
=
scatter_sum
(
src
=
edge_virial
,
index
=
edge_index
[
1
],
dim
=
0
,
dim_size
=
num_atoms
)
atom_virial
=
(
atom_virial_sender
+
atom_virial_receiver
)
/
2
atom_virial
=
(
atom_virial
+
atom_virial
.
transpose
(
-
1
,
-
2
))
/
2
atom_stress
=
None
cell
=
cell
.
view
(
-
1
,
3
,
3
)
volume
=
torch
.
linalg
.
det
(
cell
).
abs
().
unsqueeze
(
-
1
)
atom_volume
=
volume
[
batch
].
view
(
-
1
,
1
,
1
)
atom_stress
=
atom_virial
/
atom_volume
atom_stress
=
torch
.
where
(
torch
.
abs
(
atom_stress
)
<
1e10
,
atom_stress
,
torch
.
zeros_like
(
atom_stress
)
)
return
-
1
*
atom_virial
,
atom_stress
def
get_edge_vectors_and_lengths
(
positions
:
torch
.
Tensor
,
# [n_nodes, 3]
edge_index
:
torch
.
Tensor
,
# [2, n_edges]
shifts
:
torch
.
Tensor
,
# [n_edges, 3]
normalize
:
bool
=
False
,
eps
:
float
=
1e-9
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
sender
=
edge_index
[
0
]
receiver
=
edge_index
[
1
]
vectors
=
positions
[
receiver
]
-
positions
[
sender
]
+
shifts
# [n_edges, 3]
lengths
=
torch
.
linalg
.
norm
(
vectors
,
dim
=-
1
,
keepdim
=
True
)
# [n_edges, 1]
if
normalize
:
vectors_normed
=
vectors
/
(
lengths
+
eps
)
return
vectors_normed
,
lengths
return
vectors
,
lengths
def
_check_non_zero
(
std
):
if
np
.
any
(
std
==
0
):
logging
.
warning
(
"Standard deviation of the scaling is zero, Changing to no scaling"
)
std
[
std
==
0
]
=
1
return
std
def
extract_invariant
(
x
:
torch
.
Tensor
,
num_layers
:
int
,
num_features
:
int
,
l_max
:
int
):
out
=
[]
out
.
append
(
x
[:,
:
num_features
])
for
i
in
range
(
1
,
num_layers
):
out
.
append
(
x
[
:,
i
*
(
l_max
+
1
)
**
2
*
num_features
:
(
i
*
(
l_max
+
1
)
**
2
+
1
)
*
num_features
,
]
)
return
torch
.
cat
(
out
,
dim
=-
1
)
def
compute_mean_std_atomic_inter_energy
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
atomic_energies
:
np
.
ndarray
,
)
->
Tuple
[
float
,
float
]:
atomic_energies_fn
=
AtomicEnergiesBlock
(
atomic_energies
=
atomic_energies
)
avg_atom_inter_es_list
=
[]
head_list
=
[]
for
batch
in
data_loader
:
node_e0
=
atomic_energies_fn
(
batch
.
node_attrs
)
graph_e0s
=
scatter_sum
(
src
=
node_e0
,
index
=
batch
.
batch
,
dim
=
0
,
dim_size
=
batch
.
num_graphs
)[
torch
.
arange
(
batch
.
num_graphs
),
batch
.
head
]
graph_sizes
=
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
]
avg_atom_inter_es_list
.
append
(
(
batch
.
energy
-
graph_e0s
)
/
graph_sizes
)
# {[n_graphs], }
head_list
.
append
(
batch
.
head
)
avg_atom_inter_es
=
torch
.
cat
(
avg_atom_inter_es_list
)
# [total_n_graphs]
head
=
torch
.
cat
(
head_list
,
dim
=
0
)
# [total_n_graphs]
# mean = to_numpy(torch.mean(avg_atom_inter_es)).item()
# std = to_numpy(torch.std(avg_atom_inter_es)).item()
mean
=
to_numpy
(
scatter_mean
(
src
=
avg_atom_inter_es
,
index
=
head
,
dim
=
0
).
squeeze
(
-
1
))
std
=
to_numpy
(
scatter_std
(
src
=
avg_atom_inter_es
,
index
=
head
,
dim
=
0
).
squeeze
(
-
1
))
std
=
_check_non_zero
(
std
)
return
mean
,
std
def
_compute_mean_std_atomic_inter_energy
(
batch
:
Batch
,
atomic_energies_fn
:
AtomicEnergiesBlock
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
head
=
batch
.
head
node_e0
=
atomic_energies_fn
(
batch
.
node_attrs
)
graph_e0s
=
scatter_sum
(
src
=
node_e0
,
index
=
batch
.
batch
,
dim
=
0
,
dim_size
=
batch
.
num_graphs
)[
torch
.
arange
(
batch
.
num_graphs
),
head
]
graph_sizes
=
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
]
atom_energies
=
(
batch
.
energy
-
graph_e0s
)
/
graph_sizes
return
atom_energies
def
compute_mean_rms_energy_forces
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
atomic_energies
:
np
.
ndarray
,
)
->
Tuple
[
float
,
float
]:
atomic_energies_fn
=
AtomicEnergiesBlock
(
atomic_energies
=
atomic_energies
)
atom_energy_list
=
[]
forces_list
=
[]
head_list
=
[]
head_batch
=
[]
for
batch
in
data_loader
:
head
=
batch
.
head
node_e0
=
atomic_energies_fn
(
batch
.
node_attrs
)
graph_e0s
=
scatter_sum
(
src
=
node_e0
,
index
=
batch
.
batch
,
dim
=
0
,
dim_size
=
batch
.
num_graphs
)[
torch
.
arange
(
batch
.
num_graphs
),
head
]
graph_sizes
=
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
]
atom_energy_list
.
append
(
(
batch
.
energy
-
graph_e0s
)
/
graph_sizes
)
# {[n_graphs], }
forces_list
.
append
(
batch
.
forces
)
# {[n_graphs*n_atoms,3], }
head_list
.
append
(
head
)
head_batch
.
append
(
head
[
batch
.
batch
])
atom_energies
=
torch
.
cat
(
atom_energy_list
,
dim
=
0
)
# [total_n_graphs]
forces
=
torch
.
cat
(
forces_list
,
dim
=
0
)
# {[total_n_graphs*n_atoms,3], }
head
=
torch
.
cat
(
head_list
,
dim
=
0
)
# [total_n_graphs]
head_batch
=
torch
.
cat
(
head_batch
,
dim
=
0
)
# [total_n_graphs]
# mean = to_numpy(torch.mean(atom_energies)).item()
# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item()
mean
=
to_numpy
(
scatter_mean
(
src
=
atom_energies
,
index
=
head
,
dim
=
0
).
squeeze
(
-
1
))
rms
=
to_numpy
(
torch
.
sqrt
(
scatter_mean
(
src
=
torch
.
square
(
forces
),
index
=
head_batch
,
dim
=
0
).
mean
(
-
1
)
)
)
rms
=
_check_non_zero
(
rms
)
return
mean
,
rms
def
_compute_mean_rms_energy_forces
(
batch
:
Batch
,
atomic_energies_fn
:
AtomicEnergiesBlock
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
head
=
batch
.
head
node_e0
=
atomic_energies_fn
(
batch
.
node_attrs
)
graph_e0s
=
scatter_sum
(
src
=
node_e0
,
index
=
batch
.
batch
,
dim
=
0
,
dim_size
=
batch
.
num_graphs
)[
torch
.
arange
(
batch
.
num_graphs
),
head
]
graph_sizes
=
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
]
atom_energies
=
(
batch
.
energy
-
graph_e0s
)
/
graph_sizes
# {[n_graphs], }
forces
=
batch
.
forces
# {[n_graphs*n_atoms,3], }
return
atom_energies
,
forces
def
compute_avg_num_neighbors
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
)
->
float
:
num_neighbors
=
[]
for
batch
in
data_loader
:
_
,
receivers
=
batch
.
edge_index
_
,
counts
=
torch
.
unique
(
receivers
,
return_counts
=
True
)
num_neighbors
.
append
(
counts
)
avg_num_neighbors
=
torch
.
mean
(
torch
.
cat
(
num_neighbors
,
dim
=
0
).
type
(
torch
.
get_default_dtype
())
)
return
to_numpy
(
avg_num_neighbors
).
item
()
def
compute_statistics
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
atomic_energies
:
np
.
ndarray
,
)
->
Tuple
[
float
,
float
,
float
,
float
]:
atomic_energies_fn
=
AtomicEnergiesBlock
(
atomic_energies
=
atomic_energies
)
atom_energy_list
=
[]
forces_list
=
[]
num_neighbors
=
[]
head_list
=
[]
head_batch
=
[]
for
batch
in
data_loader
:
head
=
batch
.
head
node_e0
=
atomic_energies_fn
(
batch
.
node_attrs
)
graph_e0s
=
scatter_sum
(
src
=
node_e0
,
index
=
batch
.
batch
,
dim
=
0
,
dim_size
=
batch
.
num_graphs
)[
torch
.
arange
(
batch
.
num_graphs
),
head
]
graph_sizes
=
batch
.
ptr
[
1
:]
-
batch
.
ptr
[:
-
1
]
atom_energy_list
.
append
(
(
batch
.
energy
-
graph_e0s
)
/
graph_sizes
)
# {[n_graphs], }
forces_list
.
append
(
batch
.
forces
)
# {[n_graphs*n_atoms,3], }
head_list
.
append
(
head
)
# {[n_graphs], }
head_batch
.
append
(
head
[
batch
.
batch
])
_
,
receivers
=
batch
.
edge_index
_
,
counts
=
torch
.
unique
(
receivers
,
return_counts
=
True
)
num_neighbors
.
append
(
counts
)
atom_energies
=
torch
.
cat
(
atom_energy_list
,
dim
=
0
)
# [total_n_graphs]
forces
=
torch
.
cat
(
forces_list
,
dim
=
0
)
# {[total_n_graphs*n_atoms,3], }
head
=
torch
.
cat
(
head_list
,
dim
=
0
)
# [total_n_graphs]
head_batch
=
torch
.
cat
(
head_batch
,
dim
=
0
)
# [total_n_graphs]
# mean = to_numpy(torch.mean(atom_energies)).item()
mean
=
to_numpy
(
scatter_mean
(
src
=
atom_energies
,
index
=
head
,
dim
=
0
).
squeeze
(
-
1
))
rms
=
to_numpy
(
torch
.
sqrt
(
scatter_mean
(
src
=
torch
.
square
(
forces
),
index
=
head_batch
,
dim
=
0
).
mean
(
-
1
)
)
)
avg_num_neighbors
=
torch
.
mean
(
torch
.
cat
(
num_neighbors
,
dim
=
0
).
type
(
torch
.
get_default_dtype
())
)
return
to_numpy
(
avg_num_neighbors
).
item
(),
mean
,
rms
def
compute_rms_dipoles
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
)
->
Tuple
[
float
,
float
]:
dipoles_list
=
[]
for
batch
in
data_loader
:
dipoles_list
.
append
(
batch
.
dipole
)
# {[n_graphs,3], }
dipoles
=
torch
.
cat
(
dipoles_list
,
dim
=
0
)
# {[total_n_graphs,3], }
rms
=
to_numpy
(
torch
.
sqrt
(
torch
.
mean
(
torch
.
square
(
dipoles
)))).
item
()
rms
=
_check_non_zero
(
rms
)
return
rms
def
compute_fixed_charge_dipole
(
charges
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
batch
:
torch
.
Tensor
,
num_graphs
:
int
,
)
->
torch
.
Tensor
:
mu
=
positions
*
charges
.
unsqueeze
(
-
1
)
/
(
1e-11
/
c
/
e
)
# [N_atoms,3]
return
scatter_sum
(
src
=
mu
,
index
=
batch
.
unsqueeze
(
-
1
),
dim
=
0
,
dim_size
=
num_graphs
)
# [N_graphs,3]
class
InteractionKwargs
(
NamedTuple
):
lammps_class
:
Optional
[
torch
.
Tensor
]
lammps_natoms
:
Tuple
[
int
,
int
]
=
(
0
,
0
)
class
GraphContext
(
NamedTuple
):
is_lammps
:
bool
num_graphs
:
int
num_atoms_arange
:
torch
.
Tensor
displacement
:
Optional
[
torch
.
Tensor
]
positions
:
torch
.
Tensor
vectors
:
torch
.
Tensor
lengths
:
torch
.
Tensor
cell
:
torch
.
Tensor
node_heads
:
torch
.
Tensor
interaction_kwargs
:
InteractionKwargs
def
prepare_graph
(
data
:
Dict
[
str
,
torch
.
Tensor
],
compute_virials
:
bool
=
False
,
compute_stress
:
bool
=
False
,
compute_displacement
:
bool
=
False
,
lammps_mliap
:
bool
=
False
,
)
->
GraphContext
:
if
torch
.
jit
.
is_scripting
():
lammps_mliap
=
False
node_heads
=
(
data
[
"head"
][
data
[
"batch"
]]
if
"head"
in
data
else
torch
.
zeros_like
(
data
[
"batch"
])
)
if
lammps_mliap
:
n_real
,
n_total
=
data
[
"natoms"
][
0
],
data
[
"natoms"
][
1
]
num_graphs
=
2
num_atoms_arange
=
torch
.
arange
(
n_real
,
device
=
data
[
"node_attrs"
].
device
)
displacement
=
None
positions
=
torch
.
zeros
(
(
int
(
n_real
),
3
),
dtype
=
data
[
"vectors"
].
dtype
,
device
=
data
[
"vectors"
].
device
,
)
cell
=
torch
.
zeros
(
(
num_graphs
,
3
,
3
),
dtype
=
data
[
"vectors"
].
dtype
,
device
=
data
[
"vectors"
].
device
,
)
vectors
=
data
[
"vectors"
].
requires_grad_
(
True
)
lengths
=
torch
.
linalg
.
vector_norm
(
vectors
,
dim
=
1
,
keepdim
=
True
)
ikw
=
InteractionKwargs
(
data
[
"lammps_class"
],
(
n_real
,
n_total
))
else
:
data
[
"positions"
].
requires_grad_
(
True
)
positions
=
data
[
"positions"
]
cell
=
data
[
"cell"
]
num_atoms_arange
=
torch
.
arange
(
positions
.
shape
[
0
],
device
=
positions
.
device
)
num_graphs
=
int
(
data
[
"ptr"
].
numel
()
-
1
)
displacement
=
torch
.
zeros
(
(
num_graphs
,
3
,
3
),
dtype
=
positions
.
dtype
,
device
=
positions
.
device
)
if
compute_virials
or
compute_stress
or
compute_displacement
:
p
,
s
,
displacement
=
get_symmetric_displacement
(
positions
=
positions
,
unit_shifts
=
data
[
"unit_shifts"
],
cell
=
cell
,
edge_index
=
data
[
"edge_index"
],
num_graphs
=
num_graphs
,
batch
=
data
[
"batch"
],
)
data
[
"positions"
],
data
[
"shifts"
]
=
p
,
s
vectors
,
lengths
=
get_edge_vectors_and_lengths
(
positions
=
data
[
"positions"
],
edge_index
=
data
[
"edge_index"
],
shifts
=
data
[
"shifts"
],
)
ikw
=
InteractionKwargs
(
None
,
(
0
,
0
))
return
GraphContext
(
is_lammps
=
lammps_mliap
,
num_graphs
=
num_graphs
,
num_atoms_arange
=
num_atoms_arange
,
displacement
=
displacement
,
positions
=
positions
,
vectors
=
vectors
,
lengths
=
lengths
,
cell
=
cell
,
node_heads
=
node_heads
,
interaction_kwargs
=
ikw
,
)
modules/wrapper_ops.py
0 → 100644
View file @
251f5af2
"""
Wrapper class for o3.Linear that optionally uses cuet.Linear
"""
import
dataclasses
from
typing
import
List
,
Optional
import
torch
from
e3nn
import
o3
from
mace.modules.symmetric_contraction
import
SymmetricContraction
from
mace.tools.cg
import
O3_e3nn
try
:
import
cuequivariance
as
cue
import
cuequivariance_torch
as
cuet
CUET_AVAILABLE
=
True
except
ImportError
:
CUET_AVAILABLE
=
False
@
dataclasses
.
dataclass
class
CuEquivarianceConfig
:
"""Configuration for cuequivariance acceleration"""
enabled
:
bool
=
False
layout
:
str
=
"mul_ir"
# One of: mul_ir, ir_mul
layout_str
:
str
=
"mul_ir"
group
:
str
=
"O3"
optimize_all
:
bool
=
False
# Set to True to enable all optimizations
optimize_linear
:
bool
=
False
optimize_channelwise
:
bool
=
False
optimize_symmetric
:
bool
=
False
optimize_fctp
:
bool
=
False
def
__post_init__
(
self
):
if
self
.
enabled
and
CUET_AVAILABLE
:
self
.
layout_str
=
self
.
layout
self
.
layout
=
getattr
(
cue
,
self
.
layout
)
self
.
group
=
(
O3_e3nn
if
self
.
group
==
"O3_e3nn"
else
getattr
(
cue
,
self
.
group
)
)
if
not
CUET_AVAILABLE
:
self
.
enabled
=
False
class
Linear
:
"""Returns either a cuet.Linear or o3.Linear based on config"""
def
__new__
(
cls
,
irreps_in
:
o3
.
Irreps
,
irreps_out
:
o3
.
Irreps
,
shared_weights
:
bool
=
True
,
internal_weights
:
bool
=
True
,
cueq_config
:
Optional
[
CuEquivarianceConfig
]
=
None
,
):
if
(
CUET_AVAILABLE
and
cueq_config
is
not
None
and
cueq_config
.
enabled
and
(
cueq_config
.
optimize_all
or
cueq_config
.
optimize_linear
)
):
return
cuet
.
Linear
(
cue
.
Irreps
(
cueq_config
.
group
,
irreps_in
),
cue
.
Irreps
(
cueq_config
.
group
,
irreps_out
),
layout
=
cueq_config
.
layout
,
shared_weights
=
shared_weights
,
use_fallback
=
True
,
)
return
o3
.
Linear
(
irreps_in
,
irreps_out
,
shared_weights
=
shared_weights
,
internal_weights
=
internal_weights
,
)
class
TensorProduct
:
"""Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct"""
def
__new__
(
cls
,
irreps_in1
:
o3
.
Irreps
,
irreps_in2
:
o3
.
Irreps
,
irreps_out
:
o3
.
Irreps
,
instructions
:
Optional
[
List
]
=
None
,
shared_weights
:
bool
=
False
,
internal_weights
:
bool
=
False
,
cueq_config
:
Optional
[
CuEquivarianceConfig
]
=
None
,
):
if
(
CUET_AVAILABLE
and
cueq_config
is
not
None
and
cueq_config
.
enabled
and
(
cueq_config
.
optimize_all
or
cueq_config
.
optimize_channelwise
)
):
return
cuet
.
ChannelWiseTensorProduct
(
cue
.
Irreps
(
cueq_config
.
group
,
irreps_in1
),
cue
.
Irreps
(
cueq_config
.
group
,
irreps_in2
),
cue
.
Irreps
(
cueq_config
.
group
,
irreps_out
),
layout
=
cueq_config
.
layout
,
shared_weights
=
shared_weights
,
internal_weights
=
internal_weights
,
dtype
=
torch
.
get_default_dtype
(),
math_dtype
=
torch
.
get_default_dtype
(),
)
return
o3
.
TensorProduct
(
irreps_in1
,
irreps_in2
,
irreps_out
,
instructions
=
instructions
,
shared_weights
=
shared_weights
,
internal_weights
=
internal_weights
,
)
class
FullyConnectedTensorProduct
:
"""Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct"""
def
__new__
(
cls
,
irreps_in1
:
o3
.
Irreps
,
irreps_in2
:
o3
.
Irreps
,
irreps_out
:
o3
.
Irreps
,
shared_weights
:
bool
=
True
,
internal_weights
:
bool
=
True
,
cueq_config
:
Optional
[
CuEquivarianceConfig
]
=
None
,
):
if
(
CUET_AVAILABLE
and
cueq_config
is
not
None
and
cueq_config
.
enabled
and
(
cueq_config
.
optimize_all
or
cueq_config
.
optimize_fctp
)
):
return
cuet
.
FullyConnectedTensorProduct
(
cue
.
Irreps
(
cueq_config
.
group
,
irreps_in1
),
cue
.
Irreps
(
cueq_config
.
group
,
irreps_in2
),
cue
.
Irreps
(
cueq_config
.
group
,
irreps_out
),
layout
=
cueq_config
.
layout
,
shared_weights
=
shared_weights
,
internal_weights
=
internal_weights
,
use_fallback
=
True
,
)
return
o3
.
FullyConnectedTensorProduct
(
irreps_in1
,
irreps_in2
,
irreps_out
,
shared_weights
=
shared_weights
,
internal_weights
=
internal_weights
,
)
class
SymmetricContractionWrapper
:
"""Wrapper around SymmetricContraction/cuet.SymmetricContraction"""
def
__new__
(
cls
,
irreps_in
:
o3
.
Irreps
,
irreps_out
:
o3
.
Irreps
,
correlation
:
int
,
num_elements
:
Optional
[
int
]
=
None
,
cueq_config
:
Optional
[
CuEquivarianceConfig
]
=
None
,
):
if
(
CUET_AVAILABLE
and
cueq_config
is
not
None
and
cueq_config
.
enabled
and
(
cueq_config
.
optimize_all
or
cueq_config
.
optimize_symmetric
)
):
return
cuet
.
SymmetricContraction
(
cue
.
Irreps
(
cueq_config
.
group
,
irreps_in
),
cue
.
Irreps
(
cueq_config
.
group
,
irreps_out
),
layout_in
=
cue
.
ir_mul
,
layout_out
=
cueq_config
.
layout
,
contraction_degree
=
correlation
,
num_elements
=
num_elements
,
original_mace
=
True
,
dtype
=
torch
.
get_default_dtype
(),
math_dtype
=
torch
.
get_default_dtype
(),
)
return
SymmetricContraction
(
irreps_in
=
irreps_in
,
irreps_out
=
irreps_out
,
correlation
=
correlation
,
num_elements
=
num_elements
,
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment