Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
nivren
ICT-CSP
Commits
b75ed73c
Unverified
Commit
b75ed73c
authored
Aug 24, 2025
by
zcxzcx1
Committed by
GitHub
Aug 24, 2025
Browse files
Add files via upload
parent
56d3c363
Changes
53
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
3637 additions
and
0 deletions
+3637
-0
mace-bench/3rdparty/SevenNet/tests/lammps_tests/scripts/pbc_skel.lmp
...3rdparty/SevenNet/tests/lammps_tests/scripts/pbc_skel.lmp
+23
-0
mace-bench/3rdparty/SevenNet/tests/lammps_tests/scripts/skel.lmp
...nch/3rdparty/SevenNet/tests/lammps_tests/scripts/skel.lmp
+20
-0
mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py
...bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py
+467
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py
...nch/3rdparty/SevenNet/tests/unit_tests/test_calculator.py
+217
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py
+233
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py
+282
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py
+521
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py
+285
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py
+136
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py
+213
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py
...nch/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py
+344
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py
...ch/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py
+494
-0
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py
+402
-0
No files found.
mace-bench/3rdparty/SevenNet/tests/lammps_tests/scripts/pbc_skel.lmp
0 → 100644
View file @
b75ed73c
units metal
boundary __BOUNDARY__
box tilt large
read_data __LMP_STCT__
#replicate __REPLICATE__
mass * 1.0 # do not matter since we don't run MD
pair_style __PAIR_STYLE__
pair_coeff * * __POTENTIALS__ __ELEMENT__
timestep 0.002
compute pa all pe/atom
thermo 1
fix 1 all nve
thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp
dump mydump all custom 1 force.dump id type element c_pa x y z fx fy fz
dump_modify mydump sort id element __ELEMENT__
run 0
mace-bench/3rdparty/SevenNet/tests/lammps_tests/scripts/skel.lmp
0 → 100644
View file @
b75ed73c
units metal
boundary __BOUNDARY__
read_data __LMP_STCT__
mass * 1.0 # do not matter since we don't run MD
pair_style __PAIR_STYLE__
pair_coeff * * __POTENTIALS__ __ELEMENT__
timestep 0.002
compute pa all pe/atom
thermo 1
fix 1 all nve
thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp
dump mydump all custom 1 __FORCE_DUMP_PATH__ id type element c_pa x y z fx fy fz
dump_modify mydump sort id element __ELEMENT__
run 0
mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py
0 → 100644
View file @
b75ed73c
import
copy
import
logging
import
pathlib
import
subprocess
import
ase.calculators.lammps
import
ase.io.lammpsdata
import
numpy
as
np
import
pytest
import
torch
from
ase.build
import
bulk
,
surface
from
ase.calculators.singlepoint
import
SinglePointCalculator
import
sevenn
from
sevenn.calculator
import
SevenNetCalculator
from
sevenn.model_build
import
build_E3_equivariant_model
from
sevenn.nn.cue_helper
import
is_cue_available
from
sevenn.scripts.deploy
import
deploy
,
deploy_parallel
from
sevenn.util
import
chemical_species_preprocess
,
pretrained_name_to_path
logger
=
logging
.
getLogger
(
'test_lammps'
)
cutoff
=
4.0
lmp_script_path
=
str
(
(
pathlib
.
Path
(
__file__
).
parent
/
'scripts'
/
'skel.lmp'
).
resolve
()
)
data_root
=
(
pathlib
.
Path
(
__file__
).
parent
.
parent
/
'data'
).
resolve
()
cp_0_path
=
str
(
data_root
/
'checkpoints'
/
'cp_0.pth'
)
# knows Hf, O
cp_mf_path
=
pretrained_name_to_path
(
'7net-mf-0'
)
@
pytest
.
fixture
(
scope
=
'module'
)
def
serial_potential_path
(
tmp_path_factory
):
tmp
=
tmp_path_factory
.
mktemp
(
'serial_potential'
)
pot_path
=
str
(
tmp
/
'deployed_serial.pt'
)
deploy
(
cp_0_path
,
pot_path
)
return
pot_path
@
pytest
.
fixture
(
scope
=
'module'
)
def
parallel_potential_path
(
tmp_path_factory
):
tmp
=
tmp_path_factory
.
mktemp
(
'paralllel_potential'
)
pot_path
=
str
(
tmp
/
'deployed_parallel'
)
deploy_parallel
(
cp_0_path
,
pot_path
)
return
' '
.
join
([
'3'
,
pot_path
])
@
pytest
.
fixture
(
scope
=
'module'
)
def
serial_modal_potential_path
(
tmp_path_factory
):
tmp
=
tmp_path_factory
.
mktemp
(
'serial_modal_potential'
)
pot_path
=
str
(
tmp
/
'deployed_serial.pt'
)
deploy
(
cp_mf_path
,
pot_path
,
'PBE'
)
return
pot_path
@
pytest
.
fixture
(
scope
=
'module'
)
def
parallel_modal_potential_path
(
tmp_path_factory
):
tmp
=
tmp_path_factory
.
mktemp
(
'paralllel_modal_potential'
)
pot_path
=
str
(
tmp
/
'deployed_parallel'
)
deploy_parallel
(
cp_mf_path
,
pot_path
,
'PBE'
)
return
' '
.
join
([
'5'
,
pot_path
])
@
pytest
.
fixture
(
scope
=
'module'
)
def
ref_calculator
():
return
SevenNetCalculator
(
cp_0_path
)
@
pytest
.
fixture
(
scope
=
'module'
)
def
ref_modal_calculator
():
return
SevenNetCalculator
(
cp_mf_path
,
modal
=
'PBE'
)
def
get_model_config
():
config
=
{
'cutoff'
:
cutoff
,
'channel'
:
8
,
'lmax'
:
2
,
'is_parity'
:
True
,
'num_convolution_layer'
:
3
,
'self_connection_type'
:
'linear'
,
# not NequIp
'interaction_type'
:
'nequip'
,
'radial_basis'
:
{
'radial_basis_name'
:
'bessel'
,
},
'cutoff_function'
:
{
'cutoff_function_name'
:
'poly_cut'
},
'weight_nn_hidden_neurons'
:
[
64
,
64
],
'act_radial'
:
'silu'
,
'act_scalar'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'act_gate'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'conv_denominator'
:
30.0
,
'train_denominator'
:
False
,
'shift'
:
-
10.0
,
'scale'
:
10.0
,
'train_shift_scale'
:
False
,
'irreps_manual'
:
False
,
'lmax_edge'
:
-
1
,
'lmax_node'
:
-
1
,
'readout_as_fcn'
:
False
,
'use_bias_in_linear'
:
False
,
'_normalize_sph'
:
True
,
}
config
.
update
(
chemical_species_preprocess
([
'Hf'
,
'O'
]))
return
config
def
get_model
(
config_overwrite
=
None
,
use_cueq
=
False
,
cueq_config
=
None
):
cf
=
get_model_config
()
if
config_overwrite
is
not
None
:
cf
.
update
(
config_overwrite
)
cueq_config
=
cueq_config
or
{
'cuequivariance_config'
:
{
'use'
:
use_cueq
}}
cf
.
update
(
cueq_config
)
model
=
build_E3_equivariant_model
(
cf
,
parallel
=
False
)
assert
not
isinstance
(
model
,
list
)
return
model
def
hfo2_bulk
(
replicate
=
(
2
,
2
,
2
),
a
=
4.0
):
atoms
=
bulk
(
'HfO'
,
'rocksalt'
,
a
,
orthorhombic
=
True
)
atoms
=
atoms
*
replicate
atoms
.
rattle
(
stdev
=
0.10
)
return
atoms
def
hf_surface
(
replicate
=
(
3
,
3
,
1
),
layers
=
4
,
vacuum
=
0.5
):
atoms
=
surface
(
'Al'
,
(
1
,
0
,
0
),
layers
=
layers
,
vacuum
=
vacuum
)
atoms
.
set_atomic_numbers
([
72
]
*
len
(
atoms
))
# Hf
atoms
=
atoms
*
replicate
atoms
.
rattle
(
stdev
=
0.10
)
return
atoms
def
get_system
(
system_name
,
**
kwargs
):
if
system_name
==
'bulk'
:
return
hfo2_bulk
(
**
kwargs
)
elif
system_name
==
'surface'
:
return
hf_surface
(
**
kwargs
)
else
:
raise
ValueError
()
def
assert_atoms
(
atoms1
,
atoms2
,
rtol
=
1e-5
,
atol
=
1e-6
):
def
acl
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
):
return
np
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
assert
len
(
atoms1
)
==
len
(
atoms2
)
assert
acl
(
atoms1
.
get_cell
(),
atoms2
.
get_cell
())
assert
acl
(
atoms1
.
get_potential_energy
(),
atoms2
.
get_potential_energy
())
assert
acl
(
atoms1
.
get_forces
(),
atoms2
.
get_forces
(),
rtol
*
10
,
atol
*
10
)
assert
acl
(
atoms1
.
get_stress
(
voigt
=
False
),
atoms2
.
get_stress
(
voigt
=
False
),
rtol
*
10
,
atol
*
10
,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
def
_lammps_results_to_atoms
(
lammps_log
,
force_dump
):
with
open
(
lammps_log
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
lmp_log
=
None
for
i
,
line
in
enumerate
(
lines
):
if
not
line
.
startswith
(
'Per MPI rank memory allocation'
):
continue
lmp_log
=
{
k
:
eval
(
v
)
for
k
,
v
in
zip
(
lines
[
i
+
1
].
split
(),
lines
[
i
+
2
].
split
())
}
break
assert
lmp_log
is
not
None
and
'PotEng'
in
lmp_log
latoms_list
=
ase
.
io
.
read
(
force_dump
,
format
=
'lammps-dump-text'
,
index
=
':'
)
assert
isinstance
(
latoms_list
,
list
)
latoms
=
latoms_list
[
0
]
assert
latoms
.
calc
is
not
None
latoms
.
calc
.
results
[
'energy'
]
=
lmp_log
[
'PotEng'
]
latoms
.
calc
.
results
[
'free_energy'
]
=
lmp_log
[
'PotEng'
]
latoms
.
info
=
{
'data_from'
:
'lammps'
,
'lmp_log'
:
lmp_log
,
'lmp_dump'
:
force_dump
,
}
# atomic energy read
latoms
.
calc
.
results
[
'energies'
]
=
latoms
.
arrays
[
'c_pa'
][:,
0
]
stress
=
np
.
array
(
[
[
lmp_log
[
'Pxx'
],
lmp_log
[
'Pxy'
],
lmp_log
[
'Pxz'
]],
[
lmp_log
[
'Pxy'
],
lmp_log
[
'Pyy'
],
lmp_log
[
'Pyz'
]],
[
lmp_log
[
'Pxz'
],
lmp_log
[
'Pyz'
],
lmp_log
[
'Pzz'
]],
]
)
stress
=
-
1
*
stress
/
1602.1766208
/
1000
# convert bars to eV/A^3
latoms
.
calc
.
results
[
'stress'
]
=
stress
return
latoms
def
_run_lammps
(
atoms
,
pair_style
,
potential
,
wd
,
command
,
test_name
):
wd
=
wd
.
resolve
()
pbc
=
atoms
.
get_pbc
()
pbc_str
=
' '
.
join
([
'p'
if
x
else
'f'
for
x
in
pbc
])
chem
=
list
(
set
(
atoms
.
get_chemical_symbols
()))
# Way to ase handle lammps structure
prism
=
ase
.
calculators
.
lammps
.
coordinatetransform
.
Prism
(
atoms
.
get_cell
(),
pbc
=
pbc
)
lmp_stct
=
wd
/
'lammps_structure'
ase
.
io
.
lammpsdata
.
write_lammps_data
(
lmp_stct
,
atoms
,
prismobj
=
prism
,
specorder
=
chem
)
with
open
(
lmp_script_path
,
'r'
)
as
f
:
cont
=
f
.
read
()
lammps_log
=
str
(
wd
/
'log.lammps'
)
force_dump
=
str
(
wd
/
'force.dump'
)
var_dct
=
{}
var_dct
[
'__ELEMENT__'
]
=
' '
.
join
(
chem
)
var_dct
[
'__LMP_STCT__'
]
=
str
(
lmp_stct
.
resolve
())
var_dct
[
'__PAIR_STYLE__'
]
=
pair_style
var_dct
[
'__POTENTIALS__'
]
=
potential
var_dct
[
'__BOUNDARY__'
]
=
pbc_str
var_dct
[
'__FORCE_DUMP_PATH__'
]
=
force_dump
for
key
,
val
in
var_dct
.
items
():
cont
=
cont
.
replace
(
key
,
val
)
input_script_path
=
str
(
wd
/
'in.lmp'
)
with
open
(
input_script_path
,
'w'
)
as
f
:
f
.
write
(
cont
)
command
=
f
'
{
command
}
-in
{
input_script_path
}
-log
{
lammps_log
}
'
subprocess_routine
(
command
.
split
(),
test_name
)
lmp_atoms
=
_lammps_results_to_atoms
(
lammps_log
,
force_dump
)
assert
lmp_atoms
.
calc
is
not
None
rot_mat
=
prism
.
rot_mat
results
=
copy
.
deepcopy
(
lmp_atoms
.
calc
.
results
)
r_force
=
np
.
dot
(
results
[
'forces'
],
rot_mat
.
T
)
results
[
'forces'
]
=
r_force
if
'stress'
in
results
:
# see ase.calculators.lammpsrun.py
stress_tensor
=
results
[
'stress'
]
stress_atoms
=
np
.
dot
(
np
.
dot
(
rot_mat
,
stress_tensor
),
rot_mat
.
T
)
results
[
'stress'
]
=
stress_atoms
r_cell
=
lmp_atoms
.
get_cell
()
@
rot_mat
.
T
lmp_atoms
.
set_cell
(
r_cell
,
scale_atoms
=
True
)
lmp_atoms
=
SinglePointCalculator
(
lmp_atoms
,
**
results
).
get_atoms
()
return
lmp_atoms
def
serial_lammps_run
(
atoms
,
potential
,
wd
,
test_name
,
lammps_cmd
):
command
=
lammps_cmd
return
_run_lammps
(
atoms
,
'e3gnn'
,
potential
,
wd
,
command
,
test_name
)
def
parallel_lammps_run
(
atoms
,
potential
,
wd
,
test_name
,
ncores
,
lammps_cmd
,
mpirun_cmd
):
command
=
f
'
{
mpirun_cmd
}
-np
{
ncores
}
{
lammps_cmd
}
'
return
_run_lammps
(
atoms
,
'e3gnn/parallel'
,
potential
,
wd
,
command
,
test_name
)
def
subprocess_routine
(
cmd
,
name
):
res
=
subprocess
.
run
(
cmd
,
capture_output
=
True
,
timeout
=
30
)
if
res
.
returncode
!=
0
:
logger
.
error
(
f
'Subprocess
{
name
}
failed return code:
{
res
.
returncode
}
'
)
logger
.
error
(
res
.
stderr
.
decode
(
'utf-8'
))
raise
RuntimeError
(
f
'
{
name
}
failed'
)
logger
.
info
(
f
'stdout of
{
name
}
:'
)
logger
.
info
(
res
.
stdout
.
decode
(
'utf-8'
))
@
pytest
.
mark
.
parametrize
(
'system'
,
[
'bulk'
,
'surface'
],
)
def
test_serial
(
system
,
serial_potential_path
,
ref_calculator
,
lammps_cmd
,
tmp_path
):
atoms
=
get_system
(
system
)
atoms_lammps
=
serial_lammps_run
(
atoms
=
atoms
,
potential
=
serial_potential_path
,
wd
=
tmp_path
,
test_name
=
'serial lmp test'
,
lammps_cmd
=
lammps_cmd
,
)
atoms
.
calc
=
ref_calculator
assert_atoms
(
atoms
,
atoms_lammps
)
@
pytest
.
mark
.
parametrize
(
'system,ncores'
,
[
(
'bulk'
,
1
),
(
'bulk'
,
2
),
(
'bulk'
,
4
),
(
'surface'
,
1
),
(
'surface'
,
2
),
(
'surface'
,
3
),
(
'surface'
,
4
),
],
)
def
test_parallel
(
system
,
ncores
,
parallel_potential_path
,
ref_calculator
,
lammps_cmd
,
mpirun_cmd
,
tmp_path
,
):
if
system
==
'bulk'
:
rep
=
(
6
,
6
,
3
)
elif
system
==
'surface'
:
rep
=
(
4
,
4
,
1
)
else
:
assert
False
atoms
=
get_system
(
system
,
replicate
=
rep
)
atoms_lammps
=
parallel_lammps_run
(
atoms
=
atoms
,
potential
=
parallel_potential_path
,
wd
=
tmp_path
,
test_name
=
'parallel lmp test'
,
lammps_cmd
=
lammps_cmd
,
mpirun_cmd
=
mpirun_cmd
,
ncores
=
ncores
,
)
atoms
.
calc
=
ref_calculator
assert_atoms
(
atoms
,
atoms_lammps
)
@
pytest
.
mark
.
parametrize
(
'system'
,
[
'bulk'
,
'surface'
],
)
def
test_modal_serial
(
system
,
serial_modal_potential_path
,
ref_modal_calculator
,
lammps_cmd
,
tmp_path
):
atoms
=
get_system
(
system
)
atoms_lammps
=
serial_lammps_run
(
atoms
=
atoms
,
potential
=
serial_modal_potential_path
,
wd
=
tmp_path
,
test_name
=
'serial lmp test'
,
lammps_cmd
=
lammps_cmd
,
)
atoms
.
calc
=
ref_modal_calculator
assert_atoms
(
atoms
,
atoms_lammps
)
@
pytest
.
mark
.
parametrize
(
'system,ncores'
,
[
(
'bulk'
,
2
),
(
'surface'
,
2
),
],
)
def
test_modal_parallel
(
system
,
ncores
,
parallel_modal_potential_path
,
ref_modal_calculator
,
lammps_cmd
,
mpirun_cmd
,
tmp_path
,
):
if
system
==
'bulk'
:
rep
=
(
6
,
6
,
3
)
elif
system
==
'surface'
:
rep
=
(
4
,
4
,
1
)
else
:
assert
False
atoms
=
get_system
(
system
,
replicate
=
rep
)
atoms_lammps
=
parallel_lammps_run
(
atoms
=
atoms
,
potential
=
parallel_modal_potential_path
,
wd
=
tmp_path
,
test_name
=
'parallel lmp test'
,
lammps_cmd
=
lammps_cmd
,
mpirun_cmd
=
mpirun_cmd
,
ncores
=
ncores
,
)
atoms
.
calc
=
ref_modal_calculator
assert_atoms
(
atoms
,
atoms_lammps
)
@
pytest
.
mark
.
filterwarnings
(
'ignore:.*is not found from.*'
)
@
pytest
.
mark
.
skipif
(
not
is_cue_available
(),
reason
=
'cueq not available'
)
def
test_cueq_serial
(
lammps_cmd
,
tmp_path
):
"""
TODO: Use already saved cueq enabled checkpoint after cueq becomes stable
"""
cueq
=
True
model
=
get_model
(
use_cueq
=
cueq
)
ref_calc
=
SevenNetCalculator
(
model
,
file_type
=
'model_instance'
)
atoms
=
get_system
(
'bulk'
)
cfg
=
get_model_config
()
cfg
.
update
(
{
'cuequivariance_config'
:
{
'use'
:
cueq
},
'version'
:
sevenn
.
__version__
}
)
cp_path
=
str
(
tmp_path
/
'cp.pth'
)
torch
.
save
(
{
'model_state_dict'
:
model
.
state_dict
(),
'config'
:
cfg
},
cp_path
,
)
pot_path
=
str
(
tmp_path
/
'deployed_from_cueq_serial.pt'
)
deploy
(
cp_path
,
pot_path
)
atoms_lammps
=
serial_lammps_run
(
atoms
=
atoms
,
potential
=
pot_path
,
wd
=
tmp_path
,
test_name
=
'cueq checkpoint serial lmp run test'
,
lammps_cmd
=
lammps_cmd
,
)
atoms
.
calc
=
ref_calc
assert_atoms
(
atoms
,
atoms_lammps
)
@
pytest
.
mark
.
filterwarnings
(
'ignore:.*is not found from.*'
)
@
pytest
.
mark
.
skipif
(
not
is_cue_available
(),
reason
=
'cueq not available'
)
def
test_cueq_parallel
(
lammps_cmd
,
mpirun_cmd
,
tmp_path
):
"""
TODO: Use already saved cueq enabled checkpoint after cueq becomes stable
"""
cueq
=
True
model
=
get_model
(
use_cueq
=
cueq
)
ref_calc
=
SevenNetCalculator
(
model
,
file_type
=
'model_instance'
)
atoms
=
get_system
(
'surface'
,
replicate
=
(
4
,
4
,
1
))
cfg
=
get_model_config
()
cfg
.
update
(
{
'cuequivariance_config'
:
{
'use'
:
cueq
},
'version'
:
sevenn
.
__version__
}
)
cp_path
=
str
(
tmp_path
/
'cp.pth'
)
torch
.
save
(
{
'model_state_dict'
:
model
.
state_dict
(),
'config'
:
cfg
},
cp_path
,
)
pot_path
=
str
(
tmp_path
/
'deployed_from_cueq_parallel'
)
deploy_parallel
(
cp_path
,
pot_path
)
atoms_lammps
=
parallel_lammps_run
(
atoms
=
atoms
,
potential
=
' '
.
join
([
str
(
cfg
[
'num_convolution_layer'
]),
pot_path
]),
wd
=
tmp_path
,
test_name
=
'cueq checkpoint parallel lmp run test'
,
lammps_cmd
=
lammps_cmd
,
mpirun_cmd
=
mpirun_cmd
,
ncores
=
2
,
)
atoms
.
calc
=
ref_calc
assert_atoms
(
atoms
,
atoms_lammps
)
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py
0 → 100644
View file @
b75ed73c
import
copy
import
numpy
as
np
import
pytest
from
ase.build
import
bulk
,
molecule
from
sevenn.calculator
import
D3Calculator
,
SevenNetCalculator
from
sevenn.nn.cue_helper
import
is_cue_available
from
sevenn.scripts.deploy
import
deploy
from
sevenn.util
import
(
model_from_checkpoint
,
model_from_checkpoint_with_backend
,
pretrained_name_to_path
,
)
@
pytest
.
fixture
def
atoms_pbc
():
atoms1
=
bulk
(
'NaCl'
,
'rocksalt'
,
a
=
5.63
)
atoms1
.
set_cell
([[
1.0
,
2.815
,
2.815
],
[
2.815
,
0.0
,
2.815
],
[
2.815
,
2.815
,
0.0
]])
atoms1
.
set_positions
([[
0.0
,
0.0
,
0.0
],
[
2.815
,
0.0
,
0.0
]])
return
atoms1
@
pytest
.
fixture
def
atoms_mol
():
atoms2
=
molecule
(
'H2O'
)
atoms2
.
set_positions
([[
0.0
,
0.2
,
0.12
],
[
0.0
,
0.76
,
-
0.48
],
[
0.0
,
-
0.76
,
-
0.48
]])
return
atoms2
@
pytest
.
fixture
(
scope
=
'module'
)
def
sevennet_0_cal
():
return
SevenNetCalculator
(
'7net-0_11July2024'
)
@
pytest
.
fixture
(
scope
=
'module'
)
def
sevennet_0_cueq_cal
():
cpp
=
pretrained_name_to_path
(
'7net-0_11July2024'
)
model
,
_
=
model_from_checkpoint_with_backend
(
cpp
,
'cueq'
)
return
SevenNetCalculator
(
model
)
@
pytest
.
fixture
(
scope
=
'module'
)
def
d3_cal
():
try
:
return
D3Calculator
()
except
NotImplementedError
as
e
:
pytest
.
skip
(
f
'
{
e
}
'
)
def
test_sevennet_0_cal_pbc
(
atoms_pbc
,
sevennet_0_cal
):
atoms1_ref
=
{
'energy'
:
-
3.779199
,
'energies'
:
[
-
1.8493923
,
-
1.9298072
],
'force'
:
[
[
12.666697
,
0.04726403
,
0.04775861
],
[
-
12.666697
,
-
0.04726403
,
-
0.04775861
],
],
'stress'
:
[
[
-
0.6439122
,
-
0.03643947
,
-
0.03643981
,
0.00599139
,
0.04544507
,
0.04543639
,
]
],
}
atoms_pbc
.
calc
=
sevennet_0_cal
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energy
(),
atoms1_ref
[
'energy'
])
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energy
(
force_consistent
=
True
),
atoms1_ref
[
'energy'
]
)
assert
np
.
allclose
(
atoms_pbc
.
get_forces
(),
atoms1_ref
[
'force'
])
assert
np
.
allclose
(
atoms_pbc
.
get_stress
(),
atoms1_ref
[
'stress'
])
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energies
(),
atoms1_ref
[
'energies'
])
def
test_sevennet_0_cal_mol
(
atoms_mol
,
sevennet_0_cal
):
atoms2_ref
=
{
'energy'
:
-
12.782808303833008
,
'energies'
:
[
-
6.2493525
,
-
3.141562
,
-
3.3918958
],
'force'
:
[
[
0.0
,
-
1.3619621e01
,
7.5937047e00
],
[
0.0
,
9.3918495e00
,
-
1.0172190e01
],
[
0.0
,
4.2277718e00
,
2.5784855e00
],
],
}
atoms_mol
.
calc
=
sevennet_0_cal
assert
np
.
allclose
(
atoms_mol
.
get_potential_energy
(),
atoms2_ref
[
'energy'
])
assert
np
.
allclose
(
atoms_mol
.
get_potential_energy
(
force_consistent
=
True
),
atoms2_ref
[
'energy'
]
)
assert
np
.
allclose
(
atoms_mol
.
get_forces
(),
atoms2_ref
[
'force'
])
assert
np
.
allclose
(
atoms_mol
.
get_potential_energies
(),
atoms2_ref
[
'energies'
])
def
test_sevennet_0_cal_deployed_consistency
(
tmp_path
,
atoms_pbc
):
fname
=
str
(
tmp_path
/
'7net_0.pt'
)
deploy
(
pretrained_name_to_path
(
'7net-0_11July2024'
),
fname
)
calc_script
=
SevenNetCalculator
(
fname
,
file_type
=
'torchscript'
)
calc_cp
=
SevenNetCalculator
(
pretrained_name_to_path
(
'7net-0_11July2024'
))
atoms_pbc
.
calc
=
calc_cp
atoms_pbc
.
get_potential_energy
()
res_cp
=
copy
.
copy
(
atoms_pbc
.
calc
.
results
)
atoms_pbc
.
calc
=
calc_script
atoms_pbc
.
get_potential_energy
()
res_script
=
copy
.
copy
(
atoms_pbc
.
calc
.
results
)
for
k
in
res_cp
:
assert
np
.
allclose
(
res_cp
[
k
],
res_script
[
k
])
def
test_sevennet_0_cal_as_instance_consistency
(
atoms_pbc
):
model
,
_
=
model_from_checkpoint
(
pretrained_name_to_path
(
'7net-0_11July2024'
)
)
calc_cp
=
SevenNetCalculator
(
pretrained_name_to_path
(
'7net-0_11July2024'
))
calc_instance
=
SevenNetCalculator
(
model
,
file_type
=
'model_instance'
)
atoms_pbc
.
calc
=
calc_cp
atoms_pbc
.
get_potential_energy
()
res_cp
=
copy
.
copy
(
atoms_pbc
.
calc
.
results
)
atoms_pbc
.
calc
=
calc_instance
atoms_pbc
.
get_potential_energy
()
res_script
=
copy
.
copy
(
atoms_pbc
.
calc
.
results
)
for
k
in
res_cp
:
assert
np
.
allclose
(
res_cp
[
k
],
res_script
[
k
])
@
pytest
.
mark
.
skipif
(
not
is_cue_available
(),
reason
=
'cueq not available'
)
def
test_sevennet_0_cal_cueq
(
atoms_pbc
,
sevennet_0_cueq_cal
):
atoms1_ref
=
{
'energy'
:
-
3.779199
,
'energies'
:
[
-
1.8493923
,
-
1.9298072
],
'force'
:
[
[
12.666697
,
0.04726403
,
0.04775861
],
[
-
12.666697
,
-
0.04726403
,
-
0.04775861
],
],
'stress'
:
[
[
-
0.6439122
,
-
0.03643947
,
-
0.03643981
,
0.00599139
,
0.04544507
,
0.04543639
,
]
],
}
atoms_pbc
.
calc
=
sevennet_0_cueq_cal
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energy
(),
atoms1_ref
[
'energy'
])
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energy
(
force_consistent
=
True
),
atoms1_ref
[
'energy'
]
)
assert
np
.
allclose
(
atoms_pbc
.
get_forces
(),
atoms1_ref
[
'force'
])
assert
np
.
allclose
(
atoms_pbc
.
get_stress
(),
atoms1_ref
[
'stress'
])
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energies
(),
atoms1_ref
[
'energies'
])
def
test_d3_cal_pbc
(
atoms_pbc
,
d3_cal
):
atoms1_ref
=
{
'energy'
:
-
0.531393751583389
,
'force'
:
[
[
-
0.00570205
,
0.00107457
,
0.00107459
],
[
0.00570205
,
-
0.00107457
,
-
0.00107459
],
],
'stress'
:
[
[
1.52403705e-02
,
1.50417333e-02
,
1.50417321e-02
,
-
3.22684163e-05
,
-
5.05532863e-05
,
-
5.05586994e-05
,
]
],
}
atoms_pbc
.
calc
=
d3_cal
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energy
(),
atoms1_ref
[
'energy'
])
assert
np
.
allclose
(
atoms_pbc
.
get_potential_energy
(
force_consistent
=
True
),
atoms1_ref
[
'energy'
]
)
assert
np
.
allclose
(
atoms_pbc
.
get_forces
(),
atoms1_ref
[
'force'
])
assert
np
.
allclose
(
atoms_pbc
.
get_stress
(),
atoms1_ref
[
'stress'
])
def
test_d3_cal_mol
(
atoms_mol
,
d3_cal
):
atoms2_ref
=
{
'energy'
:
-
0.009889134535170716
,
'force'
:
[
[
0.0
,
2.04263840e-03
,
1.27477674e-03
],
[
0.0
,
-
9.90038901e-05
,
1.18046682e-06
],
[
0.0
,
-
1.94363451e-03
,
-
1.27595721e-03
],
],
}
atoms_mol
.
calc
=
d3_cal
assert
np
.
allclose
(
atoms_mol
.
get_potential_energy
(),
atoms2_ref
[
'energy'
])
assert
np
.
allclose
(
atoms_mol
.
get_potential_energy
(
force_consistent
=
True
),
atoms2_ref
[
'energy'
]
)
assert
np
.
allclose
(
atoms_mol
.
get_forces
(),
atoms2_ref
[
'force'
])
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py
0 → 100644
View file @
b75ed73c
import
csv
import
os
import
pathlib
from
unittest
import
mock
import
ase.io
import
numpy
as
np
import
pytest
import
yaml
from
ase.build
import
bulk
from
sevenn.calculator
import
SevenNetCalculator
from
sevenn.logger
import
Logger
from
sevenn.main.sevenn
import
main
as
sevenn_main
from
sevenn.main.sevenn_get_model
import
main
as
get_model_main
from
sevenn.main.sevenn_graph_build
import
main
as
graph_build_main
from
sevenn.main.sevenn_inference
import
main
as
inference_main
from
sevenn.util
import
pretrained_name_to_path
main
=
os
.
path
.
abspath
(
f
'
{
os
.
path
.
dirname
(
__file__
)
}
/../../sevenn/main/'
)
preset
=
os
.
path
.
abspath
(
f
'
{
os
.
path
.
dirname
(
__file__
)
}
/../../sevenn/presets/'
)
file_path
=
pathlib
.
Path
(
__file__
).
parent
.
resolve
()
data_root
=
(
pathlib
.
Path
(
__file__
).
parent
.
parent
/
'data'
).
resolve
()
hfo2_path
=
str
(
data_root
/
'systems'
/
'hfo2.extxyz'
)
hfo2_7net_0_inference_path
=
data_root
/
'inferences'
/
'snet0_on_hfo2'
cp_0_path
=
str
(
data_root
/
'checkpoints'
/
'cp_0.pth'
)
Logger
()
# init
@
pytest
.
fixture
def
atoms_hfo
():
atoms1
=
bulk
(
'HfO'
,
'rocksalt'
,
a
=
5.63
)
atoms1
.
set_cell
([[
1.0
,
2.815
,
2.815
],
[
2.815
,
0.0
,
2.815
],
[
2.815
,
2.815
,
0.0
]])
atoms1
.
set_positions
([[
0.0
,
0.0
,
0.0
],
[
2.815
,
0.0
,
0.0
]])
return
atoms1
@
pytest
.
fixture
(
scope
=
'module'
)
def
sevennet_0_cal
():
return
SevenNetCalculator
(
'7net-0_11July2024'
)
def
test_get_model_serial
(
tmp_path
,
capsys
):
output_file
=
tmp_path
/
'mypot.pt'
cp
=
pretrained_name_to_path
(
'7net-0'
)
cli_args
=
[
'-o'
,
str
(
output_file
),
cp
]
with
mock
.
patch
(
'sys.argv'
,
[
f
'
{
main
}
/sevenn_get_model.py'
]
+
cli_args
):
get_model_main
()
_
=
capsys
.
readouterr
()
# not used
assert
output_file
.
is_file
(),
'.pt file is not written'
def
test_get_model_parallel
(
tmp_path
,
capsys
):
output_dir
=
tmp_path
/
'my_parallel'
cp
=
pretrained_name_to_path
(
'7net-0'
)
expected_file_cnt
=
5
# 5 interaction layers
cli_args
=
[
'-o'
,
str
(
output_dir
),
'-p'
,
cp
]
with
mock
.
patch
(
'sys.argv'
,
[
f
'
{
main
}
/sevenn_get_model.py'
]
+
cli_args
):
# with pytest.raises(SystemExit):
get_model_main
()
_
=
capsys
.
readouterr
()
# not used
assert
output_dir
.
is_dir
(),
'parallel model directory not exist'
for
i
in
range
(
expected_file_cnt
):
assert
(
output_dir
/
f
'deployed_parallel_
{
i
}
.pt'
).
is_file
()
@
pytest
.
mark
.
parametrize
(
'source'
,
[(
hfo2_path
)])
def
test_graph_build
(
source
,
tmp_path
):
output_dir
=
tmp_path
/
'sevenn_data'
output_f
=
output_dir
/
'my_graph.pt'
output_yml
=
output_dir
/
'my_graph.yaml'
cli_args
=
[
'-o'
,
str
(
tmp_path
),
'-f'
,
'my_graph.pt'
,
source
,
'4.0'
]
with
mock
.
patch
(
'sys.argv'
,
[
f
'
{
main
}
/sevenn_graph_build.py'
]
+
cli_args
):
graph_build_main
()
assert
output_dir
.
is_dir
()
assert
output_f
.
is_file
()
assert
output_yml
.
is_file
()
@
pytest
.
mark
.
parametrize
(
'batch,device,save_graph'
,
[
(
1
,
'cpu'
,
False
),
(
2
,
'cpu'
,
False
),
(
1
,
'cpu'
,
True
),
],
)
def
test_inference
(
batch
,
device
,
save_graph
,
tmp_path
):
checkpoint
=
'7net-0'
target
=
hfo2_path
ref_path
=
hfo2_7net_0_inference_path
output_dir
=
tmp_path
/
'inference_results'
files
=
[
'info.csv'
,
'per_graph.csv'
,
'per_atom.csv'
,
'errors.txt'
]
cli_args
=
[
'--output'
,
str
(
output_dir
),
'--device'
,
device
,
'--batch'
,
str
(
batch
),
checkpoint
,
target
,
]
if
save_graph
:
cli_args
.
append
(
'--save_graph'
)
with
mock
.
patch
(
'sys.argv'
,
[
f
'
{
main
}
/sevenn_inference.py'
]
+
cli_args
):
inference_main
()
assert
output_dir
.
is_dir
()
for
f
in
files
:
assert
(
output_dir
/
f
).
is_file
()
with
open
(
output_dir
/
'errors.txt'
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
errors
=
[
float
(
ll
.
split
(
':'
)[
-
1
].
strip
())
for
ll
in
f
.
readlines
()]
with
open
(
ref_path
/
'errors.txt'
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
errors_ref
=
[
float
(
ll
.
split
(
':'
)[
-
1
].
strip
())
for
ll
in
f
.
readlines
()]
assert
np
.
allclose
(
np
.
array
(
errors
),
np
.
array
(
errors_ref
))
"""
# TODO: commented out as currently SevenNetGraphDataset can't do this
with open(output_dir / 'info.csv', 'r') as f:
reader = csv.DictReader(f)
for dct in reader:
assert dct['file'] == hfo2_path
assert reader.line_num == 3
"""
if
save_graph
:
assert
(
output_dir
/
'sevenn_data'
).
is_dir
()
assert
(
output_dir
/
'sevenn_data'
/
'saved_graph.pt'
).
is_file
()
assert
(
output_dir
/
'sevenn_data'
/
'saved_graph.yaml'
).
is_file
()
def
test_inference_unlabeled
(
atoms_hfo
,
tmp_path
):
labeled
=
str
(
hfo2_path
)
unlabeled
=
str
(
tmp_path
/
'unlabeled.xyz'
)
ase
.
io
.
write
(
unlabeled
,
atoms_hfo
)
output_dir
=
tmp_path
/
'inference_results'
cli_args
=
[
'--output'
,
str
(
output_dir
),
'--allow_unlabeled'
,
cp_0_path
,
labeled
,
unlabeled
,
]
with
mock
.
patch
(
'sys.argv'
,
[
f
'
{
main
}
/sevenn_inference.py'
]
+
cli_args
):
inference_main
()
with
open
(
output_dir
/
'info.csv'
,
'r'
)
as
f
:
reader
=
csv
.
DictReader
(
f
)
for
dct
in
reader
:
assert
dct
[
'file'
]
in
[
labeled
,
unlabeled
]
assert
reader
.
line_num
==
4
def
test_inference_labeled_w_kwargs
(
atoms_hfo
,
tmp_path
):
atoms_hfo
.
info
[
'my_energy'
]
=
1.0
atoms_hfo
.
arrays
[
'my_force'
]
=
np
.
full
((
len
(
atoms_hfo
),
3
),
7.7
)
# this should be considered as Voigt, xx, yy, zz, yz, zx, xy
atoms_hfo
.
info
[
'my_stress'
]
=
np
.
array
([
1
,
2
,
3
,
4
,
5
,
6
])
unlabeled
=
str
(
tmp_path
/
'unlabeled.xyz'
)
ase
.
io
.
write
(
unlabeled
,
atoms_hfo
)
output_dir
=
tmp_path
/
'inference_results'
cli_args
=
[
'--output'
,
str
(
output_dir
),
cp_0_path
,
unlabeled
,
'--kwargs'
,
'energy_key=my_energy'
,
'force_key=my_force'
,
'stress_key=my_stress'
,
]
with
mock
.
patch
(
'sys.argv'
,
[
f
'
{
main
}
/sevenn_inference.py'
]
+
cli_args
):
inference_main
()
per_graph
=
None
with
open
(
output_dir
/
'per_graph.csv'
,
'r'
)
as
f
:
reader
=
csv
.
DictReader
(
f
)
for
dct
in
reader
:
per_graph
=
dct
assert
reader
.
line_num
==
2
assert
per_graph
is
not
None
stress_coeff
=
-
1602.1766208
assert
np
.
allclose
(
float
(
per_graph
[
'stress_yy'
]),
2
*
stress_coeff
)
assert
np
.
allclose
(
float
(
per_graph
[
'stress_yz'
]),
4
*
stress_coeff
)
assert
np
.
allclose
(
float
(
per_graph
[
'stress_zx'
]),
5
*
stress_coeff
)
assert
np
.
allclose
(
float
(
per_graph
[
'stress_xy'
]),
6
*
stress_coeff
)
@
pytest
.
mark
.
parametrize
(
'preset_name,mode,data_path'
,
[
(
'fine_tune'
,
'train_v2'
,
hfo2_path
),
(
'base'
,
'train_v2'
,
hfo2_path
),
(
'sevennet-0'
,
'train_v1'
,
hfo2_path
),
],
)
def
test_sevenn_preset
(
preset_name
,
mode
,
data_path
,
tmp_path
):
preset_path
=
os
.
path
.
join
(
preset
,
preset_name
+
'.yaml'
)
with
open
(
preset_path
,
'r'
)
as
f
:
cfg
=
yaml
.
safe_load
(
f
)
cfg
[
'train'
][
'epoch'
]
=
1
if
mode
==
'train_v2'
:
cfg
[
'data'
][
'load_trainset_path'
]
=
data_path
cfg
[
'data'
].
pop
(
'load_testset_path'
,
None
)
elif
mode
==
'train_v1'
:
cfg
[
'data'
][
'load_dataset_path'
]
=
data_path
else
:
assert
False
cfg
[
'data'
][
'load_validset_path'
]
=
data_path
input_yam
=
str
(
tmp_path
/
'input.yaml'
)
with
open
(
input_yam
,
'w'
)
as
f
:
yaml
.
dump
(
cfg
,
f
)
Logger
().
switch_file
(
str
(
tmp_path
/
'log.sevenn'
))
cli_args
=
[
'train'
,
'-w'
,
str
(
tmp_path
),
'-m'
,
mode
,
input_yam
]
with
mock
.
patch
(
'sys.argv'
,
[
f
'
{
main
}
/sevenn.py'
]
+
cli_args
):
sevenn_main
()
assert
(
tmp_path
/
'lc.csv'
).
is_file
()
or
(
tmp_path
/
'log.csv'
).
is_file
()
assert
(
tmp_path
/
'log.sevenn'
).
is_file
()
assert
(
tmp_path
/
'checkpoint_best.pth'
).
is_file
()
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py
0 → 100644
View file @
b75ed73c
# TODO: add gradient test from total loss after double precision.
# so far, it is empirically checked by seeing learning curves
import
copy
import
numpy
as
np
import
pytest
import
torch
from
ase.build
import
bulk
from
torch_geometric.loader.dataloader
import
Collater
import
sevenn
import
sevenn.train.dataload
as
dl
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.calculator
import
SevenNetCalculator
from
sevenn.model_build
import
build_E3_equivariant_model
from
sevenn.nn.cue_helper
import
is_cue_available
from
sevenn.nn.sequential
import
AtomGraphSequential
from
sevenn.util
import
(
chemical_species_preprocess
,
model_from_checkpoint_with_backend
,
)
cutoff
=
4.0
_atoms
=
bulk
(
'NaCl'
,
'rocksalt'
,
a
=
4.00
)
*
(
2
,
2
,
2
)
_avg_num_neigh
=
30.0
_atoms
.
rattle
()
_graph
=
AtomGraphData
.
from_numpy_dict
(
dl
.
unlabeled_atoms_to_graph
(
_atoms
,
cutoff
))
def
get_graphs
(
batched
):
# batch size 2
cloned
=
[
_graph
.
clone
().
to
(
'cuda'
),
_graph
.
clone
().
to
(
'cuda'
)]
if
not
batched
:
return
cloned
else
:
return
Collater
(
cloned
)(
cloned
)
def
get_model_config
():
config
=
{
'cutoff'
:
cutoff
,
'channel'
:
32
,
'lmax'
:
2
,
'is_parity'
:
True
,
'num_convolution_layer'
:
3
,
'self_connection_type'
:
'nequip'
,
# not NequIp
'interaction_type'
:
'nequip'
,
'radial_basis'
:
{
'radial_basis_name'
:
'bessel'
,
},
'cutoff_function'
:
{
'cutoff_function_name'
:
'poly_cut'
},
'weight_nn_hidden_neurons'
:
[
64
,
64
],
'act_radial'
:
'silu'
,
'act_scalar'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'act_gate'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'conv_denominator'
:
_avg_num_neigh
,
'train_denominator'
:
False
,
'shift'
:
-
10.0
,
'scale'
:
10.0
,
'train_shift_scale'
:
False
,
'irreps_manual'
:
False
,
'lmax_edge'
:
-
1
,
'lmax_node'
:
-
1
,
'readout_as_fcn'
:
False
,
'use_bias_in_linear'
:
False
,
'_normalize_sph'
:
True
,
}
chems
=
set
()
chems
.
update
(
_atoms
.
get_chemical_symbols
())
config
.
update
(
**
chemical_species_preprocess
(
list
(
chems
)))
return
config
def
get_model
(
config_overwrite
=
None
,
use_cueq
=
False
,
cueq_config
=
None
):
cf
=
get_model_config
()
if
config_overwrite
is
not
None
:
cf
.
update
(
config_overwrite
)
cueq_config
=
cueq_config
or
{
'cuequivariance_config'
:
{
'use'
:
use_cueq
}}
cf
.
update
(
cueq_config
)
model
=
build_E3_equivariant_model
(
cf
,
parallel
=
False
)
assert
isinstance
(
model
,
AtomGraphSequential
)
model
.
to
(
'cuda'
)
return
model
@
pytest
.
mark
.
skipif
(
not
is_cue_available
()
or
not
torch
.
cuda
.
is_available
(),
reason
=
'cueq or gpu is not available'
,
)
@
pytest
.
mark
.
parametrize
(
'cf'
,
[
({}),
({
'self_connection_type'
:
'linear'
}),
({
'is_parity'
:
False
}),
({
'channel'
:
8
}),
({
'lmax'
:
3
}),
({
'num_interaction_layer'
:
2
}),
({
'num_interaction_layer'
:
4
}),
],
)
def
test_model_output
(
cf
):
torch
.
manual_seed
(
777
)
model_e3nn
=
get_model
(
cf
)
torch
.
manual_seed
(
777
)
model_cueq
=
get_model
(
cf
,
use_cueq
=
True
)
model_e3nn
.
set_is_batch_data
(
True
)
model_cueq
.
set_is_batch_data
(
True
)
e3nn_out
=
model_e3nn
.
_preprocess
(
get_graphs
(
batched
=
True
))
cueq_out
=
model_cueq
.
_preprocess
(
get_graphs
(
batched
=
True
))
for
k
,
e3nn_f
in
model_e3nn
.
_modules
.
items
():
cueq_f
=
model_cueq
.
_modules
[
k
]
e3nn_out
=
e3nn_f
(
e3nn_out
)
# type: ignore
cueq_out
=
cueq_f
(
cueq_out
)
# type: ignore
assert
torch
.
allclose
(
e3nn_out
.
x
,
cueq_out
.
x
,
atol
=
1e-6
),
(
f
'
{
k
}
\n\n
{
e3nn_f
}
\n\n
{
cueq_f
}
'
)
assert
torch
.
allclose
(
e3nn_out
.
inferred_total_energy
,
cueq_out
.
inferred_total_energy
)
assert
torch
.
allclose
(
e3nn_out
.
atomic_energy
,
cueq_out
.
atomic_energy
)
assert
torch
.
allclose
(
e3nn_out
.
inferred_force
,
cueq_out
.
inferred_force
,
atol
=
1e-5
)
assert
torch
.
allclose
(
e3nn_out
.
inferred_stress
,
cueq_out
.
inferred_stress
,
atol
=
1e-5
)
@
pytest
.
mark
.
filterwarnings
(
'ignore:.*is not found from.*'
)
@
pytest
.
mark
.
skipif
(
not
is_cue_available
()
or
not
torch
.
cuda
.
is_available
(),
reason
=
'cueq or gpu is not available'
,
)
@
pytest
.
mark
.
parametrize
(
'start_from_cueq'
,
[
(
True
),
(
False
),
],
)
def
test_checkpoint_convert
(
tmp_path
,
start_from_cueq
):
torch
.
manual_seed
(
123
)
model_from
=
get_model
(
use_cueq
=
start_from_cueq
)
cfg
=
get_model_config
()
cfg
.
update
(
{
'cuequivariance_config'
:
{
'use'
:
start_from_cueq
},
'version'
:
sevenn
.
__version__
,
}
)
torch
.
save
(
{
'model_state_dict'
:
model_from
.
state_dict
(),
'config'
:
cfg
},
tmp_path
/
'cp_from.pth'
,
)
backend
=
'e3nn'
if
start_from_cueq
else
'cueq'
model_to
,
_
=
model_from_checkpoint_with_backend
(
str
(
tmp_path
/
'cp_from.pth'
),
backend
)
model_to
.
to
(
'cuda'
)
model_from
.
set_is_batch_data
(
True
)
model_to
.
set_is_batch_data
(
True
)
from_out
=
model_from
(
get_graphs
(
batched
=
True
))
to_out
=
model_to
(
get_graphs
(
batched
=
True
))
assert
torch
.
allclose
(
from_out
.
inferred_total_energy
,
to_out
.
inferred_total_energy
)
assert
torch
.
allclose
(
from_out
.
atomic_energy
,
to_out
.
atomic_energy
)
assert
torch
.
allclose
(
from_out
.
inferred_force
,
to_out
.
inferred_force
,
atol
=
1e-5
)
assert
torch
.
allclose
(
from_out
.
inferred_stress
,
to_out
.
inferred_stress
,
atol
=
1e-5
)
@
pytest
.
mark
.
filterwarnings
(
'ignore:.*is not found from.*'
)
@
pytest
.
mark
.
skipif
(
not
is_cue_available
()
or
not
torch
.
cuda
.
is_available
(),
reason
=
'cueq or gpu is not available'
,
)
@
pytest
.
mark
.
parametrize
(
'start_from_cueq'
,
[
(
True
),
(
False
),
],
)
def
test_checkpoint_convert_no_batch
(
tmp_path
,
start_from_cueq
):
torch
.
manual_seed
(
123
)
model_from
=
get_model
(
use_cueq
=
start_from_cueq
)
cfg
=
get_model_config
()
cfg
.
update
(
{
'cuequivariance_config'
:
{
'use'
:
start_from_cueq
},
'version'
:
sevenn
.
__version__
,
}
)
torch
.
save
(
{
'model_state_dict'
:
model_from
.
state_dict
(),
'config'
:
cfg
},
tmp_path
/
'cp_from.pth'
,
)
backend
=
'e3nn'
if
start_from_cueq
else
'cueq'
model_to
,
_
=
model_from_checkpoint_with_backend
(
str
(
tmp_path
/
'cp_from.pth'
),
backend
)
model_to
.
to
(
'cuda'
)
model_from
.
set_is_batch_data
(
False
)
model_to
.
set_is_batch_data
(
False
)
from_out
=
model_from
(
get_graphs
(
batched
=
False
)[
0
])
to_out
=
model_to
(
get_graphs
(
batched
=
False
)[
0
])
assert
torch
.
allclose
(
from_out
.
inferred_total_energy
,
to_out
.
inferred_total_energy
)
assert
torch
.
allclose
(
from_out
.
atomic_energy
,
to_out
.
atomic_energy
)
assert
torch
.
allclose
(
from_out
.
inferred_force
,
to_out
.
inferred_force
,
atol
=
1e-5
)
assert
torch
.
allclose
(
from_out
.
inferred_stress
,
to_out
.
inferred_stress
,
atol
=
1e-5
)
def
assert_atoms
(
atoms1
,
atoms2
,
rtol
=
1e-5
,
atol
=
1e-6
):
def
acl
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
):
return
np
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
assert
len
(
atoms1
)
==
len
(
atoms2
)
assert
acl
(
atoms1
.
get_cell
(),
atoms2
.
get_cell
())
assert
acl
(
atoms1
.
get_potential_energy
(),
atoms2
.
get_potential_energy
())
assert
acl
(
atoms1
.
get_forces
(),
atoms2
.
get_forces
(),
rtol
*
10
,
atol
*
10
)
assert
acl
(
atoms1
.
get_stress
(
voigt
=
False
),
atoms2
.
get_stress
(
voigt
=
False
),
rtol
*
10
,
atol
*
10
,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
@
pytest
.
mark
.
filterwarnings
(
'ignore:.*is not found from.*'
)
@
pytest
.
mark
.
skipif
(
not
is_cue_available
()
or
not
torch
.
cuda
.
is_available
(),
reason
=
'cueq or gpu is not available'
,
)
def
test_calculator
(
tmp_path
):
cueq
=
True
model
=
get_model
(
use_cueq
=
cueq
)
ref_calc
=
SevenNetCalculator
(
model
,
file_type
=
'model_instance'
)
atoms
=
copy
.
deepcopy
(
_atoms
)
atoms
.
calc
=
ref_calc
cfg
=
get_model_config
()
cfg
.
update
(
{
'cuequivariance_config'
:
{
'use'
:
cueq
},
'version'
:
sevenn
.
__version__
}
)
cp_path
=
str
(
tmp_path
/
'cp.pth'
)
torch
.
save
(
{
'model_state_dict'
:
model
.
state_dict
(),
'config'
:
cfg
},
cp_path
,
)
calc2
=
SevenNetCalculator
(
cp_path
,
enable_cueq
=
False
)
atoms2
=
copy
.
deepcopy
(
_atoms
)
atoms2
.
calc
=
calc2
assert_atoms
(
atoms
,
atoms2
)
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py
0 → 100644
View file @
b75ed73c
import
logging
import
os
import
os.path
as
osp
import
uuid
from
collections
import
Counter
from
copy
import
deepcopy
from
typing
import
Literal
import
ase.calculators.singlepoint
as
singlepoint
import
ase.io
import
numpy
as
np
import
pytest
import
torch
from
ase
import
Atoms
from
ase.build
import
bulk
,
molecule
from
torch_geometric.loader
import
DataLoader
import
sevenn._keys
as
KEY
import
sevenn.train.dataload
as
dl
import
sevenn.train.graph_dataset
as
ds
import
sevenn.train.modal_dataset
as
modal_dataset
from
sevenn._const
import
NUM_UNIV_ELEMENT
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.util
import
model_from_checkpoint
,
pretrained_name_to_path
cutoff
=
4.0
lattice_constant
=
3.35
_samples
=
{
'bulk'
:
bulk
(
'NaCl'
,
'rocksalt'
,
a
=
5.63
),
'mol'
:
molecule
(
'H2O'
),
'isolated'
:
molecule
(
'H'
),
'small_bulk'
:
Atoms
(
symbols
=
'Cu'
,
positions
=
[
(
0
,
0
,
0
),
# Atom at the corner of the cube
],
cell
=
[
[
lattice_constant
,
0
,
0
],
[
0
,
lattice_constant
,
0
],
[
0
,
0
,
lattice_constant
],
],
pbc
=
True
,
# Periodic boundary conditions
),
}
_nedges_c4
=
{
'bulk'
:
36
,
'mol'
:
6
,
'isolated'
:
0
,
'small_bulk'
:
18
}
def
get_atoms
(
atoms_type
:
Literal
[
'bulk'
,
'mol'
,
'isolated'
,
'small_bulk'
],
init_y_as
:
Literal
[
'calc'
,
'info'
,
'none'
],
):
"""
Return atoms w, w/o reference values with its
# of edges for 4.0 cutoff length
"""
assert
atoms_type
in
_samples
atoms
=
deepcopy
(
_samples
[
atoms_type
])
natoms
=
len
(
atoms
)
if
init_y_as
==
'calc'
:
results
=
{
'energy'
:
np
.
random
.
rand
(
1
),
'forces'
:
np
.
random
.
rand
(
natoms
,
3
),
'stress'
:
np
.
random
.
rand
(
6
),
}
if
not
atoms
.
pbc
.
all
():
del
results
[
'stress'
]
calc
=
singlepoint
.
SinglePointCalculator
(
atoms
,
**
results
)
atoms
=
calc
.
get_atoms
()
elif
init_y_as
==
'info'
:
atoms
.
info
[
'y_energy'
]
=
np
.
random
.
rand
(
1
)
atoms
.
arrays
[
'y_force'
]
=
np
.
random
.
rand
(
natoms
,
3
)
atoms
.
info
[
'y_stress'
]
=
np
.
random
.
rand
(
6
)
if
not
atoms
.
pbc
.
all
():
del
atoms
.
info
[
'y_stress'
]
return
atoms
,
_nedges_c4
[
atoms_type
]
@
pytest
.
mark
.
parametrize
(
'init_y_as'
,
[
'calc'
,
'info'
])
@
pytest
.
mark
.
parametrize
(
'atoms_type'
,
[
'bulk'
,
'mol'
,
'isolated'
])
def
test_atoms_to_graph
(
atoms_type
,
init_y_as
):
atoms
,
nedges
=
get_atoms
(
atoms_type
,
init_y_as
)
is_stress
=
atoms
.
pbc
.
all
()
y_from_calc
=
init_y_as
==
'calc'
graph
=
dl
.
atoms_to_graph
(
atoms
,
cutoff
=
cutoff
,
y_from_calc
=
y_from_calc
)
essential
=
{
'atomic_numbers'
:
((
len
(
atoms
),),
int
),
'pos'
:
((
len
(
atoms
),
3
),
float
),
'edge_index'
:
((
2
,
nedges
),
int
),
'edge_vec'
:
((
nedges
,
3
),
float
),
'total_energy'
:
((),
float
),
'force_of_atoms'
:
((
len
(
atoms
),
3
),
float
),
'cell_volume'
:
((),
float
),
'num_atoms'
:
((),
int
),
'per_atom_energy'
:
((),
float
),
'stress'
:
((
1
,
6
),
float
),
}
for
k
,
(
shape
,
dtype
)
in
essential
.
items
():
assert
k
in
graph
,
f
'
{
k
}
missing in graph'
assert
isinstance
(
graph
[
k
],
np
.
ndarray
),
f
'
{
k
}
:
{
type
(
graph
[
k
])
}
is not np.ndarray'
assert
graph
[
k
].
shape
==
shape
,
f
'
{
k
}
shape
{
graph
[
k
].
shape
}
!=
{
shape
}
'
if
not
is_stress
and
k
==
'stress'
:
assert
np
.
isnan
(
graph
[
k
]).
all
()
else
:
assert
graph
[
k
].
dtype
==
dtype
,
f
'
{
k
}
dtype
{
graph
[
k
].
dtype
}
!=
{
dtype
}
'
assert
graph
[
'per_atom_energy'
]
==
(
graph
[
'total_energy'
]
/
len
(
atoms
))
assert
graph
[
'num_atoms'
]
==
len
(
atoms
)
if
not
is_stress
:
assert
graph
[
'cell_volume'
]
==
np
.
finfo
(
float
).
eps
@
pytest
.
mark
.
parametrize
(
'atoms_type'
,
[
'bulk'
,
'mol'
,
'isolated'
])
def
test_unlabeled_atoms_to_graph
(
atoms_type
):
atoms
,
nedges
=
get_atoms
(
atoms_type
,
'none'
)
graph
=
dl
.
unlabeled_atoms_to_graph
(
atoms
,
cutoff
=
cutoff
)
essential
=
{
'atomic_numbers'
:
((
len
(
atoms
),),
int
),
'pos'
:
((
len
(
atoms
),
3
),
float
),
'edge_index'
:
((
2
,
nedges
),
int
),
'edge_vec'
:
((
nedges
,
3
),
float
),
'cell_volume'
:
((),
float
),
'num_atoms'
:
((),
int
),
}
for
k
,
(
shape
,
dtype
)
in
essential
.
items
():
assert
k
in
graph
,
f
'
{
k
}
missing in graph'
assert
isinstance
(
graph
[
k
],
np
.
ndarray
),
f
'
{
k
}
:
{
type
(
graph
[
k
])
}
is not np.ndarray'
assert
graph
[
k
].
dtype
==
dtype
,
f
'
{
k
}
dtype
{
graph
[
k
].
dtype
}
!=
{
dtype
}
'
assert
graph
[
k
].
shape
==
shape
,
f
'
{
k
}
shape
{
graph
[
k
].
shape
}
!=
{
shape
}
'
assert
graph
[
'num_atoms'
]
==
len
(
atoms
)
if
not
atoms
.
pbc
.
all
():
assert
graph
[
'cell_volume'
]
==
np
.
finfo
(
float
).
eps
@
pytest
.
mark
.
parametrize
(
'init_y_as'
,
[
'calc'
,
'info'
])
@
pytest
.
mark
.
parametrize
(
'atoms_type'
,
[
'bulk'
,
'mol'
,
'isolated'
])
def
test_atom_graph_data
(
atoms_type
,
init_y_as
):
atoms
,
nedges
=
get_atoms
(
atoms_type
,
init_y_as
)
y_from_calc
=
init_y_as
==
'calc'
is_stress
=
atoms
.
pbc
.
all
()
np_graph
=
dl
.
atoms_to_graph
(
atoms
,
cutoff
=
cutoff
,
y_from_calc
=
y_from_calc
)
graph
=
AtomGraphData
.
from_numpy_dict
(
np_graph
)
essential
=
{
'atomic_numbers'
:
((
len
(
atoms
),),
int
),
'edge_index'
:
((
2
,
nedges
),
int
),
'edge_vec'
:
((
nedges
,
3
),
float
),
}
auxilaray
=
{
'x'
:
((
len
(
atoms
),),
int
),
'pos'
:
((
len
(
atoms
),
3
),
float
),
'num_atoms'
:
((),
int
),
'cell_volume'
:
((),
float
),
'total_energy'
:
((),
float
),
'per_atom_energy'
:
((),
float
),
'force_of_atoms'
:
((
len
(
atoms
),
3
),
float
),
'stress'
:
((
1
,
6
),
float
),
}
for
k
,
(
shape
,
dtype
)
in
essential
.
items
():
assert
k
in
graph
,
f
'
{
k
}
missing in graph'
assert
isinstance
(
graph
[
k
],
torch
.
Tensor
),
f
'
{
k
}
:
{
type
(
graph
[
k
])
}
is not an tensor'
assert
graph
[
k
].
is_floating_point
()
==
(
dtype
is
float
)
assert
graph
[
k
].
shape
==
shape
,
f
'
{
k
}
shape
{
graph
[
k
].
shape
}
!=
{
shape
}
'
for
k
,
(
shape
,
dtype
)
in
auxilaray
.
items
():
if
k
not
in
graph
:
continue
assert
isinstance
(
graph
[
k
],
torch
.
Tensor
),
f
'
{
k
}
:
{
type
(
graph
[
k
])
}
is not an tensor'
assert
graph
[
k
].
shape
==
shape
,
f
'
{
k
}
shape
{
graph
[
k
].
shape
}
!=
{
shape
}
'
if
not
is_stress
and
k
==
'stress'
:
assert
torch
.
isnan
(
graph
[
k
]).
all
()
else
:
assert
graph
[
k
].
is_floating_point
()
==
(
dtype
is
float
)
def
test_graph_build
():
"""
Compare parallel implementation, should preserve order
"""
atoms_list
=
[
get_atoms
(
t
,
'calc'
)[
0
]
# type: ignore
for
t
in
list
(
_samples
.
keys
())
]
one_core
=
dl
.
graph_build
(
atoms_list
,
cutoff
,
num_cores
=
1
,
y_from_calc
=
True
)
two_core
=
dl
.
graph_build
(
atoms_list
,
cutoff
,
num_cores
=
2
,
y_from_calc
=
True
)
assert
len
(
one_core
)
==
len
(
two_core
)
for
g1
,
g2
in
zip
(
one_core
,
two_core
):
assert
set
(
g1
.
keys
())
==
set
(
g2
.
keys
())
for
k
in
g1
.
keys
():
if
not
isinstance
(
g1
[
k
],
torch
.
Tensor
):
continue
if
k
==
'stress'
:
# TODO: robust way to test it
assert
torch
.
allclose
(
g1
[
k
],
g2
[
k
])
or
(
torch
.
isnan
(
g1
[
k
]).
all
()
==
torch
.
isnan
(
g2
[
k
]).
all
()
)
else
:
assert
torch
.
allclose
(
g1
[
k
],
g2
[
k
])
@
pytest
.
fixture
(
scope
=
'module'
)
def
graph_dataset_tuple
():
tmpdir
=
os
.
getenv
(
'TMPDIR'
,
'/tmp'
)
randstr
=
uuid
.
uuid4
().
hex
assert
os
.
access
(
tmpdir
,
os
.
W_OK
),
f
'
{
tmpdir
}
is not writable'
root
=
tmpdir
files
=
f
'
{
root
}
/
{
randstr
}
.extxyz'
atoms_list
=
[
get_atoms
(
atype
,
'calc'
)[
0
]
# type: ignore
for
atype
in
[
'bulk'
,
'mol'
,
'isolated'
]
]
ase
.
io
.
write
(
files
,
atoms_list
,
'extxyz'
)
dataset
=
ds
.
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
root
,
files
=
files
,
processed_name
=
f
'
{
randstr
}
.pt'
,
)
assert
os
.
path
.
isfile
(
f
'
{
root
}
/sevenn_data/
{
randstr
}
.pt'
),
'dataset not written'
return
dataset
,
atoms_list
def
test_sevenn_graph_dataset_properties
(
graph_dataset_tuple
):
dataset
,
atoms_list
=
graph_dataset_tuple
species
=
set
()
natoms
=
Counter
()
elist
=
[]
e_per_list
=
[]
flist
=
[]
slist
=
[]
for
at
in
atoms_list
:
chems
=
at
.
get_chemical_symbols
()
species
.
update
(
chems
)
natoms
.
update
(
chems
)
elist
.
append
(
at
.
get_potential_energy
())
e_per_list
.
append
(
at
.
get_potential_energy
()
/
len
(
at
))
flist
.
extend
(
at
.
get_forces
())
try
:
slist
.
append
(
at
.
get_stress
())
except
NotImplementedError
:
slist
.
append
(
np
.
full
(
6
,
np
.
nan
))
elist
=
np
.
array
(
elist
)
e_per_list
=
np
.
array
(
e_per_list
)
flist
=
np
.
array
(
flist
)
slist
=
np
.
array
(
slist
)
natoms
[
'total'
]
=
sum
([
cnt
for
cnt
in
list
(
natoms
.
values
())])
assert
set
(
dataset
.
species
)
==
species
assert
dataset
.
natoms
==
natoms
assert
np
.
allclose
(
dataset
.
per_atom_energy_mean
,
e_per_list
.
mean
())
assert
np
.
allclose
(
dataset
.
force_rms
,
np
.
sqrt
((
flist
**
2
).
mean
()))
def
test_sevenn_graph_dataset_elemwise_energies
(
graph_dataset_tuple
):
logger
=
logging
.
getLogger
(
__name__
)
dataset
,
atoms_list
=
graph_dataset_tuple
ref_e
=
dataset
.
elemwise_reference_energies
assert
len
(
ref_e
)
==
NUM_UNIV_ELEMENT
z_set
=
set
()
for
atoms
in
atoms_list
:
inferred_e
=
0
atomic_numbers
=
atoms
.
get_atomic_numbers
()
z_set
.
update
(
atomic_numbers
)
for
z
in
atomic_numbers
:
inferred_e
+=
ref_e
[
z
]
# it never be same, but should be similar
logger
.
info
(
'elemwise energy should be similar:'
)
logger
.
info
(
f
'
{
inferred_e
:
4
f
}
{
atoms
.
get_potential_energy
()[
0
]:
4
f
}
'
)
for
z
in
range
(
NUM_UNIV_ELEMENT
):
if
z
not
in
z_set
:
assert
ref_e
[
z
]
==
0
def
test_sevenn_graph_dataset_statistics
(
graph_dataset_tuple
):
dataset
,
atoms_list
=
graph_dataset_tuple
elist
=
[]
e_per_list
=
[]
flist
=
[]
slist
=
[]
for
at
in
atoms_list
:
elist
.
append
(
at
.
get_potential_energy
())
e_per_list
.
append
(
at
.
get_potential_energy
()
/
len
(
at
))
flist
.
extend
(
at
.
get_forces
())
try
:
slist
.
append
(
at
.
get_stress
())
except
NotImplementedError
:
slist
.
append
(
np
.
full
(
6
,
np
.
nan
))
dct
=
{
'total_energy'
:
np
.
array
(
elist
),
'per_atom_energy'
:
np
.
array
(
e_per_list
),
'force_of_atoms'
:
np
.
array
(
flist
).
flatten
(),
# 'stress': np.array(slist), # TODO: it may have nan
}
for
key
in
dct
:
assert
np
.
allclose
(
dataset
.
statistics
[
key
][
'mean'
],
dct
[
key
].
mean
()),
key
assert
np
.
allclose
(
dataset
.
statistics
[
key
][
'std'
],
dct
[
key
].
std
(
ddof
=
0
)),
key
assert
np
.
allclose
(
dataset
.
statistics
[
key
][
'median'
],
np
.
median
(
dct
[
key
])
),
key
assert
np
.
allclose
(
dataset
.
statistics
[
key
][
'max'
],
dct
[
key
].
max
()),
key
assert
np
.
allclose
(
dataset
.
statistics
[
key
][
'min'
],
dct
[
key
].
min
()),
key
def
test_sevenn_mm_dataset_statistics
(
tmp_path
):
files
=
osp
.
join
(
tmp_path
,
'gd_one.extxyz'
)
atoms_list1
=
[
get_atoms
(
atype
,
'calc'
)[
0
]
# type: ignore
for
atype
in
[
'bulk'
,
'bulk'
,
'bulk'
,
'bulk'
]
]
ase
.
io
.
write
(
files
,
atoms_list1
,
'extxyz'
)
gd1
=
ds
.
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
tmp_path
,
files
=
files
,
processed_name
=
'gd_one.pt'
,
)
files
=
osp
.
join
(
tmp_path
,
'gd_two.extxyz'
)
atoms_list2
=
[
get_atoms
(
atype
,
'calc'
)[
0
]
# type: ignore
for
atype
in
[
'mol'
,
'mol'
,
'bulk'
]
]
ase
.
io
.
write
(
files
,
atoms_list2
,
'extxyz'
)
gd2
=
ds
.
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
tmp_path
,
files
=
files
,
processed_name
=
'gd_two.pt'
,
)
ref
=
ds
.
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
tmp_path
,
files
=
[
gd1
.
processed_paths
[
0
],
gd2
.
processed_paths
[
0
]],
processed_name
=
'combined.pt'
,
)
mm
=
modal_dataset
.
SevenNetMultiModalDataset
(
{
'modal1'
:
gd1
,
'modal2'
:
gd2
}
)
assert
np
.
allclose
(
ref
.
per_atom_energy_mean
,
mm
.
per_atom_energy_mean
[
'total'
])
assert
np
.
allclose
(
ref
.
avg_num_neigh
,
mm
.
avg_num_neigh
[
'total'
])
assert
np
.
allclose
(
ref
.
force_rms
,
mm
.
force_rms
[
'total'
])
assert
set
(
ref
.
species
)
==
set
(
mm
.
species
[
'total'
])
@
pytest
.
mark
.
parametrize
(
'a_types,init_ys'
,
[([
'bulk'
,
'mol'
,
'isolated'
],
[
'calc'
,
'calc'
,
'calc'
])]
)
def
test_7net_graph_dataset_batch_shape
(
a_types
,
init_ys
,
tmp_path
):
assert
len
(
a_types
)
==
len
(
init_ys
)
n_graph
=
len
(
a_types
)
atoms_list
=
[]
tot_edges
=
0
tot_atoms
=
0
for
a_type
,
init_y
in
zip
(
a_types
,
init_ys
):
atoms
,
n_edge
=
get_atoms
(
a_type
,
init_y
)
tot_edges
+=
n_edge
tot_atoms
+=
len
(
atoms
)
atoms_list
.
append
(
atoms
)
ase
.
io
.
write
(
tmp_path
/
'tmp'
,
atoms_list
,
format
=
'extxyz'
)
dataset
=
ds
.
SevenNetGraphDataset
(
cutoff
,
tmp_path
,
str
(
tmp_path
/
'tmp'
))
loader
=
DataLoader
(
dataset
,
batch_size
=
n_graph
)
graph
=
next
(
iter
(
loader
))
essential
=
{
'x'
:
((
tot_atoms
,),
int
),
'atomic_numbers'
:
((
tot_atoms
,),
int
),
'pos'
:
((
tot_atoms
,
3
),
float
),
'edge_index'
:
((
2
,
tot_edges
),
int
),
'edge_vec'
:
((
tot_edges
,
3
),
float
),
'total_energy'
:
((
n_graph
,),
float
),
'force_of_atoms'
:
((
tot_atoms
,
3
),
float
),
'cell_volume'
:
((
n_graph
,),
float
),
'num_atoms'
:
((
n_graph
,),
int
),
'per_atom_energy'
:
((
n_graph
,),
float
),
'stress'
:
((
n_graph
,
6
),
float
),
'batch'
:
((
tot_atoms
,),
int
),
# from PyG
}
for
k
,
(
shape
,
dtype
)
in
essential
.
items
():
assert
k
in
graph
,
f
'
{
k
}
missing in graph'
assert
isinstance
(
graph
[
k
],
torch
.
Tensor
),
f
'
{
k
}
:
{
type
(
graph
[
k
])
}
is not an tensor'
assert
graph
[
k
].
is_floating_point
()
==
(
dtype
is
float
)
assert
graph
[
k
].
shape
==
shape
,
f
'
{
k
}
shape
{
graph
[
k
].
shape
}
!=
{
shape
}
'
@
pytest
.
mark
.
parametrize
(
'atoms_type'
,
[
'bulk'
,
'mol'
,
'isolated'
,
'small_bulk'
])
def
test_graph_build_ase_and_matscipy
(
atoms_type
):
atoms
,
_
=
get_atoms
(
atoms_type
,
'calc'
)
atoms
.
rattle
()
pos
=
atoms
.
get_positions
()
cell
=
np
.
array
(
atoms
.
get_cell
())
pbc
=
atoms
.
get_pbc
()
# graph build check
# ase graph build
edge_src_ase
,
edge_dst_ase
,
edge_vec_ase
,
shifts_ase
=
dl
.
_graph_build_ase
(
cutoff
,
pbc
,
cell
,
pos
)
# matscipy graph build
edge_src_matsci
,
edge_dst_matsci
,
edge_vec_matsci
,
shifts_matsci
=
(
dl
.
_graph_build_matscipy
(
cutoff
,
pbc
,
cell
,
pos
)
)
# sort the graph
sorted_indices_ase
=
np
.
lexsort
(
(
edge_vec_ase
[:,
2
],
edge_vec_ase
[:,
1
],
edge_vec_ase
[:,
0
])
)
sorted_indices_matsci
=
np
.
lexsort
(
(
edge_vec_matsci
[:,
2
],
edge_vec_matsci
[:,
1
],
edge_vec_matsci
[:,
0
])
)
sorted_vec_ase
=
edge_vec_ase
[
sorted_indices_ase
]
sorted_vec_matsci
=
edge_vec_matsci
[
sorted_indices_matsci
]
sorted_src_ase
=
edge_src_ase
[
sorted_indices_ase
]
sorted_dst_ase
=
edge_dst_ase
[
sorted_indices_ase
]
sorted_src_matsci
=
edge_src_matsci
[
sorted_indices_matsci
]
sorted_dst_matsci
=
edge_dst_matsci
[
sorted_indices_matsci
]
sorted_shift_ase
=
shifts_ase
[
sorted_indices_ase
]
sorted_shift_matsci
=
shifts_matsci
[
sorted_indices_matsci
]
# compare the result
assert
np
.
allclose
(
sorted_vec_ase
,
sorted_vec_matsci
)
assert
np
.
array_equal
(
sorted_src_ase
,
sorted_src_matsci
)
assert
np
.
array_equal
(
sorted_dst_ase
,
sorted_dst_matsci
)
assert
np
.
array_equal
(
sorted_shift_ase
,
sorted_shift_matsci
)
# energy test
model
,
_
=
model_from_checkpoint
(
pretrained_name_to_path
(
'7net-0_11July2024'
))
model
.
eval
()
model
.
set_is_batch_data
(
False
)
# for ase energy
edge_idx_ase
=
np
.
array
([
edge_src_ase
,
edge_dst_ase
])
atomic_numbers
=
atoms
.
get_atomic_numbers
()
cell
=
np
.
array
(
cell
)
vol
=
dl
.
_correct_scalar
(
atoms
.
cell
.
volume
)
if
vol
==
0
:
vol
=
np
.
array
(
np
.
finfo
(
float
).
eps
)
data_ase
=
{
KEY
.
NODE_FEATURE
:
atomic_numbers
,
KEY
.
ATOMIC_NUMBERS
:
atomic_numbers
,
KEY
.
POS
:
pos
,
KEY
.
EDGE_IDX
:
edge_idx_ase
,
KEY
.
EDGE_VEC
:
edge_vec_ase
,
KEY
.
CELL
:
cell
,
KEY
.
CELL_SHIFT
:
shifts_ase
,
KEY
.
CELL_VOLUME
:
vol
,
KEY
.
NUM_ATOMS
:
dl
.
_correct_scalar
(
len
(
atomic_numbers
)),
}
data_ase
[
KEY
.
INFO
]
=
{}
atom_graph_data_ase
=
AtomGraphData
.
from_numpy_dict
(
data_ase
)
output_ase
=
model
(
atom_graph_data_ase
)
ase_pred_energy
=
output_ase
[
KEY
.
PRED_TOTAL_ENERGY
]
ase_pred_force
=
output_ase
[
KEY
.
PRED_FORCE
]
ase_pred_stress
=
output_ase
[
KEY
.
PRED_STRESS
]
# for matsci energy
edge_idx_matsci
=
np
.
array
([
edge_src_matsci
,
edge_dst_matsci
])
atomic_numbers
=
atoms
.
get_atomic_numbers
()
cell
=
np
.
array
(
cell
)
vol
=
dl
.
_correct_scalar
(
atoms
.
cell
.
volume
)
if
vol
==
0
:
vol
=
np
.
array
(
np
.
finfo
(
float
).
eps
)
data_matsci
=
{
KEY
.
NODE_FEATURE
:
atomic_numbers
,
KEY
.
ATOMIC_NUMBERS
:
atomic_numbers
,
KEY
.
POS
:
pos
,
KEY
.
EDGE_IDX
:
edge_idx_matsci
,
KEY
.
EDGE_VEC
:
edge_vec_matsci
,
KEY
.
CELL
:
cell
,
KEY
.
CELL_SHIFT
:
shifts_matsci
,
KEY
.
CELL_VOLUME
:
vol
,
KEY
.
NUM_ATOMS
:
dl
.
_correct_scalar
(
len
(
atomic_numbers
)),
}
data_matsci
[
KEY
.
INFO
]
=
{}
atom_graph_data_matsci
=
AtomGraphData
.
from_numpy_dict
(
data_matsci
)
output_matsci
=
model
(
atom_graph_data_matsci
)
matsci_pred_energy
=
output_matsci
[
KEY
.
PRED_TOTAL_ENERGY
]
matsci_pred_force
=
output_matsci
[
KEY
.
PRED_FORCE
]
matsci_pred_stress
=
output_matsci
[
KEY
.
PRED_STRESS
]
assert
torch
.
equal
(
ase_pred_energy
,
matsci_pred_energy
)
assert
torch
.
allclose
(
ase_pred_force
,
matsci_pred_force
,
atol
=
1e-06
)
assert
torch
.
allclose
(
ase_pred_stress
,
matsci_pred_stress
)
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py
0 → 100644
View file @
b75ed73c
# test_errors: error recorder.py, loss.py
from
copy
import
deepcopy
import
numpy
as
np
import
pytest
import
torch
import
torch.nn
from
torch
import
tensor
import
sevenn.error_recorder
as
erc
import
sevenn.train.loss
as
loss
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.train.optim
import
loss_dict
_default_config
=
{
'loss'
:
'mse'
,
'loss_param'
:
{},
'error_record'
:
[
(
'Energy'
,
'RMSE'
),
(
'Force'
,
'RMSE'
),
(
'Stress'
,
'RMSE'
),
(
'Energy'
,
'MAE'
),
(
'Force'
,
'MAE'
),
(
'Stress'
,
'MAE'
),
(
'TotalLoss'
,
'None'
),
],
'is_train_stress'
:
True
,
'force_loss_weight'
:
1.0
,
'stress_loss_weight'
:
0.001
,
}
_erc_test_params
=
[
(
'TotalEnergy'
,
4
,
3
),
(
'Energy'
,
4
,
3
),
(
'Force'
,
4
,
3
),
(
'Stress'
,
4
,
3
),
(
'Stress_GPa'
,
4
,
3
),
(
'Energy'
,
4
,
1
),
(
'Energy'
,
1
,
1
),
(
'Force'
,
1
,
3
),
(
'Stress'
,
1
,
3
),
]
def
acl
(
a
,
b
):
return
torch
.
allclose
(
a
,
b
,
atol
=
1e-6
)
def
config
(
**
overwrite
):
# to make it read-only
cf
=
deepcopy
(
_default_config
)
for
k
,
v
in
overwrite
.
items
():
cf
[
k
]
=
v
return
cf
def
test_per_atom_energy_loss
():
loss_f
=
loss
.
PerAtomEnergyLoss
(
criterion
=
torch
.
nn
.
MSELoss
())
ref
=
torch
.
rand
(
2
)
pred
=
torch
.
rand
(
2
)
natoms
=
torch
.
randint
(
1
,
10
,
(
2
,))
tmp
=
AtomGraphData
(
total_energy
=
ref
,
inferred_total_energy
=
pred
,
num_atoms
=
natoms
,
).
to_dict
()
ret
=
loss_f
.
get_loss
(
tmp
)
assert
loss_f
.
criterion
is
not
None
assert
torch
.
allclose
(
loss_f
.
criterion
((
ref
/
natoms
),
(
pred
/
natoms
)),
ret
)
def
test_force_loss
():
loss_f
=
loss
.
ForceLoss
(
criterion
=
torch
.
nn
.
MSELoss
())
ref
=
torch
.
rand
((
4
,
3
))
pred
=
torch
.
rand
((
4
,
3
))
batch
=
tensor
([
0
,
0
,
0
,
1
])
tmp
=
AtomGraphData
(
force_of_atoms
=
ref
,
inferred_force
=
pred
,
batch
=
batch
,
).
to_dict
()
ret
=
loss_f
.
get_loss
(
tmp
)
assert
loss_f
.
criterion
is
not
None
assert
torch
.
allclose
(
loss_f
.
criterion
(
ref
.
reshape
(
-
1
),
pred
.
reshape
(
-
1
)),
ret
)
def
test_stress_loss
():
loss_f
=
loss
.
StressLoss
(
criterion
=
torch
.
nn
.
MSELoss
())
ref
=
torch
.
rand
((
2
,
6
))
pred
=
torch
.
rand
((
2
,
6
))
tmp
=
AtomGraphData
(
stress
=
ref
,
inferred_stress
=
pred
,
).
to_dict
()
ret
=
loss_f
.
get_loss
(
tmp
)
KB
=
1602.1766208
assert
loss_f
.
criterion
is
not
None
assert
torch
.
allclose
(
loss_f
.
criterion
(
ref
.
reshape
(
-
1
)
*
KB
,
pred
.
reshape
(
-
1
)
*
KB
),
ret
)
@
pytest
.
mark
.
parametrize
(
'conf'
,
[
config
(),
config
(
is_train_stress
=
False
)])
def
test_loss_from_config
(
conf
):
loss_functions
=
loss
.
get_loss_functions_from_config
(
conf
)
if
conf
[
'is_train_stress'
]:
assert
len
(
loss_functions
)
==
3
else
:
assert
len
(
loss_functions
)
==
2
for
loss_def
,
w
in
loss_functions
:
assert
isinstance
(
loss_def
,
loss
.
LossDefinition
)
if
isinstance
(
loss_def
,
loss
.
PerAtomEnergyLoss
):
assert
w
==
1.0
elif
isinstance
(
loss_def
,
loss
.
ForceLoss
):
assert
w
==
conf
[
'force_loss_weight'
]
elif
isinstance
(
loss_def
,
loss
.
StressLoss
):
assert
w
==
conf
[
'stress_loss_weight'
]
else
:
raise
ValueError
(
f
'Unexpected loss function:
{
loss_def
}
'
)
@
pytest
.
mark
.
parametrize
(
'err_type,ndata,natoms'
,
_erc_test_params
)
def
test_rms_error
(
err_type
,
ndata
,
natoms
):
err_dct
=
erc
.
get_err_type
(
err_type
)
err
=
erc
.
RMSError
(
**
err_dct
)
ref
=
torch
.
rand
((
ndata
,
err
.
vdim
)).
squeeze
(
1
)
pred
=
torch
.
rand
((
ndata
,
err
.
vdim
)).
squeeze
(
1
)
natoms
=
torch
.
tensor
([
natoms
]
*
ndata
)
_data
=
{
err_dct
[
'ref_key'
]:
ref
,
err_dct
[
'pred_key'
]:
pred
,
'num_atoms'
:
natoms
,
}
tmp
=
AtomGraphData
(
**
_data
)
err
.
update
(
tmp
)
_ref
=
ref
*
err
.
coeff
_pred
=
pred
*
err
.
coeff
if
'per_atom'
in
err_dct
and
err_dct
[
'per_atom'
]:
# natoms = natoms.unsqueeze(-1)
_ref
=
_ref
/
natoms
_pred
=
_pred
/
natoms
val
=
torch
.
sqrt
(((
_ref
-
_pred
)
**
2
).
sum
()
/
ndata
)
# not ndata*natoms
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
err
.
update
(
tmp
)
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
@
pytest
.
mark
.
parametrize
(
'err_type,ndata,natoms'
,
_erc_test_params
)
def
test_mae_error
(
err_type
,
ndata
,
natoms
):
err_dct
=
erc
.
get_err_type
(
err_type
)
vdim
=
err_dct
[
'vdim'
]
err
=
erc
.
MAError
(
**
err_dct
)
ref
=
torch
.
rand
((
ndata
,
vdim
)).
squeeze
(
1
)
pred
=
torch
.
rand
((
ndata
,
vdim
)).
squeeze
(
1
)
natoms
=
torch
.
tensor
([
natoms
]
*
ndata
)
_data
=
{
err_dct
[
'ref_key'
]:
ref
,
err_dct
[
'pred_key'
]:
pred
,
'num_atoms'
:
natoms
,
}
tmp
=
AtomGraphData
(
**
_data
)
err
.
update
(
tmp
)
_ref
=
ref
*
err
.
coeff
_pred
=
pred
*
err
.
coeff
if
'per_atom'
in
err_dct
and
err_dct
[
'per_atom'
]:
_ref
/=
natoms
_pred
/=
natoms
val
=
abs
(
_ref
-
_pred
).
sum
()
/
(
ndata
*
vdim
)
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
err
.
update
(
tmp
)
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
# TODO: test_component_rms_error
@
pytest
.
mark
.
parametrize
(
'err_type,ndata,natoms'
,
_erc_test_params
)
def
test_custom_error
(
err_type
,
ndata
,
natoms
):
def
func
(
a
,
b
):
return
a
*
b
err_dct
=
erc
.
get_err_type
(
err_type
)
vdim
=
err_dct
[
'vdim'
]
err
=
erc
.
CustomError
(
func
,
**
err_dct
)
ref
=
torch
.
rand
((
ndata
,
vdim
)).
squeeze
(
1
)
pred
=
torch
.
rand
((
ndata
,
vdim
)).
squeeze
(
1
)
natoms
=
torch
.
tensor
([
natoms
]
*
ndata
)
_data
=
{
err_dct
[
'ref_key'
]:
ref
,
err_dct
[
'pred_key'
]:
pred
,
'num_atoms'
:
natoms
,
}
_ref
=
ref
*
err
.
coeff
_pred
=
pred
*
err
.
coeff
if
'per_atom'
in
err_dct
and
err_dct
[
'per_atom'
]:
_ref
/=
natoms
_pred
/=
natoms
tmp
=
AtomGraphData
(
**
_data
)
err
.
update
(
tmp
)
val
=
func
(
_ref
,
_pred
).
mean
()
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
err
.
update
(
tmp
)
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
@
pytest
.
mark
.
parametrize
(
'conf'
,
[
config
(),
config
(
is_train_stress
=
False
)])
def
test_total_loss_metric_from_config
(
conf
):
def
func
(
a
,
b
):
return
a
*
b
err
=
erc
.
ErrorRecorder
.
init_total_loss_metric
(
conf
,
func
)
ndata
=
3
natoms
=
4
e1
,
e2
=
torch
.
rand
(
ndata
),
torch
.
rand
(
ndata
)
f1
,
f2
=
torch
.
rand
(
ndata
*
natoms
,
3
),
torch
.
rand
(
ndata
*
natoms
,
3
)
s1
,
s2
=
torch
.
rand
((
ndata
,
6
)),
torch
.
rand
((
ndata
,
6
))
_data
=
{
'total_energy'
:
e1
,
'inferred_total_energy'
:
e2
,
'force_of_atoms'
:
f1
,
'inferred_force'
:
f2
,
'stress'
:
s1
,
'inferred_stress'
:
s2
,
'num_atoms'
:
torch
.
tensor
([
natoms
]
*
ndata
),
}
tmp
=
AtomGraphData
(
**
_data
)
err
.
update
(
tmp
)
val
=
(
func
(
e1
/
natoms
,
e2
/
natoms
)).
mean
()
+
conf
[
'force_loss_weight'
]
*
func
(
f1
,
f2
).
mean
()
if
conf
[
'is_train_stress'
]:
KB
=
1602.1766208
val
+=
conf
[
'stress_loss_weight'
]
*
func
(
s1
*
KB
,
s2
*
KB
).
mean
()
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
err
.
update
(
tmp
)
assert
np
.
allclose
(
err
.
get
(),
val
.
item
())
@
pytest
.
mark
.
parametrize
(
'conf'
,
[
config
(),
config
(
is_train_stress
=
False
),
config
(
loss
=
'huber'
)]
)
def
test_error_recorder_from_config
(
conf
):
recorder
=
erc
.
ErrorRecorder
.
from_config
(
conf
)
total_loss_flag
=
False
for
metric
in
recorder
.
metrics
:
if
conf
[
'is_train_stress'
]
is
False
:
assert
'stress'
not
in
metric
.
name
if
metric
.
name
==
'TotalLoss'
:
total_loss_flag
=
True
for
loss_metric
,
_
in
metric
.
metrics
:
# type: ignore
assert
isinstance
(
loss_metric
.
func
,
loss_dict
[
conf
[
'loss'
]])
assert
total_loss_flag
@
pytest
.
mark
.
parametrize
(
'conf'
,
[
config
(),
config
(
is_train_stress
=
False
),
config
(
loss
=
'huber'
)]
)
def
test_error_recorder_from_config_and_loss_functions
(
conf
):
loss_functions
=
loss
.
get_loss_functions_from_config
(
conf
)
recorder
=
erc
.
ErrorRecorder
.
from_config
(
conf
,
loss_functions
)
total_loss_flag
=
False
for
metric
in
recorder
.
metrics
:
if
conf
[
'is_train_stress'
]
is
False
:
assert
'stress'
not
in
metric
.
name
if
metric
.
name
==
'TotalLoss'
:
total_loss_flag
=
True
for
loss_metric
,
_
in
metric
.
metrics
:
# type: ignore
assert
isinstance
(
loss_metric
.
loss_def
.
criterion
,
loss_dict
[
conf
[
'loss'
]]
)
assert
total_loss_flag
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py
0 → 100644
View file @
b75ed73c
# # deploy is test on lammps
# test append modality
# from no modality model to modality yes model
# from modality model to more modality model
# different shift scale settings
# test modality options (check num param)
# calculators with modality
import
copy
# + modal checkpoint continue and test_train
# + sevenn_cp test things in test_cli
import
pathlib
import
pytest
from
ase.build
import
bulk
import
sevenn.train.graph_dataset
as
graph_ds
import
sevenn.util
as
util
from
sevenn.calculator
import
SevenNetCalculator
from
sevenn.model_build
import
build_E3_equivariant_model
cutoff
=
5.0
data_root
=
(
pathlib
.
Path
(
__file__
).
parent
.
parent
/
'data'
).
resolve
()
hfo2_path
=
str
(
data_root
/
'systems'
/
'hfo2.extxyz'
)
sevennet_0_path
=
util
.
pretrained_name_to_path
(
'7net-0_11July2024'
)
@
pytest
.
fixture
(
scope
=
'module'
)
def
graph_dataset_path
(
tmp_path_factory
):
gd_path
=
tmp_path_factory
.
mktemp
(
'gd'
)
ds
=
graph_ds
.
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
str
(
gd_path
),
files
=
[
hfo2_path
],
processed_name
=
'tmp.pt'
)
return
ds
.
processed_paths
[
0
]
_modal_cfg
=
{
'use_modal_node_embedding'
:
False
,
'use_modal_self_inter_intro'
:
True
,
'use_modal_self_inter_outro'
:
True
,
'use_modal_output_block'
:
True
,
'use_modality'
:
True
,
'use_modal_wise_shift'
:
True
,
# T/F should be tested
'use_modal_wise_scale'
:
False
,
# T/F should be tested
'load_trainset_path'
:
[
{
'data_modality'
:
'modal_new'
,
'file_list'
:
[{
'file'
:
hfo2_path
}],
}
],
}
@
pytest
.
fixture
(
scope
=
'module'
)
def
snet_0_cp
():
return
util
.
load_checkpoint
(
sevennet_0_path
)
@
pytest
.
fixture
(
scope
=
'module'
)
def
snet_0_calc
():
return
SevenNetCalculator
()
@
pytest
.
fixture
()
def
bulk_atoms
():
atoms
=
bulk
(
'Si'
)
*
3
atoms
.
rattle
()
return
atoms
def
assert_atoms
(
atoms1
,
atoms2
,
rtol
=
1e-5
,
atol
=
1e-6
):
import
numpy
as
np
def
acl
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
):
return
np
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
assert
len
(
atoms1
)
==
len
(
atoms2
)
assert
acl
(
atoms1
.
get_cell
(),
atoms2
.
get_cell
())
assert
acl
(
atoms1
.
get_potential_energy
(),
atoms2
.
get_potential_energy
())
assert
acl
(
atoms1
.
get_forces
(),
atoms2
.
get_forces
(),
rtol
*
10
,
atol
*
10
)
assert
acl
(
atoms1
.
get_stress
(
voigt
=
False
),
atoms2
.
get_stress
(
voigt
=
False
),
rtol
*
10
,
atol
*
10
,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
def
get_modal_cfg
(
overwrite
=
None
):
modal_cfg
=
copy
.
deepcopy
(
_modal_cfg
).
copy
()
if
overwrite
:
modal_cfg
.
update
(
overwrite
)
return
modal_cfg
@
pytest
.
mark
.
parametrize
(
'cfg_overwrite'
,
[
({}),
({
'use_modal_wise_scale'
:
True
}),
({
'use_modal_wise_shift'
:
False
}),
({
'use_modal_self_inter_intro'
:
False
}),
],
)
def
test_append_modal_sevennet_0
(
cfg_overwrite
,
snet_0_cp
,
snet_0_calc
,
bulk_atoms
,
graph_dataset_path
,
tmp_path
,
):
modal_cfg
=
snet_0_cp
.
config
modal_cfg
.
pop
(
'load_dataset_path'
)
modal_cfg
.
pop
(
'load_validset_path'
)
modal_cfg
.
update
(
get_modal_cfg
(
cfg_overwrite
))
modal_cfg
[
'shift'
]
=
'elemwise_reference_energies'
modal_cfg
[
'scale'
]
=
'per_atom_energy_std'
modal_cfg
[
'load_trainset_path'
][
0
][
'file_list'
]
=
[{
'file'
:
graph_dataset_path
}]
new_state_dict
=
snet_0_cp
.
append_modal
(
modal_cfg
,
original_modal_name
=
'pbe'
,
working_dir
=
tmp_path
)
sevennet_0_w_modal
=
build_E3_equivariant_model
(
modal_cfg
)
sevennet_0_w_modal
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
atoms1
=
bulk_atoms
atoms2
=
copy
.
deepcopy
(
atoms1
)
atoms1
.
calc
=
snet_0_calc
atoms2
.
calc
=
SevenNetCalculator
(
model
=
sevennet_0_w_modal
,
file_type
=
'model_instance'
,
modal
=
'pbe'
)
assert_atoms
(
atoms1
,
atoms2
)
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py
0 → 100644
View file @
b75ed73c
import
pytest
import
torch
from
ase.build
import
bulk
,
molecule
from
ase.data
import
chemical_symbols
from
torch_geometric.loader.dataloader
import
Collater
import
sevenn.train.dataload
as
dl
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.model_build
import
build_E3_equivariant_model
from
sevenn.nn.sequential
import
AtomGraphSequential
from
sevenn.util
import
chemical_species_preprocess
cutoff
=
4.0
_samples
=
{
'bulk'
:
bulk
(
'NaCl'
,
'rocksalt'
,
a
=
5.63
),
'mol'
:
molecule
(
'H2O'
),
'isolated'
:
molecule
(
'H'
),
}
n_samples
=
len
(
_samples
)
n_atoms_total
=
sum
([
len
(
at
)
for
at
in
_samples
.
values
()])
_graph_list
=
[
AtomGraphData
.
from_numpy_dict
(
dl
.
unlabeled_atoms_to_graph
(
at
,
cutoff
))
for
at
in
list
(
_samples
.
values
())
]
def
test_chemical_species_preprocess
():
chems
=
[
'He'
,
'H'
,
'Be'
,
'H'
]
cf
=
chemical_species_preprocess
(
chems
,
universal
=
False
)
assert
cf
[
'chemical_species'
]
==
[
'Be'
,
'H'
,
'He'
]
assert
cf
[
'_number_of_species'
]
==
3
assert
cf
[
'_type_map'
]
==
{
4
:
0
,
1
:
1
,
2
:
2
}
cf
=
chemical_species_preprocess
(
chems
,
universal
=
True
)
assert
cf
[
'chemical_species'
]
==
chemical_symbols
assert
cf
[
'_number_of_species'
]
==
len
(
chemical_symbols
)
assert
len
(
cf
[
'_type_map'
])
==
len
(
chemical_symbols
)
for
z
,
node_idx
in
cf
[
'_type_map'
].
items
():
assert
z
==
node_idx
def
get_graphs
(
batched
):
cloned
=
[
g
.
clone
()
for
g
in
_graph_list
]
if
not
batched
:
return
cloned
else
:
return
Collater
(
cloned
)(
cloned
)
def
get_model_config
():
config
=
{
'cutoff'
:
cutoff
,
'channel'
:
4
,
'radial_basis'
:
{
'radial_basis_name'
:
'bessel'
,
},
'cutoff_function'
:
{
'cutoff_function_name'
:
'poly_cut'
},
'interaction_type'
:
'nequip'
,
'lmax'
:
2
,
'is_parity'
:
True
,
'num_convolution_layer'
:
3
,
'weight_nn_hidden_neurons'
:
[
64
,
64
],
'act_radial'
:
'silu'
,
'act_scalar'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'act_gate'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'conv_denominator'
:
30.0
,
'train_denominator'
:
False
,
'self_connection_type'
:
'nequip'
,
'shift'
:
-
10.0
,
'scale'
:
10.0
,
'train_shift_scale'
:
False
,
'irreps_manual'
:
False
,
'lmax_edge'
:
-
1
,
'lmax_node'
:
-
1
,
'readout_as_fcn'
:
False
,
'use_bias_in_linear'
:
False
,
'_normalize_sph'
:
True
,
}
chems
=
set
()
for
at
in
list
(
_samples
.
values
()):
chems
.
update
(
at
.
get_chemical_symbols
())
config
.
update
(
**
chemical_species_preprocess
(
list
(
chems
)))
return
config
def
get_model
(
config_overwrite
=
{}):
cf
=
get_model_config
()
cf
.
update
(
**
config_overwrite
)
model
=
build_E3_equivariant_model
(
cf
,
parallel
=
False
)
assert
isinstance
(
model
,
AtomGraphSequential
)
return
model
@
pytest
.
mark
.
parametrize
(
'batched'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'cf'
,
[{}])
def
test_shape
(
cf
,
batched
):
model
=
get_model
(
cf
)
model
.
set_is_batch_data
(
batched
)
graph
=
get_graphs
(
batched
)
if
not
batched
:
output_shapes
=
{
'inferred_total_energy'
:
(),
'inferred_stress'
:
(
6
,),
}
for
g
in
graph
:
natoms
=
g
[
'num_atoms'
]
output_shapes
.
update
(
{
'atomic_energy'
:
(
natoms
,
1
),
# intended
'inferred_force'
:
(
natoms
,
3
),
}
)
output
=
model
(
g
)
for
k
,
shape
in
output_shapes
.
items
():
assert
output
[
k
].
shape
==
shape
,
f
'
{
k
}
:
{
output
[
k
].
shape
}
!=
{
shape
}
'
else
:
output_shapes
=
{
'inferred_total_energy'
:
(
n_samples
,),
'atomic_energy'
:
(
n_atoms_total
,
1
),
# intended
'inferred_force'
:
(
n_atoms_total
,
3
),
'inferred_stress'
:
(
n_samples
,
6
),
}
output
=
model
(
graph
)
for
k
,
shape
in
output_shapes
.
items
():
assert
output
[
k
].
shape
==
shape
,
f
'
{
k
}
:
{
output
[
k
].
shape
}
!=
{
shape
}
'
def
test_batch
():
model
=
get_model
()
model
.
set_is_batch_data
(
False
)
graph_list
=
get_graphs
(
batched
=
False
)
output_list
=
[
model
(
g
)
for
g
in
graph_list
]
model
.
set_is_batch_data
(
True
)
graph_batch
=
get_graphs
(
batched
=
True
)
output_batched
=
model
(
graph_batch
)
e_concat
=
torch
.
concat
(
[
g
[
'inferred_total_energy'
].
unsqueeze
(
-
1
)
for
g
in
output_list
]
)
ae_concat
=
torch
.
concat
([
g
[
'atomic_energy'
].
squeeze
(
1
)
for
g
in
output_list
])
f_concat
=
torch
.
concat
([
g
[
'inferred_force'
]
for
g
in
output_list
])
s_concat
=
torch
.
stack
([
g
[
'inferred_stress'
]
for
g
in
output_list
])
assert
torch
.
allclose
(
e_concat
,
output_batched
[
'inferred_total_energy'
])
assert
torch
.
allclose
(
ae_concat
,
output_batched
[
'atomic_energy'
].
squeeze
(
1
))
assert
torch
.
allclose
(
torch
.
round
(
f_concat
,
decimals
=
5
),
torch
.
round
(
output_batched
[
'inferred_force'
],
decimals
=
5
),
atol
=
1e-5
,
)
assert
torch
.
allclose
(
# TODO, hard-coded, assumes the first structure is bulk
torch
.
round
(
s_concat
[
0
],
decimals
=
5
),
torch
.
round
(
output_batched
[
'inferred_stress'
][
0
],
decimals
=
5
),
)
_n_param_tests
=
[
({},
20642
),
({
'train_denominator'
:
True
},
20642
+
3
),
({
'train_shift_scale'
:
True
},
20642
+
2
),
({
'shift'
:
[
1.0
]
*
4
},
20642
),
({
'scale'
:
[
1.0
]
*
4
,
'train_shift_scale'
:
True
},
20642
+
8
),
({
'num_convolution_layer'
:
4
},
33458
),
({
'lmax'
:
3
},
26866
),
({
'channel'
:
2
},
16883
),
({
'is_parity'
:
False
},
20386
),
({
'self_connection_type'
:
'linear'
},
20114
),
]
@
pytest
.
mark
.
parametrize
(
'cf,ref'
,
_n_param_tests
)
def
test_num_params
(
cf
,
ref
):
model
=
get_model
(
cf
)
param
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
])
assert
param
==
ref
,
f
'ref:
{
ref
}
!= given:
{
param
}
'
_n_modal_param_tests
=
[
({},
20642
),
({
'use_modal_node_embedding'
:
True
},
20642
+
8
),
({
'use_modal_self_inter_intro'
:
True
},
20642
+
2
*
4
*
3
),
({
'use_modal_self_inter_outro'
:
True
},
20642
+
2
*
(
12
+
20
+
4
)),
({
'use_modal_output_block'
:
True
},
20642
+
2
*
4
/
2
),
]
@
pytest
.
mark
.
parametrize
(
'cf,ref'
,
_n_modal_param_tests
)
def
test_modal_num_params
(
cf
,
ref
):
modal_cfg
=
{
'use_modality'
:
True
,
'_number_of_modalities'
:
2
,
'_modal_map'
:
{
'x1'
:
0
,
'x2'
:
1
},
'use_modal_node_embedding'
:
False
,
'use_modal_self_inter_intro'
:
False
,
'use_modal_self_inter_outro'
:
False
,
'use_modal_output_block'
:
False
,
'use_modal_wise_shift'
:
False
,
'use_modal_wise_scale'
:
False
,
}
modal_cfg
.
update
(
cf
)
model
=
get_model
(
modal_cfg
)
param
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
])
assert
param
==
ref
,
f
'ref:
{
ref
}
!= given:
{
param
}
'
# TODO: test_irreps, test_gard, test_equivariance
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py
0 → 100644
View file @
b75ed73c
# test_pretrained: output consistency for pretrained models
import
pytest
import
torch
from
ase.build
import
bulk
,
molecule
import
sevenn._keys
as
KEY
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.train.dataload
import
unlabeled_atoms_to_graph
from
sevenn.util
import
model_from_checkpoint
,
pretrained_name_to_path
def
acl
(
a
,
b
,
atol
=
1e-6
):
return
torch
.
allclose
(
a
,
b
,
atol
=
atol
)
@
pytest
.
fixture
def
atoms_pbc
():
atoms1
=
bulk
(
'NaCl'
,
'rocksalt'
,
a
=
5.63
)
atoms1
.
set_cell
([[
1.0
,
2.815
,
2.815
],
[
2.815
,
0.0
,
2.815
],
[
2.815
,
2.815
,
0.0
]])
atoms1
.
set_positions
([[
0.0
,
0.0
,
0.0
],
[
2.815
,
0.0
,
0.0
]])
return
atoms1
@
pytest
.
fixture
def
atoms_mol
():
atoms2
=
molecule
(
'H2O'
)
atoms2
.
set_positions
([[
0.0
,
0.2
,
0.12
],
[
0.0
,
0.76
,
-
0.48
],
[
0.0
,
-
0.76
,
-
0.48
]])
return
atoms2
def
test_7net0_22May2024
(
atoms_pbc
,
atoms_mol
):
"""
Reference from v0.9.3.post1 with SevenNetCalculator
"""
cp_path
=
pretrained_name_to_path
(
'7net-0_22May2024'
)
model
,
config
=
model_from_checkpoint
(
cp_path
)
cutoff
=
config
[
'cutoff'
]
g1
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_pbc
,
cutoff
))
g2
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_mol
,
cutoff
))
model
.
set_is_batch_data
(
False
)
g1
=
model
(
g1
)
g2
=
model
(
g2
)
g1_ref_e
=
torch
.
tensor
([
-
3.4140868186950684
])
g1_ref_f
=
torch
.
tensor
(
[
[
1.2628037e01
,
7.5093508e-03
,
1.3480943e-02
],
[
-
1.2628037e01
,
-
7.5093508e-03
,
-
1.3480917e-02
],
]
)
g1_ref_s
=
-
1
*
torch
.
tensor
(
[
-
0.65014917
,
-
0.01990843
,
-
0.02000658
,
0.03286226
,
0.00589222
,
0.03291973
]
)
g2_ref_e
=
torch
.
tensor
([
-
12.808363914489746
])
g2_ref_f
=
torch
.
tensor
(
[
[
9.31322575e-10
,
-
1.30241165e01
,
6.93116236e00
],
[
-
1.39698386e-09
,
9.28001022e00
,
-
9.51867390e00
],
[
5.23868948e-10
,
3.74410582e00
,
2.58751225e00
],
]
)
assert
acl
(
g1
.
inferred_total_energy
,
g1_ref_e
)
assert
acl
(
g1
.
inferred_force
,
g1_ref_f
)
assert
acl
(
g1
.
inferred_stress
,
g1_ref_s
)
assert
acl
(
g2
.
inferred_total_energy
,
g2_ref_e
)
assert
acl
(
g2
.
inferred_force
,
g2_ref_f
)
def
test_7net0_11July2024
(
atoms_pbc
,
atoms_mol
):
"""
Reference from v0.9.3.post1 with SevenNetCalculator
"""
cp_path
=
pretrained_name_to_path
(
'7net-0_11July2024'
)
model
,
config
=
model_from_checkpoint
(
cp_path
)
cutoff
=
config
[
'cutoff'
]
g1
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_pbc
,
cutoff
))
g2
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_mol
,
cutoff
))
model
.
set_is_batch_data
(
False
)
g1
=
model
(
g1
)
g2
=
model
(
g2
)
model
.
set_is_batch_data
(
True
)
g1_ref_e
=
torch
.
tensor
([
-
3.779199
])
g1_ref_f
=
torch
.
tensor
(
[
[
12.666697
,
0.04726403
,
0.04775861
],
[
-
12.666697
,
-
0.04726403
,
-
0.04775861
],
]
)
g1_ref_s
=
-
1
*
torch
.
tensor
(
# xx, yy, zz, xy, yz, zx
[
-
0.6439122
,
-
0.03643947
,
-
0.03643981
,
0.04543639
,
0.00599139
,
0.04544507
]
)
g2_ref_e
=
torch
.
tensor
([
-
12.782808303833008
])
g2_ref_f
=
torch
.
tensor
(
[
[
0.0
,
-
1.3619621e01
,
7.5937047e00
],
[
0.0
,
9.3918495e00
,
-
1.0172190e01
],
[
0.0
,
4.2277718e00
,
2.5784855e00
],
]
)
assert
acl
(
g1
.
inferred_total_energy
,
g1_ref_e
)
assert
acl
(
g1
.
inferred_force
,
g1_ref_f
)
assert
acl
(
g1
.
inferred_stress
,
g1_ref_s
)
assert
acl
(
g2
.
inferred_total_energy
,
g2_ref_e
)
assert
acl
(
g2
.
inferred_force
,
g2_ref_f
)
def
test_7net_l3i5
(
atoms_pbc
,
atoms_mol
):
"""
Reference from v0.9.3.post1 with SevenNetCalculator
"""
cp_path
=
pretrained_name_to_path
(
'7net-l3i5'
)
model
,
config
=
model_from_checkpoint
(
cp_path
)
cutoff
=
config
[
'cutoff'
]
g1
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_pbc
,
cutoff
))
g2
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_mol
,
cutoff
))
model
.
set_is_batch_data
(
False
)
g1
=
model
(
g1
)
g2
=
model
(
g2
)
model
.
set_is_batch_data
(
True
)
g1_ref_e
=
torch
.
tensor
([
-
3.611131191253662
])
g1_ref_f
=
torch
.
tensor
(
[
[
13.430887
,
0.08655541
,
0.08754013
],
[
-
13.430886
,
-
0.08655544
,
-
0.08754011
],
]
)
g1_ref_s
=
-
1
*
torch
.
tensor
(
# xx, yy, zz, xy, yz, zx
[
-
0.6818918
,
-
0.04104544
,
-
0.04107663
,
0.04794561
,
0.00565416
,
0.04793138
]
)
g2_ref_e
=
torch
.
tensor
([
-
12.700481414794922
])
g2_ref_f
=
torch
.
tensor
(
[
[
0.0
,
-
1.4547814e01
,
8.1347866
],
[
0.0
,
1.0308369e01
,
-
1.0880318e01
],
[
0.0
,
4.2394452
,
2.7455316
],
]
)
assert
acl
(
g1
.
inferred_total_energy
,
g1_ref_e
)
assert
acl
(
g1
.
inferred_force
,
g1_ref_f
,
1e-5
)
assert
acl
(
g1
.
inferred_stress
,
g1_ref_s
,
1e-5
)
assert
acl
(
g2
.
inferred_total_energy
,
g2_ref_e
)
assert
acl
(
g2
.
inferred_force
,
g2_ref_f
)
def
test_7net_mf_0
(
atoms_pbc
,
atoms_mol
):
cp_path
=
pretrained_name_to_path
(
'7net-mf-0'
)
model
,
config
=
model_from_checkpoint
(
cp_path
)
cutoff
=
config
[
'cutoff'
]
g1
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_pbc
,
cutoff
))
g2
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_mol
,
cutoff
))
g1
[
KEY
.
DATA_MODALITY
]
=
'R2SCAN'
g2
[
KEY
.
DATA_MODALITY
]
=
'R2SCAN'
model
.
set_is_batch_data
(
False
)
g1
=
model
(
g1
)
g2
=
model
(
g2
)
model
.
set_is_batch_data
(
True
)
g1_ref_e
=
torch
.
tensor
([
-
11.607587814331055
])
g1_ref_f
=
torch
.
tensor
(
[
[
8.512259
,
0.07307914
,
0.06676716
],
[
-
8.512257
,
-
0.07307915
,
-
0.06676716
],
]
)
g1_ref_s
=
-
1
*
torch
.
tensor
(
# xx, yy, zz, xy, yz, zx
[
-
0.4516204
,
-
0.02483013
,
-
0.02485001
,
0.03247492
,
0.00259375
,
0.03250402
]
)
g2_ref_e
=
torch
.
tensor
([
-
14.172412872314453
])
g2_ref_f
=
torch
.
tensor
(
[
[
4.6566129e-10
,
-
1.3429364e01
,
6.9344816e00
],
[
2.3283064e-09
,
8.9132404e00
,
-
9.6807365e00
],
[
-
2.7939677e-09
,
4.5161238e00
,
2.7462559e00
],
]
)
assert
acl
(
g1
.
inferred_total_energy
,
g1_ref_e
)
assert
acl
(
g1
.
inferred_force
,
g1_ref_f
)
assert
acl
(
g1
.
inferred_stress
,
g1_ref_s
)
assert
acl
(
g2
.
inferred_total_energy
,
g2_ref_e
)
assert
acl
(
g2
.
inferred_force
,
g2_ref_f
)
def
test_7net_mf_ompa_mpa
(
atoms_pbc
,
atoms_mol
):
cp_path
=
pretrained_name_to_path
(
'7net-mf-ompa'
)
model
,
config
=
model_from_checkpoint
(
cp_path
)
cutoff
=
config
[
'cutoff'
]
g1
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_pbc
,
cutoff
))
g2
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_mol
,
cutoff
))
# mpa
g1
[
KEY
.
DATA_MODALITY
]
=
'mpa'
g2
[
KEY
.
DATA_MODALITY
]
=
'mpa'
model
.
set_is_batch_data
(
False
)
g1
=
model
(
g1
)
g2
=
model
(
g2
)
model
.
set_is_batch_data
(
True
)
g1_ref_e
=
torch
.
tensor
([
-
3.490943193435669
])
g1_ref_f
=
torch
.
tensor
(
[
[
1.2680445e01
,
-
2.7985498e-04
,
-
2.7979910e-04
],
[
-
1.2680446e01
,
2.7984008e-04
,
2.7981028e-04
],
]
)
g1_ref_s
=
-
1
*
torch
.
tensor
(
# xx, yy, zz, xy, yz, zx
[
-
0.6481662
,
-
0.02462837
,
-
0.02462837
,
0.02693467
,
0.00459635
,
0.02693467
]
)
g2_ref_e
=
torch
.
tensor
([
-
12.597525596618652
])
g2_ref_f
=
torch
.
tensor
(
[
[
0.0
,
-
12.245223
,
7.26795
],
[
0.0
,
8.816763
,
-
9.423925
],
[
0.0
,
3.4284601
,
2.1559749
],
]
)
assert
acl
(
g1
.
inferred_total_energy
,
g1_ref_e
)
assert
acl
(
g1
.
inferred_force
,
g1_ref_f
)
assert
acl
(
g1
.
inferred_stress
,
g1_ref_s
)
assert
acl
(
g2
.
inferred_total_energy
,
g2_ref_e
)
assert
acl
(
g2
.
inferred_force
,
g2_ref_f
)
def
test_7net_mf_ompa_omat
(
atoms_pbc
,
atoms_mol
):
cp_path
=
pretrained_name_to_path
(
'7net-mf-ompa'
)
model
,
config
=
model_from_checkpoint
(
cp_path
)
cutoff
=
config
[
'cutoff'
]
g1
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_pbc
,
cutoff
))
g2
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_mol
,
cutoff
))
# mpa
g1
[
KEY
.
DATA_MODALITY
]
=
'omat24'
g2
[
KEY
.
DATA_MODALITY
]
=
'omat24'
model
.
set_is_batch_data
(
False
)
g1
=
model
(
g1
)
g2
=
model
(
g2
)
model
.
set_is_batch_data
(
True
)
g1_ref_e
=
torch
.
tensor
([
-
3.5094668865203857
])
g1_ref_f
=
torch
.
tensor
(
[
[
1.2562084e01
,
-
1.4219694e-03
,
-
1.4219843e-03
],
[
-
1.2562084e01
,
1.4219508e-03
,
1.4219955e-03
],
]
)
g1_ref_s
=
-
1
*
torch
.
tensor
(
# xx, yy, zz, xy, yz, zx
[
-
0.6430905
,
-
0.0254128
,
-
0.02541281
,
0.0268343
,
0.00460021
,
0.0268343
]
)
g2_ref_e
=
torch
.
tensor
([
-
12.6202974319458
])
g2_ref_f
=
torch
.
tensor
(
[
[
0.0
,
-
12.205926
,
7.2050343
],
[
0.0
,
8.790399
,
-
9.368677
],
[
0.0
,
3.4155273
,
2.163643
],
]
)
assert
acl
(
g1
.
inferred_total_energy
,
g1_ref_e
)
assert
acl
(
g1
.
inferred_force
,
g1_ref_f
)
assert
acl
(
g1
.
inferred_stress
,
g1_ref_s
)
assert
acl
(
g2
.
inferred_total_energy
,
g2_ref_e
)
assert
acl
(
g2
.
inferred_force
,
g2_ref_f
)
def
test_7net_omat
(
atoms_pbc
,
atoms_mol
):
cp_path
=
pretrained_name_to_path
(
'7net-omat'
)
model
,
config
=
model_from_checkpoint
(
cp_path
)
cutoff
=
config
[
'cutoff'
]
g1
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_pbc
,
cutoff
))
g2
=
AtomGraphData
.
from_numpy_dict
(
unlabeled_atoms_to_graph
(
atoms_mol
,
cutoff
))
model
.
set_is_batch_data
(
False
)
g1
=
model
(
g1
)
g2
=
model
(
g2
)
model
.
set_is_batch_data
(
True
)
g1_ref_e
=
torch
.
tensor
([
-
3.5033323764801025
])
g1_ref_f
=
torch
.
tensor
(
[
[
12.533154
,
0.02358698
,
0.02358694
],
[
-
12.533153
,
-
0.02358699
,
-
0.02358697
],
]
)
g1_ref_s
=
-
1
*
torch
.
tensor
(
# xx, yy, zz, xy, yz, zx
[
-
0.6420925
,
-
0.02781446
,
-
0.02781446
,
0.02575445
,
0.00381664
,
0.02575445
]
)
g2_ref_e
=
torch
.
tensor
([
-
12.403768539428711
])
g2_ref_f
=
torch
.
tensor
(
[
[
0
,
-
12.848297
,
7.11432
],
[
0.0
,
9.265477
,
-
9.564951
],
[
0.0
,
3.58282
,
2.4506311
],
]
)
assert
acl
(
g1
.
inferred_total_energy
,
g1_ref_e
)
assert
acl
(
g1
.
inferred_force
,
g1_ref_f
)
assert
acl
(
g1
.
inferred_stress
,
g1_ref_s
)
assert
acl
(
g2
.
inferred_total_energy
,
g2_ref_e
)
assert
acl
(
g2
.
inferred_force
,
g2_ref_f
)
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py
0 → 100644
View file @
b75ed73c
import
pytest
import
torch
import
sevenn._keys
as
KEY
from
sevenn._const
import
NUM_UNIV_ELEMENT
,
AtomGraphDataType
from
sevenn.nn.scale
import
(
ModalWiseRescale
,
Rescale
,
SpeciesWiseRescale
,
get_resolved_shift_scale
,
)
################################################################################
# Tests for Rescale #
################################################################################
@
pytest
.
mark
.
parametrize
(
'shift,scale'
,
[(
0.0
,
1.0
),
(
1.0
,
2.0
),
(
-
5.0
,
10.0
)])
def
test_rescale_init
(
shift
,
scale
):
"""
Test that Rescale can be initialized properly without errors
and that parameters are set correctly.
"""
module
=
Rescale
(
shift
=
shift
,
scale
=
scale
)
assert
module
.
shift
.
item
()
==
shift
assert
module
.
scale
.
item
()
==
scale
assert
module
.
key_input
==
KEY
.
SCALED_ATOMIC_ENERGY
assert
module
.
key_output
==
KEY
.
ATOMIC_ENERGY
def
test_rescale_forward
():
"""
Test that Rescale forward pass correctly applies:
output = input * scale + shift
"""
# Setup
shift
,
scale
=
1.0
,
2.0
module
=
Rescale
(
shift
=
shift
,
scale
=
scale
)
# Make some fake data
input_data
=
torch
.
tensor
([[
1.0
],
[
2.0
],
[
3.0
]],
dtype
=
torch
.
float
)
data
:
AtomGraphDataType
=
{
KEY
.
SCALED_ATOMIC_ENERGY
:
input_data
.
clone
()}
# Forward
out_data
=
module
(
data
)
# Check correctness
expected_output
=
input_data
*
scale
+
shift
assert
torch
.
allclose
(
out_data
[
KEY
.
ATOMIC_ENERGY
],
expected_output
)
def
test_rescale_get_shift_and_scale
():
"""
Test get_shift() and get_scale() methods in Rescale.
"""
module
=
Rescale
(
shift
=
1.5
,
scale
=
3.5
)
assert
module
.
get_shift
()
==
pytest
.
approx
(
1.5
)
assert
module
.
get_scale
()
==
pytest
.
approx
(
3.5
)
################################################################################
# Tests for SpeciesWiseRescale #
################################################################################
def
test_specieswise_rescale_init_float
():
"""
Test SpeciesWiseRescale when both shift and scale are floats
(should expand to same length lists).
"""
module
=
SpeciesWiseRescale
(
shift
=
[
1.0
,
-
1.0
],
scale
=
2.0
)
# Expect a parameter of length = 1 in this scenario, but can differ
# if we raise an error for "Both shift and scale is not a list".
# Usually, you'd specify a known number of species or do from_mappers.
# The code as-is throws ValueError if both are float. Let's do from_mappers:
# We'll do direct init if your code allows it. If not, use from_mappers.
assert
module
.
shift
.
shape
==
module
.
scale
.
shape
# They must be single-parameter (or expanded) if not from mappers.
def
test_specieswise_rescale_init_list
():
"""
Test initialization with list-based shift/scale of same length.
"""
shift
=
[
1.0
,
2.0
,
3.0
]
scale
=
[
2.0
,
3.0
,
4.0
]
module
=
SpeciesWiseRescale
(
shift
=
shift
,
scale
=
scale
)
assert
len
(
module
.
shift
)
==
3
assert
len
(
module
.
scale
)
==
3
assert
torch
.
allclose
(
module
.
shift
,
torch
.
tensor
([
1.0
,
2.0
,
3.0
]))
assert
torch
.
allclose
(
module
.
scale
,
torch
.
tensor
([
2.0
,
3.0
,
4.0
]))
def
test_specieswise_rescale_forward
():
"""
Test that SpeciesWiseRescale forward pass applies:
output[i] = input[i]*scale[atom_type[i]] + shift[atom_type[i]]
"""
# Suppose we have two species types:
# 0 -> shift=1, scale=2, 1 -> shift=5, scale=10
# (we'll pass them as lists in the correct order)
shift
=
[
1.0
,
5.0
]
scale
=
[
2.0
,
10.0
]
module
=
SpeciesWiseRescale
(
shift
=
shift
,
scale
=
scale
,
data_key_in
=
'in'
,
data_key_out
=
'out'
,
data_key_indices
=
'z'
,
)
# Create mock data
# Suppose we have three atoms: species => [0, 1, 0]
# input => [ [1.], [1.], [3.] ]
data
:
AtomGraphDataType
=
{
'z'
:
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
long
),
'in'
:
torch
.
tensor
([[
1.0
],
[
1.0
],
[
3.0
]],
dtype
=
torch
.
float
),
}
out
=
module
(
data
)
# Now let's manually compute expected:
# For atom 0: scale=2, shift=1, input=1 => 1*2+1=3
# For atom 1: scale=10, shift=5, input=1 => 1*10+5=15
# For atom 2: scale=2, shift=1, input=3 => 3*2+1=7
expected
=
torch
.
tensor
([[
3.0
],
[
15.0
],
[
7.0
]])
assert
torch
.
allclose
(
out
[
'out'
],
expected
)
def
test_specieswise_rescale_get_shift_scale
():
"""
Test get_shift() and get_scale() with/without type_map.
"""
shift
=
[
1.0
,
2.0
]
scale
=
[
3.0
,
4.0
]
module
=
SpeciesWiseRescale
(
shift
=
shift
,
scale
=
scale
)
# Without type_map
# Should return the raw parameter values (list form).
s
=
module
.
get_shift
()
sc
=
module
.
get_scale
()
assert
s
==
[
1.0
,
2.0
]
assert
sc
==
[
3.0
,
4.0
]
# With a type_map (example: atomic_number 1 -> 0, 8 -> 1)
type_map
=
{
1
:
0
,
8
:
1
}
# hydrogen, oxygen
s_univ
=
module
.
get_shift
(
type_map
)
sc_univ
=
module
.
get_scale
(
type_map
)
# In this small example with NUM_UNIV_ELEMENT = 2, the _as_univ will produce
# a list of length = NUM_UNIV_ELEMENT. If your real NUM_UNIV_ELEMENT is bigger,
# the rest would be padded with default values.
# For demonstration let's assume it returns [1.0, 2.0].
# Check at least the known mapped portion:
assert
len
(
s_univ
)
==
NUM_UNIV_ELEMENT
assert
len
(
sc_univ
)
==
NUM_UNIV_ELEMENT
assert
s_univ
[
1
]
==
1.0
# atomic_number=1 -> idx=0 -> shift=1.0
assert
s_univ
[
8
]
==
2.0
################################################################################
# Tests for ModalWiseRescale #
################################################################################
def
test_modalwise_rescale_init
():
"""
Basic sanity check for ModalWiseRescale initialization with
certain shapes.
"""
# Suppose we have 2 modals, 3 species => shift, scale is shape [2,3]
shift
=
[[
0.0
,
1.0
,
2.0
],
[
3.0
,
4.0
,
5.0
]]
scale
=
[[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]]
module
=
ModalWiseRescale
(
shift
=
shift
,
scale
=
scale
,
use_modal_wise_shift
=
True
,
use_modal_wise_scale
=
True
,
)
# Check shape
assert
module
.
shift
.
shape
==
torch
.
Size
([
2
,
3
])
assert
module
.
scale
.
shape
==
torch
.
Size
([
2
,
3
])
def
test_modalwise_rescale_forward
():
"""
Test that the forward pass of ModalWiseRescale matches
output[i] = input[i] * scale[modal_i, atom_i] + shift[modal_i, atom_i]
when both use_modal_wise_{shift,scale} are True.
"""
shift
=
[[
0.0
,
10.0
],
[
5.0
,
15.0
]]
# shape [2 (modals), 2 (species)]
scale
=
[[
1.0
,
2.0
],
[
10.0
,
20.0
]]
module
=
ModalWiseRescale
(
shift
=
shift
,
scale
=
scale
,
data_key_in
=
'in'
,
data_key_out
=
'out'
,
data_key_modal_indices
=
'modal_idx'
,
data_key_atom_indices
=
'atom_idx'
,
use_modal_wise_shift
=
True
,
use_modal_wise_scale
=
True
,
)
data
:
AtomGraphDataType
=
{
'in'
:
torch
.
tensor
([[
1.0
],
[
1.0
],
[
2.0
],
[
2.0
]]),
'modal_idx'
:
torch
.
tensor
([
0
,
1
],
dtype
=
torch
.
long
),
'atom_idx'
:
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
long
),
'batch'
:
torch
.
tensor
([
0
,
0
,
1
,
1
],
dtype
=
torch
.
long
),
}
out
=
module
(
data
)
# i=0 => modal_idx=0, atom_idx=0 => shift=0.0, scale=1.0 => out=1*1+0=1
# i=1 => modal_idx=0, atom_idx=1 => shift=10.0, scale=2.0 => out=1*2+10=12
# i=2 => modal_idx=1, atom_idx=0 => shift=5.0, scale=10.0 => out=2*10+5=25
# i=3 => modal_idx=1, atom_idx=1 => shift=15.0, scale=20.0 => out=2*20+15=55
expected
=
torch
.
tensor
([[
1.0
],
[
12.0
],
[
25.0
],
[
55.0
]])
assert
torch
.
allclose
(
out
[
'out'
],
expected
)
def
test_modalwise_rescale_get_shift_scale
():
"""
Test get_shift() and get_scale() with type_map and modal_map.
"""
# Setup
shift
=
[[
0.0
,
10.0
],
[
5.0
,
15.0
]]
scale
=
[[
1.0
,
2.0
],
[
10.0
,
20.0
]]
mod
=
ModalWiseRescale
(
shift
=
shift
,
scale
=
scale
,
use_modal_wise_shift
=
True
,
use_modal_wise_scale
=
True
,
)
# Suppose we have type_map and modal_map
type_map
=
{
1
:
0
,
8
:
1
}
# Example: H->0, O->1
modal_map
=
{
'a'
:
0
,
'b'
:
1
}
# get_shift, get_scale
s
=
mod
.
get_shift
(
type_map
=
type_map
,
modal_map
=
modal_map
)
sc
=
mod
.
get_scale
(
type_map
=
type_map
,
modal_map
=
modal_map
)
# Expect dict with keys "ambient", "pressure".
# Example: s["ambient"] = [ shift(0,0), shift(0,1) ] mapped to H,O
# s["pressure"] = [ shift(1,0), shift(1,1) ]
assert
isinstance
(
s
,
dict
)
and
isinstance
(
sc
,
dict
)
assert
set
(
s
.
keys
())
==
{
'a'
,
'b'
}
assert
set
(
sc
.
keys
())
==
{
'a'
,
'b'
}
################################################################################
# Tests for get_resolved_shift_scale function #
################################################################################
def
test_get_resolved_shift_scale_rescale
():
"""
Test get_resolved_shift_scale for a Rescale instance.
"""
from_m
=
Rescale
(
shift
=
2.0
,
scale
=
5.0
)
shift
,
scale
=
get_resolved_shift_scale
(
from_m
)
assert
shift
==
2.0
assert
scale
==
5.0
def
test_get_resolved_shift_scale_specieswise
():
"""
Test get_resolved_shift_scale for a SpeciesWiseRescale instance.
"""
shift_list
=
[
1.0
,
2.0
]
scale_list
=
[
3.0
,
4.0
]
module
=
SpeciesWiseRescale
(
shift
=
shift_list
,
scale
=
scale_list
)
type_map
=
{
1
:
0
,
8
:
1
}
s
,
sc
=
get_resolved_shift_scale
(
module
,
type_map
=
type_map
)
# The result should be extended to NUM_UNIV_ELEMENT length in real usage,
# but at least the first few should match shift_list, scale_list mapped.
assert
isinstance
(
s
,
list
)
assert
isinstance
(
sc
,
list
)
# Check mapped values
assert
s
[
1
]
==
shift_list
[
0
]
assert
s
[
8
]
==
shift_list
[
1
]
assert
sc
[
1
]
==
scale_list
[
0
]
assert
sc
[
8
]
==
scale_list
[
1
]
def
test_get_resolved_shift_scale_modalwise
():
"""
Test get_resolved_shift_scale for a ModalWiseRescale instance.
"""
shift
=
[[
0.0
,
10.0
],
[
5.0
,
15.0
]]
scale
=
[[
1.0
,
2.0
],
[
10.0
,
20.0
]]
mmod
=
ModalWiseRescale
(
shift
=
shift
,
scale
=
scale
,
use_modal_wise_shift
=
True
,
use_modal_wise_scale
=
True
,
)
type_map
=
{
1
:
0
,
8
:
1
}
modal_map
=
{
'a'
:
0
,
'b'
:
1
}
s
,
sc
=
get_resolved_shift_scale
(
mmod
,
type_map
=
type_map
,
modal_map
=
modal_map
)
# We expect dictionaries
assert
isinstance
(
s
,
dict
)
and
isinstance
(
sc
,
dict
)
# Keys "a", "pressure"
assert
'a'
in
s
assert
'b'
in
s
# Check one example
# s["a"] => [0.0, 10.0]
# sc["a"] => [1.0, 2.0]
assert
s
[
'a'
][
1
]
==
0.0
assert
s
[
'a'
][
8
]
==
10.0
assert
sc
[
'a'
][
1
]
==
1.0
assert
sc
[
'a'
][
8
]
==
2.0
################################################################################
# Tests for from_mappers function #
################################################################################
@
pytest
.
mark
.
parametrize
(
'shift, scale, type_map, expected_shift, expected_scale'
,
[
# Both shift and scale are floats -> broadcast to each species
(
2.0
,
3.0
,
{
1
:
0
,
8
:
1
},
# e.g., H -> index 0, O -> index 1
[
2.0
,
2.0
],
# broadcast
[
3.0
,
3.0
],
),
# shift, scale are same-length lists => directly used
(
[
0.5
,
0.6
],
[
1.0
,
1.1
],
{
1
:
0
,
8
:
1
},
[
0.5
,
0.6
],
[
1.0
,
1.1
],
),
# shift, scale are entire "universal" length (NUM_UNIV_ELEMENT=118),
# but we only map out the subset for the actual species in type_map
(
[
0.1
]
*
NUM_UNIV_ELEMENT
,
[
1.1
]
*
NUM_UNIV_ELEMENT
,
{
1
:
0
,
8
:
1
},
[
0.1
,
0.1
],
[
1.1
,
1.1
],
),
# shift is a list, scale is float => shift is used directly, scale broadcast
(
[
1.0
,
2.0
],
5.0
,
{
6
:
0
,
14
:
1
},
# C -> 0, Si -> 1
[
1.0
,
2.0
],
[
5.0
,
5.0
],
),
],
)
def
test_specieswise_rescale_from_mappers
(
shift
,
scale
,
type_map
,
expected_shift
,
expected_scale
):
"""
Test SpeciesWiseRescale.from_mappers with various combinations of
shift/scale (float, list, universal list) and a given type_map.
"""
module
=
SpeciesWiseRescale
.
from_mappers
(
# type: ignore
shift
=
shift
,
scale
=
scale
,
type_map
=
type_map
,
)
# Check that the module's internal shift and scale have the correct shape
# The length must match number of species in type_map
assert
module
.
shift
.
shape
[
0
]
==
len
(
type_map
)
assert
module
.
scale
.
shape
[
0
]
==
len
(
type_map
)
# Check that the content matches expected
actual_shift
=
module
.
shift
.
detach
().
cpu
().
tolist
()
actual_scale
=
module
.
scale
.
detach
().
cpu
().
tolist
()
assert
pytest
.
approx
(
actual_shift
)
==
expected_shift
assert
pytest
.
approx
(
actual_scale
)
==
expected_scale
@
pytest
.
mark
.
parametrize
(
'shift, scale, use_modal_wise_shift, use_modal_wise_scale, '
'type_map, modal_map, expected_shift, expected_scale'
,
[
# Example 1: single float for shift/scale,
# broadcast over 2 modals and 2 species
(
1.0
,
2.0
,
True
,
# shift depends on modal
True
,
# scale depends on modal
{
1
:
0
,
8
:
1
},
{
'modA'
:
0
,
'modB'
:
1
},
# expect 2D => [2 modals x 2 species]
[[
1.0
,
1.0
],
[
1.0
,
1.0
]],
[[
2.0
,
2.0
],
[
2.0
,
2.0
]],
),
# Example 2: shift/scale are universal element-lists => use_modal=False => 1D
(
[
0.5
]
*
NUM_UNIV_ELEMENT
,
[
1.5
]
*
NUM_UNIV_ELEMENT
,
False
,
# shift is not modal-wise
False
,
# scale is not modal-wise
{
6
:
0
,
14
:
1
},
# e.g. C->0, Si->1
{
'modA'
:
0
,
'modB'
:
1
},
# 1D => length = n_atom_types(=2)
[
0.5
,
0.5
],
[
1.5
,
1.5
],
),
# Example 3: shift is dict of modals -> each is float
# => broadcast for each species
(
{
'modA'
:
0.0
,
'modB'
:
2.0
},
{
'modA'
:
1.0
,
'modB'
:
3.0
},
True
,
True
,
{
1
:
0
,
8
:
1
},
{
'modA'
:
0
,
'modB'
:
1
},
# shift => shape [2 modals, 2 species]
[[
0.0
,
0.0
],
[
2.0
,
2.0
]],
[[
1.0
,
1.0
],
[
3.0
,
3.0
]],
),
# Example 4: already in "modal-wise + species-wise" shape, direct pass
(
[[
0.0
,
10.0
],
[
5.0
,
15.0
]],
[[
1.0
,
2.0
],
[
10.0
,
20.0
]],
True
,
True
,
{
1
:
0
,
8
:
1
},
{
'modA'
:
0
,
'modB'
:
1
},
[[
0.0
,
10.0
],
[
5.0
,
15.0
]],
[[
1.0
,
2.0
],
[
10.0
,
20.0
]],
),
# Example 5: shift is a list of floats (one per modal),
# but we want modal-wise => broadcast for each species
(
[
0.0
,
10.0
],
# length=2 => same as #modals
[
1.0
,
2.0
],
True
,
True
,
{
1
:
0
,
8
:
1
},
{
'modA'
:
0
,
'modB'
:
1
},
[[
0.0
,
0.0
],
[
10.0
,
10.0
]],
[[
1.0
,
1.0
],
[
2.0
,
2.0
]],
),
],
)
def
test_modalwise_rescale_from_mappers
(
shift
,
scale
,
use_modal_wise_shift
,
use_modal_wise_scale
,
type_map
,
modal_map
,
expected_shift
,
expected_scale
,
):
"""
Test ModalWiseRescale.from_mappers for different shapes of shift/scale,
combined with type_map and modal_map.
"""
module
=
ModalWiseRescale
.
from_mappers
(
# type: ignore
shift
=
shift
,
scale
=
scale
,
use_modal_wise_shift
=
use_modal_wise_shift
,
use_modal_wise_scale
=
use_modal_wise_scale
,
type_map
=
type_map
,
modal_map
=
modal_map
,
)
# Check shape of the resulting shift, scale
# If modal-wise, we expect a 2D shape: [n_modals, n_species]
# Otherwise, a 1D shape: [n_species]
if
use_modal_wise_shift
:
assert
module
.
shift
.
dim
()
==
2
assert
module
.
shift
.
shape
[
0
]
==
len
(
modal_map
)
assert
module
.
shift
.
shape
[
1
]
==
len
(
type_map
)
else
:
assert
module
.
shift
.
dim
()
==
1
assert
module
.
shift
.
shape
[
0
]
==
len
(
type_map
)
# Similarly for scale
if
use_modal_wise_scale
:
assert
module
.
scale
.
dim
()
==
2
assert
module
.
scale
.
shape
[
0
]
==
len
(
modal_map
)
assert
module
.
scale
.
shape
[
1
]
==
len
(
type_map
)
else
:
assert
module
.
scale
.
dim
()
==
1
assert
module
.
scale
.
shape
[
0
]
==
len
(
type_map
)
# Verify the content matches our expectation
actual_shift
=
module
.
shift
.
detach
().
cpu
().
tolist
()
actual_scale
=
module
.
scale
.
detach
().
cpu
().
tolist
()
assert
actual_shift
==
expected_shift
assert
actual_scale
==
expected_scale
mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py
0 → 100644
View file @
b75ed73c
import
pathlib
import
ase.io
import
numpy
as
np
import
pytest
import
torch
from
torch_geometric.loader
import
DataLoader
import
sevenn.train.graph_dataset
as
graph_ds
from
sevenn._const
import
NUM_UNIV_ELEMENT
from
sevenn.error_recorder
import
ErrorRecorder
from
sevenn.logger
import
Logger
from
sevenn.scripts.processing_continue
import
processing_continue_v2
from
sevenn.scripts.processing_epoch
import
processing_epoch_v2
from
sevenn.train.dataload
import
graph_build
from
sevenn.train.graph_dataset
import
from_config
as
dataset_from_config
from
sevenn.train.loss
import
get_loss_functions_from_config
from
sevenn.train.trainer
import
Trainer
from
sevenn.util
import
(
chemical_species_preprocess
,
get_error_recorder
,
pretrained_name_to_path
,
)
cutoff
=
4.0
data_root
=
(
pathlib
.
Path
(
__file__
).
parent
.
parent
/
'data'
).
resolve
()
hfo2_path
=
str
(
data_root
/
'systems'
/
'hfo2.extxyz'
)
cp_0_path
=
str
(
data_root
/
'checkpoints'
/
'cp_0.pth'
)
sevennet_0_path
=
pretrained_name_to_path
(
'7net-0_11July2024'
)
known_elements
=
[
'Hf'
,
'O'
]
_elemwise_ref_energy_dct
=
{
72
:
-
17.379337
,
8
:
-
34.7499924
}
Logger
()
# init
@
pytest
.
fixture
()
def
HfO2_atoms
():
atoms
=
ase
.
io
.
read
(
hfo2_path
)
return
atoms
@
pytest
.
fixture
(
scope
=
'module'
)
def
HfO2_loader
():
atoms
=
ase
.
io
.
read
(
hfo2_path
,
index
=
':'
)
assert
isinstance
(
atoms
,
list
)
graphs
=
graph_build
(
atoms
,
cutoff
,
y_from_calc
=
True
)
return
DataLoader
(
graphs
,
batch_size
=
2
)
@
pytest
.
fixture
(
scope
=
'module'
)
def
graph_dataset_path
(
tmp_path_factory
):
gd_path
=
tmp_path_factory
.
mktemp
(
'gd'
)
ds
=
graph_ds
.
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
str
(
gd_path
),
files
=
[
hfo2_path
],
processed_name
=
'tmp.pt'
)
return
ds
.
processed_paths
[
0
]
def
get_model_config
():
config
=
{
'cutoff'
:
cutoff
,
'channel'
:
4
,
'radial_basis'
:
{
'radial_basis_name'
:
'bessel'
,
},
'cutoff_function'
:
{
'cutoff_function_name'
:
'poly_cut'
},
'interaction_type'
:
'nequip'
,
'lmax'
:
2
,
'is_parity'
:
True
,
'num_convolution_layer'
:
3
,
'weight_nn_hidden_neurons'
:
[
64
,
64
],
'act_radial'
:
'silu'
,
'act_scalar'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'act_gate'
:
{
'e'
:
'silu'
,
'o'
:
'tanh'
},
'conv_denominator'
:
'avg_num_neigh'
,
'train_denominator'
:
False
,
'self_connection_type'
:
'nequip'
,
'train_shift_scale'
:
False
,
'irreps_manual'
:
False
,
'lmax_edge'
:
-
1
,
'lmax_node'
:
-
1
,
'readout_as_fcn'
:
False
,
'use_bias_in_linear'
:
False
,
'_normalize_sph'
:
True
,
}
config
.
update
(
**
chemical_species_preprocess
(
known_elements
))
return
config
def
get_train_config
():
config
=
{
'random_seed'
:
1
,
'epoch'
:
2
,
'loss'
:
'mse'
,
'loss_param'
:
{},
'optimizer'
:
'adam'
,
'optim_param'
:
{},
'scheduler'
:
'exponentiallr'
,
'scheduler_param'
:
{
'gamma'
:
0.99
},
'force_loss_weight'
:
1.0
,
'stress_loss_weight'
:
0.1
,
'per_epoch'
:
1
,
'continue'
:
{
'checkpoint'
:
False
,
'reset_optimizer'
:
False
,
'reset_scheduler'
:
False
,
'reset_epoch'
:
False
,
},
'is_train_stress'
:
True
,
'train_shuffle'
:
True
,
'best_metric'
:
'TotalLoss'
,
'error_record'
:
[
(
'Energy'
,
'RMSE'
),
(
'Force'
,
'RMSE'
),
(
'Stress'
,
'RMSE'
),
(
'TotalLoss'
,
'None'
),
],
'use_modality'
:
False
,
'use_weight'
:
False
,
'device'
:
'cpu'
,
'is_ddp'
:
False
,
}
return
config
def
get_data_config
():
config
=
{
'batch_size'
:
2
,
'shift'
:
'per_atom_energy_mean'
,
'scale'
:
'force_rms'
,
'preprocess_num_cores'
:
1
,
'data_format_args'
:
{},
'load_trainset_path'
:
hfo2_path
,
}
return
config
def
get_config
(
overwrite
=
None
):
cf
=
{}
cf
.
update
(
get_model_config
())
cf
.
update
(
get_train_config
())
cf
.
update
(
get_data_config
())
if
overwrite
:
cf
.
update
(
overwrite
)
return
cf
def
test_processing_continue_v2_7net0
(
tmp_path
):
cp
=
torch
.
load
(
sevennet_0_path
,
weights_only
=
False
,
map_location
=
'cpu'
)
cfg
=
get_config
(
{
'continue'
:
{
'checkpoint'
:
sevennet_0_path
,
'reset_optimizer'
:
False
,
'reset_scheduler'
:
True
,
'reset_epoch'
:
False
,
}
}
)
shift_ref
=
cp
[
'model_state_dict'
][
'rescale_atomic_energy.shift'
].
cpu
().
numpy
()
scale_ref
=
np
.
array
([
1.73
]
*
89
)
conv_denominator_ref
=
np
.
array
([
35.989574
]
*
5
)
with
Logger
().
switch_file
(
str
(
tmp_path
/
'log.sevenn'
)):
state_dicts
,
epoch
=
processing_continue_v2
(
cfg
)
assert
epoch
==
601
assert
np
.
allclose
(
np
.
array
(
cfg
[
'shift'
]),
shift_ref
)
assert
np
.
allclose
(
np
.
array
(
cfg
[
'shift'
])[
0
],
-
5.062768
)
assert
np
.
allclose
(
np
.
array
(
cfg
[
'scale'
]),
scale_ref
)
assert
np
.
allclose
(
np
.
array
(
cfg
[
'conv_denominator'
]),
conv_denominator_ref
)
assert
cfg
[
'_number_of_species'
]
==
89
assert
cfg
[
'_type_map'
][
89
]
==
0
# Ac
assert
cfg
[
'_type_map'
][
40
]
==
88
# Zr
assert
state_dicts
[
2
]
is
None
# scheduler reset
@
pytest
.
mark
.
parametrize
(
'cfg_overwrite,ds_names'
,
[
({},
[
'trainset'
]),
({
'load_myset_path'
:
hfo2_path
},
[
'trainset'
,
'myset'
]),
],
)
def
test_dataset_from_config
(
cfg_overwrite
,
ds_names
,
tmp_path
):
cfg
=
get_config
(
cfg_overwrite
)
with
Logger
().
switch_file
(
str
(
tmp_path
/
'log.sevenn'
)):
datasets
=
dataset_from_config
(
cfg
,
tmp_path
)
assert
set
(
ds_names
)
==
set
(
datasets
.
keys
())
for
ds_name
in
ds_names
:
assert
(
tmp_path
/
'sevenn_data'
/
f
'
{
ds_name
}
.pt'
).
is_file
()
assert
(
tmp_path
/
'sevenn_data'
/
f
'
{
ds_name
}
.yaml'
).
is_file
()
def
test_dataset_from_config_as_it_is_load
(
graph_dataset_path
,
tmp_path
):
cfg
=
get_config
({
'load_trainset_path'
:
graph_dataset_path
})
new_wd
=
tmp_path
/
'tmp_wd'
with
Logger
().
switch_file
(
str
(
tmp_path
/
'log.sevenn'
)):
_
=
dataset_from_config
(
cfg
,
str
(
new_wd
))
print
((
tmp_path
/
'tmp_wd'
/
'sevenn_data'
))
assert
not
(
tmp_path
/
'tmp_wd'
/
'sevenn_data'
).
is_dir
()
@
pytest
.
mark
.
parametrize
(
'cfg_overwrite,shift,scale,conv'
,
[
(
{},
-
28.978
,
0.113304
,
25.333333
,
),
(
{
'shift'
:
-
1.2345678
,
},
-
1.234567
,
0.113304
,
25.333333
,
),
(
{
'conv_denominator'
:
'sqrt_avg_num_neigh'
,
},
-
28.978
,
0.113304
,
25.333333
**
0.5
,
),
(
{
'shift'
:
'force_rms'
,
},
0.113304
,
0.113304
,
25.333333
,
),
(
{
'shift'
:
'elemwise_reference_energies'
,
},
[
0.0
if
z
not
in
_elemwise_ref_energy_dct
else
_elemwise_ref_energy_dct
[
z
]
for
z
in
range
(
NUM_UNIV_ELEMENT
)
],
0.113304
,
25.333333
,
),
],
)
def
test_dataset_from_config_statistics_init
(
cfg_overwrite
,
shift
,
scale
,
conv
,
tmp_path
):
cfg
=
get_config
(
cfg_overwrite
)
with
Logger
().
switch_file
(
str
(
tmp_path
/
'log.sevenn'
)):
_
=
dataset_from_config
(
cfg
,
tmp_path
)
assert
np
.
allclose
(
cfg
[
'shift'
],
shift
)
assert
np
.
allclose
(
cfg
[
'scale'
],
scale
)
assert
np
.
allclose
(
cfg
[
'conv_denominator'
],
conv
)
def
test_dataset_from_config_chem_auto
(
tmp_path
):
cfg
=
get_config
(
{
'chemical_species'
:
'auto'
,
'_number_of_species'
:
'auto'
,
'_type_map'
:
'auto'
,
}
)
with
Logger
().
switch_file
(
str
(
tmp_path
/
'log.sevenn'
)):
_
=
dataset_from_config
(
cfg
,
tmp_path
)
assert
cfg
[
'chemical_species'
]
==
[
'Hf'
,
'O'
]
assert
cfg
[
'_number_of_species'
]
==
2
assert
cfg
[
'_type_map'
]
==
{
72
:
0
,
8
:
1
}
def
test_run_one_epoch
(
HfO2_loader
):
trainer_args
,
_
,
_
=
Trainer
.
args_from_checkpoint
(
cp_0_path
)
trainer
=
Trainer
(
**
trainer_args
)
erc
=
get_error_recorder
()
ref1
=
{
'Energy_RMSE'
:
'28.977758'
,
'Force_RMSE'
:
'0.214107'
,
'Stress_RMSE'
:
'190.014237'
,
}
ref2
=
{
'Energy_RMSE'
:
'28.977878'
,
'Force_RMSE'
:
'0.213105'
,
'Stress_RMSE'
:
'188.772557'
,
}
trainer
.
run_one_epoch
(
HfO2_loader
,
is_train
=
False
,
error_recorder
=
erc
)
ret1
=
erc
.
get_dct
()
erc
.
epoch_forward
()
for
k
in
ref1
:
assert
np
.
allclose
(
float
(
ret1
[
k
]),
float
(
ref1
[
k
]))
trainer
.
run_one_epoch
(
HfO2_loader
,
is_train
=
True
,
error_recorder
=
erc
)
erc
.
epoch_forward
()
trainer
.
run_one_epoch
(
HfO2_loader
,
is_train
=
False
,
error_recorder
=
erc
)
ret2
=
erc
.
get_dct
()
erc
.
epoch_forward
()
for
k
in
ref2
:
assert
np
.
allclose
(
float
(
ret2
[
k
]),
float
(
ref2
[
k
]))
def
test_processing_epoch_v2
(
HfO2_loader
,
tmp_path
):
trainer_args
,
_
,
_
=
Trainer
.
args_from_checkpoint
(
cp_0_path
)
trainer
=
Trainer
(
**
trainer_args
)
erc
=
get_error_recorder
()
start_epoch
=
10
total_epoch
=
12
per_epoch
=
1
best_metric
=
'Energy_RMSE'
best_metric_loader_key
=
'myset'
loaders
=
{
'trainset'
:
HfO2_loader
,
'myset'
:
HfO2_loader
}
with
Logger
().
switch_file
(
str
(
tmp_path
/
'log.sevenn'
)):
processing_epoch_v2
(
config
=
{},
trainer
=
trainer
,
loaders
=
loaders
,
start_epoch
=
start_epoch
,
error_recorder
=
erc
,
total_epoch
=
total_epoch
,
per_epoch
=
per_epoch
,
best_metric_loader_key
=
best_metric_loader_key
,
best_metric
=
best_metric
,
working_dir
=
tmp_path
,
)
assert
(
tmp_path
/
'checkpoint_10.pth'
).
is_file
()
assert
(
tmp_path
/
'checkpoint_11.pth'
).
is_file
()
assert
(
tmp_path
/
'checkpoint_12.pth'
).
is_file
()
assert
(
tmp_path
/
'checkpoint_best.pth'
).
is_file
()
assert
(
tmp_path
/
'lc.csv'
).
is_file
()
with
open
(
tmp_path
/
'lc.csv'
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
heads
=
[
ll
.
strip
()
for
ll
in
lines
[
0
].
split
(
','
)]
assert
'epoch'
in
heads
assert
'lr'
in
heads
assert
'trainset_Energy_RMSE'
in
heads
assert
'myset_Stress_MAE'
in
heads
lasts
=
[
ll
.
strip
()
for
ll
in
lines
[
-
1
].
split
(
','
)]
assert
lasts
[
0
]
==
'12'
assert
lasts
[
1
]
==
'0.000980'
# lr
assert
lasts
[
-
2
]
==
'0.087873'
# myset Force MAE
def
test_data_weight
(
graph_dataset_path
,
tmp_path
):
cfg
=
get_config
(
{
'load_trainset_path'
:
[{
'file_list'
:
[{
'file'
:
graph_dataset_path
}],
'data_weight'
:
{
'energy'
:
0.1
,
'force'
:
3.0
,
'stress'
:
1.0
},
}],
'error_record'
:
[
(
'Energy'
,
'Loss'
),
(
'Force'
,
'Loss'
),
(
'Stress'
,
'Loss'
),
(
'TotalLoss'
,
'None'
),
],
'use_weight'
:
True
}
)
trainer_args
,
_
,
_
=
Trainer
.
args_from_checkpoint
(
cp_0_path
)
trainer_args
[
'loss_functions'
]
=
get_loss_functions_from_config
(
cfg
)
trainer
=
Trainer
(
**
trainer_args
)
erc
=
ErrorRecorder
.
from_config
(
cfg
,
trainer
.
loss_functions
)
db
=
graph_ds
.
from_config
(
cfg
,
working_dir
=
tmp_path
)[
'trainset'
]
loader_w_weight
=
DataLoader
(
db
,
batch_size
=
len
(
db
))
trainer
.
run_one_epoch
(
loader_w_weight
,
False
,
erc
)
loss
=
erc
.
epoch_forward
()
assert
np
.
allclose
(
loss
[
'Energy_Loss'
],
839.7104492
*
0.1
)
assert
np
.
allclose
(
loss
[
'Force_Loss'
],
0.0152806
*
3.0
)
assert
np
.
allclose
(
loss
[
'Stress_Loss'
],
6017.568847
*
1.0
)
def
_write_empty_checkpoint
():
from
sevenn.model_build
import
build_E3_equivariant_model
# Function I used to make empty checkpoint, to write the test
cfg
=
get_config
({
'shift'
:
0.0
,
'scale'
:
1.0
,
'conv_denominator'
:
5.0
})
model
=
build_E3_equivariant_model
(
cfg
)
trainer
=
Trainer
.
from_config
(
model
,
cfg
)
# type: ignore
trainer
.
write_checkpoint
(
'./cp_0.pth'
,
config
=
cfg
,
epoch
=
0
)
if
__name__
==
'__main__'
:
_write_empty_checkpoint
()
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment