Unverified Commit b75ed73c authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 56d3c363
units metal
boundary __BOUNDARY__
box tilt large
read_data __LMP_STCT__
#replicate __REPLICATE__
mass * 1.0 # do not matter since we don't run MD
pair_style __PAIR_STYLE__
pair_coeff * * __POTENTIALS__ __ELEMENT__
timestep 0.002
compute pa all pe/atom
thermo 1
fix 1 all nve
thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp
dump mydump all custom 1 force.dump id type element c_pa x y z fx fy fz
dump_modify mydump sort id element __ELEMENT__
run 0
units metal
boundary __BOUNDARY__
read_data __LMP_STCT__
mass * 1.0 # do not matter since we don't run MD
pair_style __PAIR_STYLE__
pair_coeff * * __POTENTIALS__ __ELEMENT__
timestep 0.002
compute pa all pe/atom
thermo 1
fix 1 all nve
thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp
dump mydump all custom 1 __FORCE_DUMP_PATH__ id type element c_pa x y z fx fy fz
dump_modify mydump sort id element __ELEMENT__
run 0
import copy
import logging
import pathlib
import subprocess
import ase.calculators.lammps
import ase.io.lammpsdata
import numpy as np
import pytest
import torch
from ase.build import bulk, surface
from ase.calculators.singlepoint import SinglePointCalculator
import sevenn
from sevenn.calculator import SevenNetCalculator
from sevenn.model_build import build_E3_equivariant_model
from sevenn.nn.cue_helper import is_cue_available
from sevenn.scripts.deploy import deploy, deploy_parallel
from sevenn.util import chemical_species_preprocess, pretrained_name_to_path
logger = logging.getLogger('test_lammps')
cutoff = 4.0
lmp_script_path = str(
(pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve()
)
data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve()
cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O
cp_mf_path = pretrained_name_to_path('7net-mf-0')
@pytest.fixture(scope='module')
def serial_potential_path(tmp_path_factory):
tmp = tmp_path_factory.mktemp('serial_potential')
pot_path = str(tmp / 'deployed_serial.pt')
deploy(cp_0_path, pot_path)
return pot_path
@pytest.fixture(scope='module')
def parallel_potential_path(tmp_path_factory):
tmp = tmp_path_factory.mktemp('paralllel_potential')
pot_path = str(tmp / 'deployed_parallel')
deploy_parallel(cp_0_path, pot_path)
return ' '.join(['3', pot_path])
@pytest.fixture(scope='module')
def serial_modal_potential_path(tmp_path_factory):
tmp = tmp_path_factory.mktemp('serial_modal_potential')
pot_path = str(tmp / 'deployed_serial.pt')
deploy(cp_mf_path, pot_path, 'PBE')
return pot_path
@pytest.fixture(scope='module')
def parallel_modal_potential_path(tmp_path_factory):
tmp = tmp_path_factory.mktemp('paralllel_modal_potential')
pot_path = str(tmp / 'deployed_parallel')
deploy_parallel(cp_mf_path, pot_path, 'PBE')
return ' '.join(['5', pot_path])
@pytest.fixture(scope='module')
def ref_calculator():
return SevenNetCalculator(cp_0_path)
@pytest.fixture(scope='module')
def ref_modal_calculator():
return SevenNetCalculator(cp_mf_path, modal='PBE')
def get_model_config():
config = {
'cutoff': cutoff,
'channel': 8,
'lmax': 2,
'is_parity': True,
'num_convolution_layer': 3,
'self_connection_type': 'linear', # not NequIp
'interaction_type': 'nequip',
'radial_basis': {
'radial_basis_name': 'bessel',
},
'cutoff_function': {'cutoff_function_name': 'poly_cut'},
'weight_nn_hidden_neurons': [64, 64],
'act_radial': 'silu',
'act_scalar': {'e': 'silu', 'o': 'tanh'},
'act_gate': {'e': 'silu', 'o': 'tanh'},
'conv_denominator': 30.0,
'train_denominator': False,
'shift': -10.0,
'scale': 10.0,
'train_shift_scale': False,
'irreps_manual': False,
'lmax_edge': -1,
'lmax_node': -1,
'readout_as_fcn': False,
'use_bias_in_linear': False,
'_normalize_sph': True,
}
config.update(chemical_species_preprocess(['Hf', 'O']))
return config
def get_model(config_overwrite=None, use_cueq=False, cueq_config=None):
cf = get_model_config()
if config_overwrite is not None:
cf.update(config_overwrite)
cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}}
cf.update(cueq_config)
model = build_E3_equivariant_model(cf, parallel=False)
assert not isinstance(model, list)
return model
def hfo2_bulk(replicate=(2, 2, 2), a=4.0):
atoms = bulk('HfO', 'rocksalt', a, orthorhombic=True)
atoms = atoms * replicate
atoms.rattle(stdev=0.10)
return atoms
def hf_surface(replicate=(3, 3, 1), layers=4, vacuum=0.5):
atoms = surface('Al', (1, 0, 0), layers=layers, vacuum=vacuum)
atoms.set_atomic_numbers([72] * len(atoms)) # Hf
atoms = atoms * replicate
atoms.rattle(stdev=0.10)
return atoms
def get_system(system_name, **kwargs):
if system_name == 'bulk':
return hfo2_bulk(**kwargs)
elif system_name == 'surface':
return hf_surface(**kwargs)
else:
raise ValueError()
def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6):
def acl(a, b, rtol=rtol, atol=atol):
return np.allclose(a, b, rtol=rtol, atol=atol)
assert len(atoms1) == len(atoms2)
assert acl(atoms1.get_cell(), atoms2.get_cell())
assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy())
assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10)
assert acl(
atoms1.get_stress(voigt=False),
atoms2.get_stress(voigt=False),
rtol * 10,
atol * 10,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
def _lammps_results_to_atoms(lammps_log, force_dump):
with open(lammps_log, 'r') as f:
lines = f.readlines()
lmp_log = None
for i, line in enumerate(lines):
if not line.startswith('Per MPI rank memory allocation'):
continue
lmp_log = {
k: eval(v) for k, v in zip(lines[i + 1].split(), lines[i + 2].split())
}
break
assert lmp_log is not None and 'PotEng' in lmp_log
latoms_list = ase.io.read(force_dump, format='lammps-dump-text', index=':')
assert isinstance(latoms_list, list)
latoms = latoms_list[0]
assert latoms.calc is not None
latoms.calc.results['energy'] = lmp_log['PotEng']
latoms.calc.results['free_energy'] = lmp_log['PotEng']
latoms.info = {
'data_from': 'lammps',
'lmp_log': lmp_log,
'lmp_dump': force_dump,
}
# atomic energy read
latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0]
stress = np.array(
[
[lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']],
[lmp_log['Pxy'], lmp_log['Pyy'], lmp_log['Pyz']],
[lmp_log['Pxz'], lmp_log['Pyz'], lmp_log['Pzz']],
]
)
stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3
latoms.calc.results['stress'] = stress
return latoms
def _run_lammps(atoms, pair_style, potential, wd, command, test_name):
wd = wd.resolve()
pbc = atoms.get_pbc()
pbc_str = ' '.join(['p' if x else 'f' for x in pbc])
chem = list(set(atoms.get_chemical_symbols()))
# Way to ase handle lammps structure
prism = ase.calculators.lammps.coordinatetransform.Prism(
atoms.get_cell(), pbc=pbc
)
lmp_stct = wd / 'lammps_structure'
ase.io.lammpsdata.write_lammps_data(
lmp_stct, atoms, prismobj=prism, specorder=chem
)
with open(lmp_script_path, 'r') as f:
cont = f.read()
lammps_log = str(wd / 'log.lammps')
force_dump = str(wd / 'force.dump')
var_dct = {}
var_dct['__ELEMENT__'] = ' '.join(chem)
var_dct['__LMP_STCT__'] = str(lmp_stct.resolve())
var_dct['__PAIR_STYLE__'] = pair_style
var_dct['__POTENTIALS__'] = potential
var_dct['__BOUNDARY__'] = pbc_str
var_dct['__FORCE_DUMP_PATH__'] = force_dump
for key, val in var_dct.items():
cont = cont.replace(key, val)
input_script_path = str(wd / 'in.lmp')
with open(input_script_path, 'w') as f:
f.write(cont)
command = f'{command} -in {input_script_path} -log {lammps_log}'
subprocess_routine(command.split(), test_name)
lmp_atoms = _lammps_results_to_atoms(lammps_log, force_dump)
assert lmp_atoms.calc is not None
rot_mat = prism.rot_mat
results = copy.deepcopy(lmp_atoms.calc.results)
r_force = np.dot(results['forces'], rot_mat.T)
results['forces'] = r_force
if 'stress' in results:
# see ase.calculators.lammpsrun.py
stress_tensor = results['stress']
stress_atoms = np.dot(np.dot(rot_mat, stress_tensor), rot_mat.T)
results['stress'] = stress_atoms
r_cell = lmp_atoms.get_cell() @ rot_mat.T
lmp_atoms.set_cell(r_cell, scale_atoms=True)
lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms()
return lmp_atoms
def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd):
command = lammps_cmd
return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name)
def parallel_lammps_run(
atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd
):
command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}'
return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name)
def subprocess_routine(cmd, name):
res = subprocess.run(cmd, capture_output=True, timeout=30)
if res.returncode != 0:
logger.error(f'Subprocess {name} failed return code: {res.returncode}')
logger.error(res.stderr.decode('utf-8'))
raise RuntimeError(f'{name} failed')
logger.info(f'stdout of {name}:')
logger.info(res.stdout.decode('utf-8'))
@pytest.mark.parametrize(
'system',
['bulk', 'surface'],
)
def test_serial(system, serial_potential_path, ref_calculator, lammps_cmd, tmp_path):
atoms = get_system(system)
atoms_lammps = serial_lammps_run(
atoms=atoms,
potential=serial_potential_path,
wd=tmp_path,
test_name='serial lmp test',
lammps_cmd=lammps_cmd,
)
atoms.calc = ref_calculator
assert_atoms(atoms, atoms_lammps)
@pytest.mark.parametrize(
'system,ncores',
[
('bulk', 1),
('bulk', 2),
('bulk', 4),
('surface', 1),
('surface', 2),
('surface', 3),
('surface', 4),
],
)
def test_parallel(
system,
ncores,
parallel_potential_path,
ref_calculator,
lammps_cmd,
mpirun_cmd,
tmp_path,
):
if system == 'bulk':
rep = (6, 6, 3)
elif system == 'surface':
rep = (4, 4, 1)
else:
assert False
atoms = get_system(system, replicate=rep)
atoms_lammps = parallel_lammps_run(
atoms=atoms,
potential=parallel_potential_path,
wd=tmp_path,
test_name='parallel lmp test',
lammps_cmd=lammps_cmd,
mpirun_cmd=mpirun_cmd,
ncores=ncores,
)
atoms.calc = ref_calculator
assert_atoms(atoms, atoms_lammps)
@pytest.mark.parametrize(
'system',
['bulk', 'surface'],
)
def test_modal_serial(
system, serial_modal_potential_path, ref_modal_calculator, lammps_cmd, tmp_path
):
atoms = get_system(system)
atoms_lammps = serial_lammps_run(
atoms=atoms,
potential=serial_modal_potential_path,
wd=tmp_path,
test_name='serial lmp test',
lammps_cmd=lammps_cmd,
)
atoms.calc = ref_modal_calculator
assert_atoms(atoms, atoms_lammps)
@pytest.mark.parametrize(
'system,ncores',
[
('bulk', 2),
('surface', 2),
],
)
def test_modal_parallel(
system,
ncores,
parallel_modal_potential_path,
ref_modal_calculator,
lammps_cmd,
mpirun_cmd,
tmp_path,
):
if system == 'bulk':
rep = (6, 6, 3)
elif system == 'surface':
rep = (4, 4, 1)
else:
assert False
atoms = get_system(system, replicate=rep)
atoms_lammps = parallel_lammps_run(
atoms=atoms,
potential=parallel_modal_potential_path,
wd=tmp_path,
test_name='parallel lmp test',
lammps_cmd=lammps_cmd,
mpirun_cmd=mpirun_cmd,
ncores=ncores,
)
atoms.calc = ref_modal_calculator
assert_atoms(atoms, atoms_lammps)
@pytest.mark.filterwarnings('ignore:.*is not found from.*')
@pytest.mark.skipif(not is_cue_available(), reason='cueq not available')
def test_cueq_serial(lammps_cmd, tmp_path):
"""
TODO: Use already saved cueq enabled checkpoint after cueq becomes stable
"""
cueq = True
model = get_model(use_cueq=cueq)
ref_calc = SevenNetCalculator(model, file_type='model_instance')
atoms = get_system('bulk')
cfg = get_model_config()
cfg.update(
{'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__}
)
cp_path = str(tmp_path / 'cp.pth')
torch.save(
{'model_state_dict': model.state_dict(), 'config': cfg},
cp_path,
)
pot_path = str(tmp_path / 'deployed_from_cueq_serial.pt')
deploy(cp_path, pot_path)
atoms_lammps = serial_lammps_run(
atoms=atoms,
potential=pot_path,
wd=tmp_path,
test_name='cueq checkpoint serial lmp run test',
lammps_cmd=lammps_cmd,
)
atoms.calc = ref_calc
assert_atoms(atoms, atoms_lammps)
@pytest.mark.filterwarnings('ignore:.*is not found from.*')
@pytest.mark.skipif(not is_cue_available(), reason='cueq not available')
def test_cueq_parallel(lammps_cmd, mpirun_cmd, tmp_path):
"""
TODO: Use already saved cueq enabled checkpoint after cueq becomes stable
"""
cueq = True
model = get_model(use_cueq=cueq)
ref_calc = SevenNetCalculator(model, file_type='model_instance')
atoms = get_system('surface', replicate=(4, 4, 1))
cfg = get_model_config()
cfg.update(
{'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__}
)
cp_path = str(tmp_path / 'cp.pth')
torch.save(
{'model_state_dict': model.state_dict(), 'config': cfg},
cp_path,
)
pot_path = str(tmp_path / 'deployed_from_cueq_parallel')
deploy_parallel(cp_path, pot_path)
atoms_lammps = parallel_lammps_run(
atoms=atoms,
potential=' '.join([str(cfg['num_convolution_layer']), pot_path]),
wd=tmp_path,
test_name='cueq checkpoint parallel lmp run test',
lammps_cmd=lammps_cmd,
mpirun_cmd=mpirun_cmd,
ncores=2,
)
atoms.calc = ref_calc
assert_atoms(atoms, atoms_lammps)
import copy
import numpy as np
import pytest
from ase.build import bulk, molecule
from sevenn.calculator import D3Calculator, SevenNetCalculator
from sevenn.nn.cue_helper import is_cue_available
from sevenn.scripts.deploy import deploy
from sevenn.util import (
model_from_checkpoint,
model_from_checkpoint_with_backend,
pretrained_name_to_path,
)
@pytest.fixture
def atoms_pbc():
atoms1 = bulk('NaCl', 'rocksalt', a=5.63)
atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]])
atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]])
return atoms1
@pytest.fixture
def atoms_mol():
atoms2 = molecule('H2O')
atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]])
return atoms2
@pytest.fixture(scope='module')
def sevennet_0_cal():
return SevenNetCalculator('7net-0_11July2024')
@pytest.fixture(scope='module')
def sevennet_0_cueq_cal():
cpp = pretrained_name_to_path('7net-0_11July2024')
model, _ = model_from_checkpoint_with_backend(cpp, 'cueq')
return SevenNetCalculator(model)
@pytest.fixture(scope='module')
def d3_cal():
try:
return D3Calculator()
except NotImplementedError as e:
pytest.skip(f'{e}')
def test_sevennet_0_cal_pbc(atoms_pbc, sevennet_0_cal):
atoms1_ref = {
'energy': -3.779199,
'energies': [-1.8493923, -1.9298072],
'force': [
[12.666697, 0.04726403, 0.04775861],
[-12.666697, -0.04726403, -0.04775861],
],
'stress': [
[
-0.6439122,
-0.03643947,
-0.03643981,
0.00599139,
0.04544507,
0.04543639,
]
],
}
atoms_pbc.calc = sevennet_0_cal
assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy'])
assert np.allclose(
atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy']
)
assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force'])
assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress'])
assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies'])
def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal):
atoms2_ref = {
'energy': -12.782808303833008,
'energies': [-6.2493525, -3.141562, -3.3918958],
'force': [
[0.0, -1.3619621e01, 7.5937047e00],
[0.0, 9.3918495e00, -1.0172190e01],
[0.0, 4.2277718e00, 2.5784855e00],
],
}
atoms_mol.calc = sevennet_0_cal
assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy'])
assert np.allclose(
atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy']
)
assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force'])
assert np.allclose(atoms_mol.get_potential_energies(), atoms2_ref['energies'])
def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc):
fname = str(tmp_path / '7net_0.pt')
deploy(pretrained_name_to_path('7net-0_11July2024'), fname)
calc_script = SevenNetCalculator(fname, file_type='torchscript')
calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024'))
atoms_pbc.calc = calc_cp
atoms_pbc.get_potential_energy()
res_cp = copy.copy(atoms_pbc.calc.results)
atoms_pbc.calc = calc_script
atoms_pbc.get_potential_energy()
res_script = copy.copy(atoms_pbc.calc.results)
for k in res_cp:
assert np.allclose(res_cp[k], res_script[k])
def test_sevennet_0_cal_as_instance_consistency(atoms_pbc):
model, _ = model_from_checkpoint(
pretrained_name_to_path('7net-0_11July2024')
)
calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024'))
calc_instance = SevenNetCalculator(model, file_type='model_instance')
atoms_pbc.calc = calc_cp
atoms_pbc.get_potential_energy()
res_cp = copy.copy(atoms_pbc.calc.results)
atoms_pbc.calc = calc_instance
atoms_pbc.get_potential_energy()
res_script = copy.copy(atoms_pbc.calc.results)
for k in res_cp:
assert np.allclose(res_cp[k], res_script[k])
@pytest.mark.skipif(not is_cue_available(), reason='cueq not available')
def test_sevennet_0_cal_cueq(atoms_pbc, sevennet_0_cueq_cal):
atoms1_ref = {
'energy': -3.779199,
'energies': [-1.8493923, -1.9298072],
'force': [
[12.666697, 0.04726403, 0.04775861],
[-12.666697, -0.04726403, -0.04775861],
],
'stress': [
[
-0.6439122,
-0.03643947,
-0.03643981,
0.00599139,
0.04544507,
0.04543639,
]
],
}
atoms_pbc.calc = sevennet_0_cueq_cal
assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy'])
assert np.allclose(
atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy']
)
assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force'])
assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress'])
assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies'])
def test_d3_cal_pbc(atoms_pbc, d3_cal):
atoms1_ref = {
'energy': -0.531393751583389,
'force': [
[-0.00570205, 0.00107457, 0.00107459],
[0.00570205, -0.00107457, -0.00107459],
],
'stress': [
[
1.52403705e-02,
1.50417333e-02,
1.50417321e-02,
-3.22684163e-05,
-5.05532863e-05,
-5.05586994e-05,
]
],
}
atoms_pbc.calc = d3_cal
assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy'])
assert np.allclose(
atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy']
)
assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force'])
assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress'])
def test_d3_cal_mol(atoms_mol, d3_cal):
atoms2_ref = {
'energy': -0.009889134535170716,
'force': [
[0.0, 2.04263840e-03, 1.27477674e-03],
[0.0, -9.90038901e-05, 1.18046682e-06],
[0.0, -1.94363451e-03, -1.27595721e-03],
],
}
atoms_mol.calc = d3_cal
assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy'])
assert np.allclose(
atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy']
)
assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force'])
import csv
import os
import pathlib
from unittest import mock
import ase.io
import numpy as np
import pytest
import yaml
from ase.build import bulk
from sevenn.calculator import SevenNetCalculator
from sevenn.logger import Logger
from sevenn.main.sevenn import main as sevenn_main
from sevenn.main.sevenn_get_model import main as get_model_main
from sevenn.main.sevenn_graph_build import main as graph_build_main
from sevenn.main.sevenn_inference import main as inference_main
from sevenn.util import pretrained_name_to_path
main = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/main/')
preset = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/presets/')
file_path = pathlib.Path(__file__).parent.resolve()
data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve()
hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz')
hfo2_7net_0_inference_path = data_root / 'inferences' / 'snet0_on_hfo2'
cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth')
Logger() # init
@pytest.fixture
def atoms_hfo():
atoms1 = bulk('HfO', 'rocksalt', a=5.63)
atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]])
atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]])
return atoms1
@pytest.fixture(scope='module')
def sevennet_0_cal():
return SevenNetCalculator('7net-0_11July2024')
def test_get_model_serial(tmp_path, capsys):
output_file = tmp_path / 'mypot.pt'
cp = pretrained_name_to_path('7net-0')
cli_args = ['-o', str(output_file), cp]
with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args):
get_model_main()
_ = capsys.readouterr() # not used
assert output_file.is_file(), '.pt file is not written'
def test_get_model_parallel(tmp_path, capsys):
output_dir = tmp_path / 'my_parallel'
cp = pretrained_name_to_path('7net-0')
expected_file_cnt = 5 # 5 interaction layers
cli_args = ['-o', str(output_dir), '-p', cp]
with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args):
# with pytest.raises(SystemExit):
get_model_main()
_ = capsys.readouterr() # not used
assert output_dir.is_dir(), 'parallel model directory not exist'
for i in range(expected_file_cnt):
assert (output_dir / f'deployed_parallel_{i}.pt').is_file()
@pytest.mark.parametrize('source', [(hfo2_path)])
def test_graph_build(source, tmp_path):
output_dir = tmp_path / 'sevenn_data'
output_f = output_dir / 'my_graph.pt'
output_yml = output_dir / 'my_graph.yaml'
cli_args = ['-o', str(tmp_path), '-f', 'my_graph.pt', source, '4.0']
with mock.patch('sys.argv', [f'{main}/sevenn_graph_build.py'] + cli_args):
graph_build_main()
assert output_dir.is_dir()
assert output_f.is_file()
assert output_yml.is_file()
@pytest.mark.parametrize(
'batch,device,save_graph',
[
(1, 'cpu', False),
(2, 'cpu', False),
(1, 'cpu', True),
],
)
def test_inference(batch, device, save_graph, tmp_path):
checkpoint = '7net-0'
target = hfo2_path
ref_path = hfo2_7net_0_inference_path
output_dir = tmp_path / 'inference_results'
files = ['info.csv', 'per_graph.csv', 'per_atom.csv', 'errors.txt']
cli_args = [
'--output',
str(output_dir),
'--device',
device,
'--batch',
str(batch),
checkpoint,
target,
]
if save_graph:
cli_args.append('--save_graph')
with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args):
inference_main()
assert output_dir.is_dir()
for f in files:
assert (output_dir / f).is_file()
with open(output_dir / 'errors.txt', 'r', encoding='utf-8') as f:
errors = [float(ll.split(':')[-1].strip()) for ll in f.readlines()]
with open(ref_path / 'errors.txt', 'r', encoding='utf-8') as f:
errors_ref = [float(ll.split(':')[-1].strip()) for ll in f.readlines()]
assert np.allclose(np.array(errors), np.array(errors_ref))
"""
# TODO: commented out as currently SevenNetGraphDataset can't do this
with open(output_dir / 'info.csv', 'r') as f:
reader = csv.DictReader(f)
for dct in reader:
assert dct['file'] == hfo2_path
assert reader.line_num == 3
"""
if save_graph:
assert (output_dir / 'sevenn_data').is_dir()
assert (output_dir / 'sevenn_data' / 'saved_graph.pt').is_file()
assert (output_dir / 'sevenn_data' / 'saved_graph.yaml').is_file()
def test_inference_unlabeled(atoms_hfo, tmp_path):
labeled = str(hfo2_path)
unlabeled = str(tmp_path / 'unlabeled.xyz')
ase.io.write(unlabeled, atoms_hfo)
output_dir = tmp_path / 'inference_results'
cli_args = [
'--output',
str(output_dir),
'--allow_unlabeled',
cp_0_path,
labeled,
unlabeled,
]
with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args):
inference_main()
with open(output_dir / 'info.csv', 'r') as f:
reader = csv.DictReader(f)
for dct in reader:
assert dct['file'] in [labeled, unlabeled]
assert reader.line_num == 4
def test_inference_labeled_w_kwargs(atoms_hfo, tmp_path):
atoms_hfo.info['my_energy'] = 1.0
atoms_hfo.arrays['my_force'] = np.full((len(atoms_hfo), 3), 7.7)
# this should be considered as Voigt, xx, yy, zz, yz, zx, xy
atoms_hfo.info['my_stress'] = np.array([1, 2, 3, 4, 5, 6])
unlabeled = str(tmp_path / 'unlabeled.xyz')
ase.io.write(unlabeled, atoms_hfo)
output_dir = tmp_path / 'inference_results'
cli_args = [
'--output',
str(output_dir),
cp_0_path,
unlabeled,
'--kwargs',
'energy_key=my_energy',
'force_key=my_force',
'stress_key=my_stress',
]
with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args):
inference_main()
per_graph = None
with open(output_dir / 'per_graph.csv', 'r') as f:
reader = csv.DictReader(f)
for dct in reader:
per_graph = dct
assert reader.line_num == 2
assert per_graph is not None
stress_coeff = -1602.1766208
assert np.allclose(float(per_graph['stress_yy']), 2 * stress_coeff)
assert np.allclose(float(per_graph['stress_yz']), 4 * stress_coeff)
assert np.allclose(float(per_graph['stress_zx']), 5 * stress_coeff)
assert np.allclose(float(per_graph['stress_xy']), 6 * stress_coeff)
@pytest.mark.parametrize(
'preset_name,mode,data_path',
[
('fine_tune', 'train_v2', hfo2_path),
('base', 'train_v2', hfo2_path),
('sevennet-0', 'train_v1', hfo2_path),
],
)
def test_sevenn_preset(preset_name, mode, data_path, tmp_path):
preset_path = os.path.join(preset, preset_name + '.yaml')
with open(preset_path, 'r') as f:
cfg = yaml.safe_load(f)
cfg['train']['epoch'] = 1
if mode == 'train_v2':
cfg['data']['load_trainset_path'] = data_path
cfg['data'].pop('load_testset_path', None)
elif mode == 'train_v1':
cfg['data']['load_dataset_path'] = data_path
else:
assert False
cfg['data']['load_validset_path'] = data_path
input_yam = str(tmp_path / 'input.yaml')
with open(input_yam, 'w') as f:
yaml.dump(cfg, f)
Logger().switch_file(str(tmp_path / 'log.sevenn'))
cli_args = ['train', '-w', str(tmp_path), '-m', mode, input_yam]
with mock.patch('sys.argv', [f'{main}/sevenn.py'] + cli_args):
sevenn_main()
assert (tmp_path / 'lc.csv').is_file() or (tmp_path / 'log.csv').is_file()
assert (tmp_path / 'log.sevenn').is_file()
assert (tmp_path / 'checkpoint_best.pth').is_file()
# TODO: add gradient test from total loss after double precision.
# so far, it is empirically checked by seeing learning curves
import copy
import numpy as np
import pytest
import torch
from ase.build import bulk
from torch_geometric.loader.dataloader import Collater
import sevenn
import sevenn.train.dataload as dl
from sevenn.atom_graph_data import AtomGraphData
from sevenn.calculator import SevenNetCalculator
from sevenn.model_build import build_E3_equivariant_model
from sevenn.nn.cue_helper import is_cue_available
from sevenn.nn.sequential import AtomGraphSequential
from sevenn.util import (
chemical_species_preprocess,
model_from_checkpoint_with_backend,
)
cutoff = 4.0
_atoms = bulk('NaCl', 'rocksalt', a=4.00) * (2, 2, 2)
_avg_num_neigh = 30.0
_atoms.rattle()
_graph = AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(_atoms, cutoff))
def get_graphs(batched):
# batch size 2
cloned = [_graph.clone().to('cuda'), _graph.clone().to('cuda')]
if not batched:
return cloned
else:
return Collater(cloned)(cloned)
def get_model_config():
config = {
'cutoff': cutoff,
'channel': 32,
'lmax': 2,
'is_parity': True,
'num_convolution_layer': 3,
'self_connection_type': 'nequip', # not NequIp
'interaction_type': 'nequip',
'radial_basis': {
'radial_basis_name': 'bessel',
},
'cutoff_function': {'cutoff_function_name': 'poly_cut'},
'weight_nn_hidden_neurons': [64, 64],
'act_radial': 'silu',
'act_scalar': {'e': 'silu', 'o': 'tanh'},
'act_gate': {'e': 'silu', 'o': 'tanh'},
'conv_denominator': _avg_num_neigh,
'train_denominator': False,
'shift': -10.0,
'scale': 10.0,
'train_shift_scale': False,
'irreps_manual': False,
'lmax_edge': -1,
'lmax_node': -1,
'readout_as_fcn': False,
'use_bias_in_linear': False,
'_normalize_sph': True,
}
chems = set()
chems.update(_atoms.get_chemical_symbols())
config.update(**chemical_species_preprocess(list(chems)))
return config
def get_model(config_overwrite=None, use_cueq=False, cueq_config=None):
cf = get_model_config()
if config_overwrite is not None:
cf.update(config_overwrite)
cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}}
cf.update(cueq_config)
model = build_E3_equivariant_model(cf, parallel=False)
assert isinstance(model, AtomGraphSequential)
model.to('cuda')
return model
@pytest.mark.skipif(
not is_cue_available() or not torch.cuda.is_available(),
reason='cueq or gpu is not available',
)
@pytest.mark.parametrize(
'cf',
[
({}),
({'self_connection_type': 'linear'}),
({'is_parity': False}),
({'channel': 8}),
({'lmax': 3}),
({'num_interaction_layer': 2}),
({'num_interaction_layer': 4}),
],
)
def test_model_output(cf):
torch.manual_seed(777)
model_e3nn = get_model(cf)
torch.manual_seed(777)
model_cueq = get_model(cf, use_cueq=True)
model_e3nn.set_is_batch_data(True)
model_cueq.set_is_batch_data(True)
e3nn_out = model_e3nn._preprocess(get_graphs(batched=True))
cueq_out = model_cueq._preprocess(get_graphs(batched=True))
for k, e3nn_f in model_e3nn._modules.items():
cueq_f = model_cueq._modules[k]
e3nn_out = e3nn_f(e3nn_out) # type: ignore
cueq_out = cueq_f(cueq_out) # type: ignore
assert torch.allclose(e3nn_out.x, cueq_out.x, atol=1e-6), (
f'{k} \n\n {e3nn_f} \n\n {cueq_f}'
)
assert torch.allclose(
e3nn_out.inferred_total_energy, cueq_out.inferred_total_energy
)
assert torch.allclose(e3nn_out.atomic_energy, cueq_out.atomic_energy)
assert torch.allclose(
e3nn_out.inferred_force, cueq_out.inferred_force, atol=1e-5
)
assert torch.allclose(
e3nn_out.inferred_stress, cueq_out.inferred_stress, atol=1e-5
)
@pytest.mark.filterwarnings('ignore:.*is not found from.*')
@pytest.mark.skipif(
not is_cue_available() or not torch.cuda.is_available(),
reason='cueq or gpu is not available',
)
@pytest.mark.parametrize(
'start_from_cueq',
[
(True),
(False),
],
)
def test_checkpoint_convert(tmp_path, start_from_cueq):
torch.manual_seed(123)
model_from = get_model(use_cueq=start_from_cueq)
cfg = get_model_config()
cfg.update(
{
'cuequivariance_config': {'use': start_from_cueq},
'version': sevenn.__version__,
}
)
torch.save(
{'model_state_dict': model_from.state_dict(), 'config': cfg},
tmp_path / 'cp_from.pth',
)
backend = 'e3nn' if start_from_cueq else 'cueq'
model_to, _ = model_from_checkpoint_with_backend(
str(tmp_path / 'cp_from.pth'), backend
)
model_to.to('cuda')
model_from.set_is_batch_data(True)
model_to.set_is_batch_data(True)
from_out = model_from(get_graphs(batched=True))
to_out = model_to(get_graphs(batched=True))
assert torch.allclose(
from_out.inferred_total_energy, to_out.inferred_total_energy
)
assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy)
assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5)
assert torch.allclose(
from_out.inferred_stress, to_out.inferred_stress, atol=1e-5
)
@pytest.mark.filterwarnings('ignore:.*is not found from.*')
@pytest.mark.skipif(
not is_cue_available() or not torch.cuda.is_available(),
reason='cueq or gpu is not available',
)
@pytest.mark.parametrize(
'start_from_cueq',
[
(True),
(False),
],
)
def test_checkpoint_convert_no_batch(tmp_path, start_from_cueq):
torch.manual_seed(123)
model_from = get_model(use_cueq=start_from_cueq)
cfg = get_model_config()
cfg.update(
{
'cuequivariance_config': {'use': start_from_cueq},
'version': sevenn.__version__,
}
)
torch.save(
{'model_state_dict': model_from.state_dict(), 'config': cfg},
tmp_path / 'cp_from.pth',
)
backend = 'e3nn' if start_from_cueq else 'cueq'
model_to, _ = model_from_checkpoint_with_backend(
str(tmp_path / 'cp_from.pth'), backend
)
model_to.to('cuda')
model_from.set_is_batch_data(False)
model_to.set_is_batch_data(False)
from_out = model_from(get_graphs(batched=False)[0])
to_out = model_to(get_graphs(batched=False)[0])
assert torch.allclose(
from_out.inferred_total_energy, to_out.inferred_total_energy
)
assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy)
assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5)
assert torch.allclose(
from_out.inferred_stress, to_out.inferred_stress, atol=1e-5
)
def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6):
def acl(a, b, rtol=rtol, atol=atol):
return np.allclose(a, b, rtol=rtol, atol=atol)
assert len(atoms1) == len(atoms2)
assert acl(atoms1.get_cell(), atoms2.get_cell())
assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy())
assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10)
assert acl(
atoms1.get_stress(voigt=False),
atoms2.get_stress(voigt=False),
rtol * 10,
atol * 10,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
@pytest.mark.filterwarnings('ignore:.*is not found from.*')
@pytest.mark.skipif(
not is_cue_available() or not torch.cuda.is_available(),
reason='cueq or gpu is not available',
)
def test_calculator(tmp_path):
cueq = True
model = get_model(use_cueq=cueq)
ref_calc = SevenNetCalculator(model, file_type='model_instance')
atoms = copy.deepcopy(_atoms)
atoms.calc = ref_calc
cfg = get_model_config()
cfg.update(
{'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__}
)
cp_path = str(tmp_path / 'cp.pth')
torch.save(
{'model_state_dict': model.state_dict(), 'config': cfg},
cp_path,
)
calc2 = SevenNetCalculator(cp_path, enable_cueq=False)
atoms2 = copy.deepcopy(_atoms)
atoms2.calc = calc2
assert_atoms(atoms, atoms2)
import logging
import os
import os.path as osp
import uuid
from collections import Counter
from copy import deepcopy
from typing import Literal
import ase.calculators.singlepoint as singlepoint
import ase.io
import numpy as np
import pytest
import torch
from ase import Atoms
from ase.build import bulk, molecule
from torch_geometric.loader import DataLoader
import sevenn._keys as KEY
import sevenn.train.dataload as dl
import sevenn.train.graph_dataset as ds
import sevenn.train.modal_dataset as modal_dataset
from sevenn._const import NUM_UNIV_ELEMENT
from sevenn.atom_graph_data import AtomGraphData
from sevenn.util import model_from_checkpoint, pretrained_name_to_path
cutoff = 4.0
lattice_constant = 3.35
_samples = {
'bulk': bulk('NaCl', 'rocksalt', a=5.63),
'mol': molecule('H2O'),
'isolated': molecule('H'),
'small_bulk': Atoms(
symbols='Cu',
positions=[
(0, 0, 0), # Atom at the corner of the cube
],
cell=[
[lattice_constant, 0, 0],
[0, lattice_constant, 0],
[0, 0, lattice_constant],
],
pbc=True, # Periodic boundary conditions
),
}
_nedges_c4 = {'bulk': 36, 'mol': 6, 'isolated': 0, 'small_bulk': 18}
def get_atoms(
atoms_type: Literal['bulk', 'mol', 'isolated', 'small_bulk'],
init_y_as: Literal['calc', 'info', 'none'],
):
"""
Return atoms w, w/o reference values with its
# of edges for 4.0 cutoff length
"""
assert atoms_type in _samples
atoms = deepcopy(_samples[atoms_type])
natoms = len(atoms)
if init_y_as == 'calc':
results = {
'energy': np.random.rand(1),
'forces': np.random.rand(natoms, 3),
'stress': np.random.rand(6),
}
if not atoms.pbc.all():
del results['stress']
calc = singlepoint.SinglePointCalculator(atoms, **results)
atoms = calc.get_atoms()
elif init_y_as == 'info':
atoms.info['y_energy'] = np.random.rand(1)
atoms.arrays['y_force'] = np.random.rand(natoms, 3)
atoms.info['y_stress'] = np.random.rand(6)
if not atoms.pbc.all():
del atoms.info['y_stress']
return atoms, _nedges_c4[atoms_type]
@pytest.mark.parametrize('init_y_as', ['calc', 'info'])
@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated'])
def test_atoms_to_graph(atoms_type, init_y_as):
atoms, nedges = get_atoms(atoms_type, init_y_as)
is_stress = atoms.pbc.all()
y_from_calc = init_y_as == 'calc'
graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc)
essential = {
'atomic_numbers': ((len(atoms),), int),
'pos': ((len(atoms), 3), float),
'edge_index': ((2, nedges), int),
'edge_vec': ((nedges, 3), float),
'total_energy': ((), float),
'force_of_atoms': ((len(atoms), 3), float),
'cell_volume': ((), float),
'num_atoms': ((), int),
'per_atom_energy': ((), float),
'stress': ((1, 6), float),
}
for k, (shape, dtype) in essential.items():
assert k in graph, f'{k} missing in graph'
assert isinstance(
graph[k], np.ndarray
), f'{k}: {type(graph[k])} is not np.ndarray'
assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}'
if not is_stress and k == 'stress':
assert np.isnan(graph[k]).all()
else:
assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}'
assert graph['per_atom_energy'] == (graph['total_energy'] / len(atoms))
assert graph['num_atoms'] == len(atoms)
if not is_stress:
assert graph['cell_volume'] == np.finfo(float).eps
@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated'])
def test_unlabeled_atoms_to_graph(atoms_type):
atoms, nedges = get_atoms(atoms_type, 'none')
graph = dl.unlabeled_atoms_to_graph(atoms, cutoff=cutoff)
essential = {
'atomic_numbers': ((len(atoms),), int),
'pos': ((len(atoms), 3), float),
'edge_index': ((2, nedges), int),
'edge_vec': ((nedges, 3), float),
'cell_volume': ((), float),
'num_atoms': ((), int),
}
for k, (shape, dtype) in essential.items():
assert k in graph, f'{k} missing in graph'
assert isinstance(
graph[k], np.ndarray
), f'{k}: {type(graph[k])} is not np.ndarray'
assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}'
assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}'
assert graph['num_atoms'] == len(atoms)
if not atoms.pbc.all():
assert graph['cell_volume'] == np.finfo(float).eps
@pytest.mark.parametrize('init_y_as', ['calc', 'info'])
@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated'])
def test_atom_graph_data(atoms_type, init_y_as):
atoms, nedges = get_atoms(atoms_type, init_y_as)
y_from_calc = init_y_as == 'calc'
is_stress = atoms.pbc.all()
np_graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc)
graph = AtomGraphData.from_numpy_dict(np_graph)
essential = {
'atomic_numbers': ((len(atoms),), int),
'edge_index': ((2, nedges), int),
'edge_vec': ((nedges, 3), float),
}
auxilaray = {
'x': ((len(atoms),), int),
'pos': ((len(atoms), 3), float),
'num_atoms': ((), int),
'cell_volume': ((), float),
'total_energy': ((), float),
'per_atom_energy': ((), float),
'force_of_atoms': ((len(atoms), 3), float),
'stress': ((1, 6), float),
}
for k, (shape, dtype) in essential.items():
assert k in graph, f'{k} missing in graph'
assert isinstance(
graph[k], torch.Tensor
), f'{k}: {type(graph[k])} is not an tensor'
assert graph[k].is_floating_point() == (dtype is float)
assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}'
for k, (shape, dtype) in auxilaray.items():
if k not in graph:
continue
assert isinstance(
graph[k], torch.Tensor
), f'{k}: {type(graph[k])} is not an tensor'
assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}'
if not is_stress and k == 'stress':
assert torch.isnan(graph[k]).all()
else:
assert graph[k].is_floating_point() == (dtype is float)
def test_graph_build():
"""
Compare parallel implementation, should preserve order
"""
atoms_list = [
get_atoms(t, 'calc')[0] # type: ignore
for t in list(_samples.keys())
]
one_core = dl.graph_build(atoms_list, cutoff, num_cores=1, y_from_calc=True)
two_core = dl.graph_build(atoms_list, cutoff, num_cores=2, y_from_calc=True)
assert len(one_core) == len(two_core)
for g1, g2 in zip(one_core, two_core):
assert set(g1.keys()) == set(g2.keys())
for k in g1.keys():
if not isinstance(g1[k], torch.Tensor):
continue
if k == 'stress': # TODO: robust way to test it
assert torch.allclose(g1[k], g2[k]) or (
torch.isnan(g1[k]).all() == torch.isnan(g2[k]).all()
)
else:
assert torch.allclose(g1[k], g2[k])
@pytest.fixture(scope='module')
def graph_dataset_tuple():
tmpdir = os.getenv('TMPDIR', '/tmp')
randstr = uuid.uuid4().hex
assert os.access(tmpdir, os.W_OK), f'{tmpdir} is not writable'
root = tmpdir
files = f'{root}/{randstr}.extxyz'
atoms_list = [
get_atoms(atype, 'calc')[0] # type: ignore
for atype in ['bulk', 'mol', 'isolated']
]
ase.io.write(files, atoms_list, 'extxyz')
dataset = ds.SevenNetGraphDataset(
cutoff=cutoff,
root=root,
files=files,
processed_name=f'{randstr}.pt',
)
assert os.path.isfile(f'{root}/sevenn_data/{randstr}.pt'), 'dataset not written'
return dataset, atoms_list
def test_sevenn_graph_dataset_properties(graph_dataset_tuple):
dataset, atoms_list = graph_dataset_tuple
species = set()
natoms = Counter()
elist = []
e_per_list = []
flist = []
slist = []
for at in atoms_list:
chems = at.get_chemical_symbols()
species.update(chems)
natoms.update(chems)
elist.append(at.get_potential_energy())
e_per_list.append(at.get_potential_energy() / len(at))
flist.extend(at.get_forces())
try:
slist.append(at.get_stress())
except NotImplementedError:
slist.append(np.full(6, np.nan))
elist = np.array(elist)
e_per_list = np.array(e_per_list)
flist = np.array(flist)
slist = np.array(slist)
natoms['total'] = sum([cnt for cnt in list(natoms.values())])
assert set(dataset.species) == species
assert dataset.natoms == natoms
assert np.allclose(dataset.per_atom_energy_mean, e_per_list.mean())
assert np.allclose(dataset.force_rms, np.sqrt((flist**2).mean()))
def test_sevenn_graph_dataset_elemwise_energies(graph_dataset_tuple):
logger = logging.getLogger(__name__)
dataset, atoms_list = graph_dataset_tuple
ref_e = dataset.elemwise_reference_energies
assert len(ref_e) == NUM_UNIV_ELEMENT
z_set = set()
for atoms in atoms_list:
inferred_e = 0
atomic_numbers = atoms.get_atomic_numbers()
z_set.update(atomic_numbers)
for z in atomic_numbers:
inferred_e += ref_e[z]
# it never be same, but should be similar
logger.info('elemwise energy should be similar:')
logger.info(f'{inferred_e:4f} {atoms.get_potential_energy()[0]:4f}')
for z in range(NUM_UNIV_ELEMENT):
if z not in z_set:
assert ref_e[z] == 0
def test_sevenn_graph_dataset_statistics(graph_dataset_tuple):
dataset, atoms_list = graph_dataset_tuple
elist = []
e_per_list = []
flist = []
slist = []
for at in atoms_list:
elist.append(at.get_potential_energy())
e_per_list.append(at.get_potential_energy() / len(at))
flist.extend(at.get_forces())
try:
slist.append(at.get_stress())
except NotImplementedError:
slist.append(np.full(6, np.nan))
dct = {
'total_energy': np.array(elist),
'per_atom_energy': np.array(e_per_list),
'force_of_atoms': np.array(flist).flatten(),
# 'stress': np.array(slist), # TODO: it may have nan
}
for key in dct:
assert np.allclose(dataset.statistics[key]['mean'], dct[key].mean()), key
assert np.allclose(dataset.statistics[key]['std'], dct[key].std(ddof=0)), key
assert np.allclose(
dataset.statistics[key]['median'], np.median(dct[key])
), key
assert np.allclose(dataset.statistics[key]['max'], dct[key].max()), key
assert np.allclose(dataset.statistics[key]['min'], dct[key].min()), key
def test_sevenn_mm_dataset_statistics(tmp_path):
files = osp.join(tmp_path, 'gd_one.extxyz')
atoms_list1 = [
get_atoms(atype, 'calc')[0] # type: ignore
for atype in ['bulk', 'bulk', 'bulk', 'bulk']
]
ase.io.write(files, atoms_list1, 'extxyz')
gd1 = ds.SevenNetGraphDataset(
cutoff=cutoff,
root=tmp_path,
files=files,
processed_name='gd_one.pt',
)
files = osp.join(tmp_path, 'gd_two.extxyz')
atoms_list2 = [
get_atoms(atype, 'calc')[0] # type: ignore
for atype in ['mol', 'mol', 'bulk']
]
ase.io.write(files, atoms_list2, 'extxyz')
gd2 = ds.SevenNetGraphDataset(
cutoff=cutoff,
root=tmp_path,
files=files,
processed_name='gd_two.pt',
)
ref = ds.SevenNetGraphDataset(
cutoff=cutoff,
root=tmp_path,
files=[gd1.processed_paths[0], gd2.processed_paths[0]],
processed_name='combined.pt',
)
mm = modal_dataset.SevenNetMultiModalDataset(
{'modal1': gd1, 'modal2': gd2}
)
assert np.allclose(ref.per_atom_energy_mean, mm.per_atom_energy_mean['total'])
assert np.allclose(ref.avg_num_neigh, mm.avg_num_neigh['total'])
assert np.allclose(ref.force_rms, mm.force_rms['total'])
assert set(ref.species) == set(mm.species['total'])
@pytest.mark.parametrize(
'a_types,init_ys', [(['bulk', 'mol', 'isolated'], ['calc', 'calc', 'calc'])]
)
def test_7net_graph_dataset_batch_shape(a_types, init_ys, tmp_path):
assert len(a_types) == len(init_ys)
n_graph = len(a_types)
atoms_list = []
tot_edges = 0
tot_atoms = 0
for a_type, init_y in zip(a_types, init_ys):
atoms, n_edge = get_atoms(a_type, init_y)
tot_edges += n_edge
tot_atoms += len(atoms)
atoms_list.append(atoms)
ase.io.write(tmp_path / 'tmp', atoms_list, format='extxyz')
dataset = ds.SevenNetGraphDataset(cutoff, tmp_path, str(tmp_path / 'tmp'))
loader = DataLoader(dataset, batch_size=n_graph)
graph = next(iter(loader))
essential = {
'x': ((tot_atoms,), int),
'atomic_numbers': ((tot_atoms,), int),
'pos': ((tot_atoms, 3), float),
'edge_index': ((2, tot_edges), int),
'edge_vec': ((tot_edges, 3), float),
'total_energy': ((n_graph,), float),
'force_of_atoms': ((tot_atoms, 3), float),
'cell_volume': ((n_graph,), float),
'num_atoms': ((n_graph,), int),
'per_atom_energy': ((n_graph,), float),
'stress': ((n_graph, 6), float),
'batch': ((tot_atoms,), int), # from PyG
}
for k, (shape, dtype) in essential.items():
assert k in graph, f'{k} missing in graph'
assert isinstance(
graph[k], torch.Tensor
), f'{k}: {type(graph[k])} is not an tensor'
assert graph[k].is_floating_point() == (dtype is float)
assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}'
@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated', 'small_bulk'])
def test_graph_build_ase_and_matscipy(atoms_type):
atoms, _ = get_atoms(atoms_type, 'calc')
atoms.rattle()
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()
# graph build check
# ase graph build
edge_src_ase, edge_dst_ase, edge_vec_ase, shifts_ase = dl._graph_build_ase(
cutoff, pbc, cell, pos
)
# matscipy graph build
edge_src_matsci, edge_dst_matsci, edge_vec_matsci, shifts_matsci = (
dl._graph_build_matscipy(cutoff, pbc, cell, pos)
)
# sort the graph
sorted_indices_ase = np.lexsort(
(edge_vec_ase[:, 2], edge_vec_ase[:, 1], edge_vec_ase[:, 0])
)
sorted_indices_matsci = np.lexsort(
(edge_vec_matsci[:, 2], edge_vec_matsci[:, 1], edge_vec_matsci[:, 0])
)
sorted_vec_ase = edge_vec_ase[sorted_indices_ase]
sorted_vec_matsci = edge_vec_matsci[sorted_indices_matsci]
sorted_src_ase = edge_src_ase[sorted_indices_ase]
sorted_dst_ase = edge_dst_ase[sorted_indices_ase]
sorted_src_matsci = edge_src_matsci[sorted_indices_matsci]
sorted_dst_matsci = edge_dst_matsci[sorted_indices_matsci]
sorted_shift_ase = shifts_ase[sorted_indices_ase]
sorted_shift_matsci = shifts_matsci[sorted_indices_matsci]
# compare the result
assert np.allclose(sorted_vec_ase, sorted_vec_matsci)
assert np.array_equal(sorted_src_ase, sorted_src_matsci)
assert np.array_equal(sorted_dst_ase, sorted_dst_matsci)
assert np.array_equal(sorted_shift_ase, sorted_shift_matsci)
# energy test
model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024'))
model.eval()
model.set_is_batch_data(False)
# for ase energy
edge_idx_ase = np.array([edge_src_ase, edge_dst_ase])
atomic_numbers = atoms.get_atomic_numbers()
cell = np.array(cell)
vol = dl._correct_scalar(atoms.cell.volume)
if vol == 0:
vol = np.array(np.finfo(float).eps)
data_ase = {
KEY.NODE_FEATURE: atomic_numbers,
KEY.ATOMIC_NUMBERS: atomic_numbers,
KEY.POS: pos,
KEY.EDGE_IDX: edge_idx_ase,
KEY.EDGE_VEC: edge_vec_ase,
KEY.CELL: cell,
KEY.CELL_SHIFT: shifts_ase,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)),
}
data_ase[KEY.INFO] = {}
atom_graph_data_ase = AtomGraphData.from_numpy_dict(data_ase)
output_ase = model(atom_graph_data_ase)
ase_pred_energy = output_ase[KEY.PRED_TOTAL_ENERGY]
ase_pred_force = output_ase[KEY.PRED_FORCE]
ase_pred_stress = output_ase[KEY.PRED_STRESS]
# for matsci energy
edge_idx_matsci = np.array([edge_src_matsci, edge_dst_matsci])
atomic_numbers = atoms.get_atomic_numbers()
cell = np.array(cell)
vol = dl._correct_scalar(atoms.cell.volume)
if vol == 0:
vol = np.array(np.finfo(float).eps)
data_matsci = {
KEY.NODE_FEATURE: atomic_numbers,
KEY.ATOMIC_NUMBERS: atomic_numbers,
KEY.POS: pos,
KEY.EDGE_IDX: edge_idx_matsci,
KEY.EDGE_VEC: edge_vec_matsci,
KEY.CELL: cell,
KEY.CELL_SHIFT: shifts_matsci,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)),
}
data_matsci[KEY.INFO] = {}
atom_graph_data_matsci = AtomGraphData.from_numpy_dict(data_matsci)
output_matsci = model(atom_graph_data_matsci)
matsci_pred_energy = output_matsci[KEY.PRED_TOTAL_ENERGY]
matsci_pred_force = output_matsci[KEY.PRED_FORCE]
matsci_pred_stress = output_matsci[KEY.PRED_STRESS]
assert torch.equal(ase_pred_energy, matsci_pred_energy)
assert torch.allclose(ase_pred_force, matsci_pred_force, atol=1e-06)
assert torch.allclose(ase_pred_stress, matsci_pred_stress)
# test_errors: error recorder.py, loss.py
from copy import deepcopy
import numpy as np
import pytest
import torch
import torch.nn
from torch import tensor
import sevenn.error_recorder as erc
import sevenn.train.loss as loss
from sevenn.atom_graph_data import AtomGraphData
from sevenn.train.optim import loss_dict
_default_config = {
'loss': 'mse',
'loss_param': {},
'error_record': [
('Energy', 'RMSE'),
('Force', 'RMSE'),
('Stress', 'RMSE'),
('Energy', 'MAE'),
('Force', 'MAE'),
('Stress', 'MAE'),
('TotalLoss', 'None'),
],
'is_train_stress': True,
'force_loss_weight': 1.0,
'stress_loss_weight': 0.001,
}
_erc_test_params = [
('TotalEnergy', 4, 3),
('Energy', 4, 3),
('Force', 4, 3),
('Stress', 4, 3),
('Stress_GPa', 4, 3),
('Energy', 4, 1),
('Energy', 1, 1),
('Force', 1, 3),
('Stress', 1, 3),
]
def acl(a, b):
return torch.allclose(a, b, atol=1e-6)
def config(**overwrite): # to make it read-only
cf = deepcopy(_default_config)
for k, v in overwrite.items():
cf[k] = v
return cf
def test_per_atom_energy_loss():
loss_f = loss.PerAtomEnergyLoss(criterion=torch.nn.MSELoss())
ref = torch.rand(2)
pred = torch.rand(2)
natoms = torch.randint(1, 10, (2,))
tmp = AtomGraphData(
total_energy=ref,
inferred_total_energy=pred,
num_atoms=natoms,
).to_dict()
ret = loss_f.get_loss(tmp)
assert loss_f.criterion is not None
assert torch.allclose(loss_f.criterion((ref / natoms), (pred / natoms)), ret)
def test_force_loss():
loss_f = loss.ForceLoss(criterion=torch.nn.MSELoss())
ref = torch.rand((4, 3))
pred = torch.rand((4, 3))
batch = tensor([0, 0, 0, 1])
tmp = AtomGraphData(
force_of_atoms=ref,
inferred_force=pred,
batch=batch,
).to_dict()
ret = loss_f.get_loss(tmp)
assert loss_f.criterion is not None
assert torch.allclose(loss_f.criterion(ref.reshape(-1), pred.reshape(-1)), ret)
def test_stress_loss():
loss_f = loss.StressLoss(criterion=torch.nn.MSELoss())
ref = torch.rand((2, 6))
pred = torch.rand((2, 6))
tmp = AtomGraphData(
stress=ref,
inferred_stress=pred,
).to_dict()
ret = loss_f.get_loss(tmp)
KB = 1602.1766208
assert loss_f.criterion is not None
assert torch.allclose(
loss_f.criterion(ref.reshape(-1) * KB, pred.reshape(-1) * KB), ret
)
@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)])
def test_loss_from_config(conf):
loss_functions = loss.get_loss_functions_from_config(conf)
if conf['is_train_stress']:
assert len(loss_functions) == 3
else:
assert len(loss_functions) == 2
for loss_def, w in loss_functions:
assert isinstance(loss_def, loss.LossDefinition)
if isinstance(loss_def, loss.PerAtomEnergyLoss):
assert w == 1.0
elif isinstance(loss_def, loss.ForceLoss):
assert w == conf['force_loss_weight']
elif isinstance(loss_def, loss.StressLoss):
assert w == conf['stress_loss_weight']
else:
raise ValueError(f'Unexpected loss function: {loss_def}')
@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params)
def test_rms_error(err_type, ndata, natoms):
err_dct = erc.get_err_type(err_type)
err = erc.RMSError(**err_dct)
ref = torch.rand((ndata, err.vdim)).squeeze(1)
pred = torch.rand((ndata, err.vdim)).squeeze(1)
natoms = torch.tensor([natoms] * ndata)
_data = {
err_dct['ref_key']: ref,
err_dct['pred_key']: pred,
'num_atoms': natoms,
}
tmp = AtomGraphData(**_data)
err.update(tmp)
_ref = ref * err.coeff
_pred = pred * err.coeff
if 'per_atom' in err_dct and err_dct['per_atom']:
# natoms = natoms.unsqueeze(-1)
_ref = _ref / natoms
_pred = _pred / natoms
val = torch.sqrt(((_ref - _pred) ** 2).sum() / ndata) # not ndata*natoms
assert np.allclose(err.get(), val.item())
err.update(tmp)
assert np.allclose(err.get(), val.item())
@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params)
def test_mae_error(err_type, ndata, natoms):
err_dct = erc.get_err_type(err_type)
vdim = err_dct['vdim']
err = erc.MAError(**err_dct)
ref = torch.rand((ndata, vdim)).squeeze(1)
pred = torch.rand((ndata, vdim)).squeeze(1)
natoms = torch.tensor([natoms] * ndata)
_data = {
err_dct['ref_key']: ref,
err_dct['pred_key']: pred,
'num_atoms': natoms,
}
tmp = AtomGraphData(**_data)
err.update(tmp)
_ref = ref * err.coeff
_pred = pred * err.coeff
if 'per_atom' in err_dct and err_dct['per_atom']:
_ref /= natoms
_pred /= natoms
val = abs(_ref - _pred).sum() / (ndata * vdim)
assert np.allclose(err.get(), val.item())
err.update(tmp)
assert np.allclose(err.get(), val.item())
# TODO: test_component_rms_error
@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params)
def test_custom_error(err_type, ndata, natoms):
def func(a, b):
return a * b
err_dct = erc.get_err_type(err_type)
vdim = err_dct['vdim']
err = erc.CustomError(func, **err_dct)
ref = torch.rand((ndata, vdim)).squeeze(1)
pred = torch.rand((ndata, vdim)).squeeze(1)
natoms = torch.tensor([natoms] * ndata)
_data = {
err_dct['ref_key']: ref,
err_dct['pred_key']: pred,
'num_atoms': natoms,
}
_ref = ref * err.coeff
_pred = pred * err.coeff
if 'per_atom' in err_dct and err_dct['per_atom']:
_ref /= natoms
_pred /= natoms
tmp = AtomGraphData(**_data)
err.update(tmp)
val = func(_ref, _pred).mean()
assert np.allclose(err.get(), val.item())
err.update(tmp)
assert np.allclose(err.get(), val.item())
@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)])
def test_total_loss_metric_from_config(conf):
def func(a, b):
return a * b
err = erc.ErrorRecorder.init_total_loss_metric(conf, func)
ndata = 3
natoms = 4
e1, e2 = torch.rand(ndata), torch.rand(ndata)
f1, f2 = torch.rand(ndata * natoms, 3), torch.rand(ndata * natoms, 3)
s1, s2 = torch.rand((ndata, 6)), torch.rand((ndata, 6))
_data = {
'total_energy': e1,
'inferred_total_energy': e2,
'force_of_atoms': f1,
'inferred_force': f2,
'stress': s1,
'inferred_stress': s2,
'num_atoms': torch.tensor([natoms] * ndata),
}
tmp = AtomGraphData(**_data)
err.update(tmp)
val = (func(e1 / natoms, e2 / natoms)).mean() + conf['force_loss_weight'] * func(
f1, f2
).mean()
if conf['is_train_stress']:
KB = 1602.1766208
val += conf['stress_loss_weight'] * func(s1 * KB, s2 * KB).mean()
assert np.allclose(err.get(), val.item())
err.update(tmp)
assert np.allclose(err.get(), val.item())
@pytest.mark.parametrize(
'conf', [config(), config(is_train_stress=False), config(loss='huber')]
)
def test_error_recorder_from_config(conf):
recorder = erc.ErrorRecorder.from_config(conf)
total_loss_flag = False
for metric in recorder.metrics:
if conf['is_train_stress'] is False:
assert 'stress' not in metric.name
if metric.name == 'TotalLoss':
total_loss_flag = True
for loss_metric, _ in metric.metrics: # type: ignore
assert isinstance(loss_metric.func, loss_dict[conf['loss']])
assert total_loss_flag
@pytest.mark.parametrize(
'conf', [config(), config(is_train_stress=False), config(loss='huber')]
)
def test_error_recorder_from_config_and_loss_functions(conf):
loss_functions = loss.get_loss_functions_from_config(conf)
recorder = erc.ErrorRecorder.from_config(conf, loss_functions)
total_loss_flag = False
for metric in recorder.metrics:
if conf['is_train_stress'] is False:
assert 'stress' not in metric.name
if metric.name == 'TotalLoss':
total_loss_flag = True
for loss_metric, _ in metric.metrics: # type: ignore
assert isinstance(
loss_metric.loss_def.criterion, loss_dict[conf['loss']]
)
assert total_loss_flag
# # deploy is test on lammps
# test append modality
# from no modality model to modality yes model
# from modality model to more modality model
# different shift scale settings
# test modality options (check num param)
# calculators with modality
import copy
# + modal checkpoint continue and test_train
# + sevenn_cp test things in test_cli
import pathlib
import pytest
from ase.build import bulk
import sevenn.train.graph_dataset as graph_ds
import sevenn.util as util
from sevenn.calculator import SevenNetCalculator
from sevenn.model_build import build_E3_equivariant_model
cutoff = 5.0
data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve()
hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz')
sevennet_0_path = util.pretrained_name_to_path('7net-0_11July2024')
@pytest.fixture(scope='module')
def graph_dataset_path(tmp_path_factory):
gd_path = tmp_path_factory.mktemp('gd')
ds = graph_ds.SevenNetGraphDataset(
cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt'
)
return ds.processed_paths[0]
_modal_cfg = {
'use_modal_node_embedding': False,
'use_modal_self_inter_intro': True,
'use_modal_self_inter_outro': True,
'use_modal_output_block': True,
'use_modality': True,
'use_modal_wise_shift': True, # T/F should be tested
'use_modal_wise_scale': False, # T/F should be tested
'load_trainset_path': [
{
'data_modality': 'modal_new',
'file_list': [{'file': hfo2_path}],
}
],
}
@pytest.fixture(scope='module')
def snet_0_cp():
return util.load_checkpoint(sevennet_0_path)
@pytest.fixture(scope='module')
def snet_0_calc():
return SevenNetCalculator()
@pytest.fixture()
def bulk_atoms():
atoms = bulk('Si') * 3
atoms.rattle()
return atoms
def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6):
import numpy as np
def acl(a, b, rtol=rtol, atol=atol):
return np.allclose(a, b, rtol=rtol, atol=atol)
assert len(atoms1) == len(atoms2)
assert acl(atoms1.get_cell(), atoms2.get_cell())
assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy())
assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10)
assert acl(
atoms1.get_stress(voigt=False),
atoms2.get_stress(voigt=False),
rtol * 10,
atol * 10,
)
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
def get_modal_cfg(overwrite=None):
modal_cfg = copy.deepcopy(_modal_cfg).copy()
if overwrite:
modal_cfg.update(overwrite)
return modal_cfg
@pytest.mark.parametrize(
'cfg_overwrite',
[
({}),
({'use_modal_wise_scale': True}),
({'use_modal_wise_shift': False}),
({'use_modal_self_inter_intro': False}),
],
)
def test_append_modal_sevennet_0(
cfg_overwrite,
snet_0_cp,
snet_0_calc,
bulk_atoms,
graph_dataset_path,
tmp_path,
):
modal_cfg = snet_0_cp.config
modal_cfg.pop('load_dataset_path')
modal_cfg.pop('load_validset_path')
modal_cfg.update(get_modal_cfg(cfg_overwrite))
modal_cfg['shift'] = 'elemwise_reference_energies'
modal_cfg['scale'] = 'per_atom_energy_std'
modal_cfg['load_trainset_path'][0]['file_list'] = [{'file': graph_dataset_path}]
new_state_dict = snet_0_cp.append_modal(
modal_cfg, original_modal_name='pbe', working_dir=tmp_path
)
sevennet_0_w_modal = build_E3_equivariant_model(modal_cfg)
sevennet_0_w_modal.load_state_dict(new_state_dict, strict=True)
atoms1 = bulk_atoms
atoms2 = copy.deepcopy(atoms1)
atoms1.calc = snet_0_calc
atoms2.calc = SevenNetCalculator(
model=sevennet_0_w_modal, file_type='model_instance', modal='pbe'
)
assert_atoms(atoms1, atoms2)
import pytest
import torch
from ase.build import bulk, molecule
from ase.data import chemical_symbols
from torch_geometric.loader.dataloader import Collater
import sevenn.train.dataload as dl
from sevenn.atom_graph_data import AtomGraphData
from sevenn.model_build import build_E3_equivariant_model
from sevenn.nn.sequential import AtomGraphSequential
from sevenn.util import chemical_species_preprocess
cutoff = 4.0
_samples = {
'bulk': bulk('NaCl', 'rocksalt', a=5.63),
'mol': molecule('H2O'),
'isolated': molecule('H'),
}
n_samples = len(_samples)
n_atoms_total = sum([len(at) for at in _samples.values()])
_graph_list = [
AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(at, cutoff))
for at in list(_samples.values())
]
def test_chemical_species_preprocess():
chems = ['He', 'H', 'Be', 'H']
cf = chemical_species_preprocess(chems, universal=False)
assert cf['chemical_species'] == ['Be', 'H', 'He']
assert cf['_number_of_species'] == 3
assert cf['_type_map'] == {4: 0, 1: 1, 2: 2}
cf = chemical_species_preprocess(chems, universal=True)
assert cf['chemical_species'] == chemical_symbols
assert cf['_number_of_species'] == len(chemical_symbols)
assert len(cf['_type_map']) == len(chemical_symbols)
for z, node_idx in cf['_type_map'].items():
assert z == node_idx
def get_graphs(batched):
cloned = [g.clone() for g in _graph_list]
if not batched:
return cloned
else:
return Collater(cloned)(cloned)
def get_model_config():
config = {
'cutoff': cutoff,
'channel': 4,
'radial_basis': {
'radial_basis_name': 'bessel',
},
'cutoff_function': {'cutoff_function_name': 'poly_cut'},
'interaction_type': 'nequip',
'lmax': 2,
'is_parity': True,
'num_convolution_layer': 3,
'weight_nn_hidden_neurons': [64, 64],
'act_radial': 'silu',
'act_scalar': {'e': 'silu', 'o': 'tanh'},
'act_gate': {'e': 'silu', 'o': 'tanh'},
'conv_denominator': 30.0,
'train_denominator': False,
'self_connection_type': 'nequip',
'shift': -10.0,
'scale': 10.0,
'train_shift_scale': False,
'irreps_manual': False,
'lmax_edge': -1,
'lmax_node': -1,
'readout_as_fcn': False,
'use_bias_in_linear': False,
'_normalize_sph': True,
}
chems = set()
for at in list(_samples.values()):
chems.update(at.get_chemical_symbols())
config.update(**chemical_species_preprocess(list(chems)))
return config
def get_model(config_overwrite={}):
cf = get_model_config()
cf.update(**config_overwrite)
model = build_E3_equivariant_model(cf, parallel=False)
assert isinstance(model, AtomGraphSequential)
return model
@pytest.mark.parametrize('batched', [False, True])
@pytest.mark.parametrize('cf', [{}])
def test_shape(cf, batched):
model = get_model(cf)
model.set_is_batch_data(batched)
graph = get_graphs(batched)
if not batched:
output_shapes = {
'inferred_total_energy': (),
'inferred_stress': (6,),
}
for g in graph:
natoms = g['num_atoms']
output_shapes.update(
{
'atomic_energy': (natoms, 1), # intended
'inferred_force': (natoms, 3),
}
)
output = model(g)
for k, shape in output_shapes.items():
assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}'
else:
output_shapes = {
'inferred_total_energy': (n_samples,),
'atomic_energy': (n_atoms_total, 1), # intended
'inferred_force': (n_atoms_total, 3),
'inferred_stress': (n_samples, 6),
}
output = model(graph)
for k, shape in output_shapes.items():
assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}'
def test_batch():
model = get_model()
model.set_is_batch_data(False)
graph_list = get_graphs(batched=False)
output_list = [model(g) for g in graph_list]
model.set_is_batch_data(True)
graph_batch = get_graphs(batched=True)
output_batched = model(graph_batch)
e_concat = torch.concat(
[g['inferred_total_energy'].unsqueeze(-1) for g in output_list]
)
ae_concat = torch.concat([g['atomic_energy'].squeeze(1) for g in output_list])
f_concat = torch.concat([g['inferred_force'] for g in output_list])
s_concat = torch.stack([g['inferred_stress'] for g in output_list])
assert torch.allclose(e_concat, output_batched['inferred_total_energy'])
assert torch.allclose(ae_concat, output_batched['atomic_energy'].squeeze(1))
assert torch.allclose(
torch.round(f_concat, decimals=5),
torch.round(output_batched['inferred_force'], decimals=5),
atol=1e-5,
)
assert torch.allclose( # TODO, hard-coded, assumes the first structure is bulk
torch.round(s_concat[0], decimals=5),
torch.round(output_batched['inferred_stress'][0], decimals=5),
)
_n_param_tests = [
({}, 20642),
({'train_denominator': True}, 20642 + 3),
({'train_shift_scale': True}, 20642 + 2),
({'shift': [1.0] * 4}, 20642),
({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8),
({'num_convolution_layer': 4}, 33458),
({'lmax': 3}, 26866),
({'channel': 2}, 16883),
({'is_parity': False}, 20386),
({'self_connection_type': 'linear'}, 20114),
]
@pytest.mark.parametrize('cf,ref', _n_param_tests)
def test_num_params(cf, ref):
model = get_model(cf)
param = sum([p.numel() for p in model.parameters() if p.requires_grad])
assert param == ref, f'ref: {ref} != given: {param}'
_n_modal_param_tests = [
({}, 20642),
({'use_modal_node_embedding': True}, 20642 + 8),
({'use_modal_self_inter_intro': True}, 20642 + 2 * 4 * 3),
({'use_modal_self_inter_outro': True}, 20642 + 2 * (12 + 20 + 4)),
({'use_modal_output_block': True}, 20642 + 2 * 4 / 2),
]
@pytest.mark.parametrize('cf,ref', _n_modal_param_tests)
def test_modal_num_params(cf, ref):
modal_cfg = {
'use_modality': True,
'_number_of_modalities': 2,
'_modal_map': {'x1': 0, 'x2': 1},
'use_modal_node_embedding': False,
'use_modal_self_inter_intro': False,
'use_modal_self_inter_outro': False,
'use_modal_output_block': False,
'use_modal_wise_shift': False,
'use_modal_wise_scale': False,
}
modal_cfg.update(cf)
model = get_model(modal_cfg)
param = sum([p.numel() for p in model.parameters() if p.requires_grad])
assert param == ref, f'ref: {ref} != given: {param}'
# TODO: test_irreps, test_gard, test_equivariance
# test_pretrained: output consistency for pretrained models
import pytest
import torch
from ase.build import bulk, molecule
import sevenn._keys as KEY
from sevenn.atom_graph_data import AtomGraphData
from sevenn.train.dataload import unlabeled_atoms_to_graph
from sevenn.util import model_from_checkpoint, pretrained_name_to_path
def acl(a, b, atol=1e-6):
return torch.allclose(a, b, atol=atol)
@pytest.fixture
def atoms_pbc():
atoms1 = bulk('NaCl', 'rocksalt', a=5.63)
atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]])
atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]])
return atoms1
@pytest.fixture
def atoms_mol():
atoms2 = molecule('H2O')
atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]])
return atoms2
def test_7net0_22May2024(atoms_pbc, atoms_mol):
"""
Reference from v0.9.3.post1 with SevenNetCalculator
"""
cp_path = pretrained_name_to_path('7net-0_22May2024')
model, config = model_from_checkpoint(cp_path)
cutoff = config['cutoff']
g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff))
g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff))
model.set_is_batch_data(False)
g1 = model(g1)
g2 = model(g2)
g1_ref_e = torch.tensor([-3.4140868186950684])
g1_ref_f = torch.tensor(
[
[1.2628037e01, 7.5093508e-03, 1.3480943e-02],
[-1.2628037e01, -7.5093508e-03, -1.3480917e-02],
]
)
g1_ref_s = -1 * torch.tensor(
[-0.65014917, -0.01990843, -0.02000658, 0.03286226, 0.00589222, 0.03291973]
)
g2_ref_e = torch.tensor([-12.808363914489746])
g2_ref_f = torch.tensor(
[
[9.31322575e-10, -1.30241165e01, 6.93116236e00],
[-1.39698386e-09, 9.28001022e00, -9.51867390e00],
[5.23868948e-10, 3.74410582e00, 2.58751225e00],
]
)
assert acl(g1.inferred_total_energy, g1_ref_e)
assert acl(g1.inferred_force, g1_ref_f)
assert acl(g1.inferred_stress, g1_ref_s)
assert acl(g2.inferred_total_energy, g2_ref_e)
assert acl(g2.inferred_force, g2_ref_f)
def test_7net0_11July2024(atoms_pbc, atoms_mol):
"""
Reference from v0.9.3.post1 with SevenNetCalculator
"""
cp_path = pretrained_name_to_path('7net-0_11July2024')
model, config = model_from_checkpoint(cp_path)
cutoff = config['cutoff']
g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff))
g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff))
model.set_is_batch_data(False)
g1 = model(g1)
g2 = model(g2)
model.set_is_batch_data(True)
g1_ref_e = torch.tensor([-3.779199])
g1_ref_f = torch.tensor(
[
[12.666697, 0.04726403, 0.04775861],
[-12.666697, -0.04726403, -0.04775861],
]
)
g1_ref_s = -1 * torch.tensor(
# xx, yy, zz, xy, yz, zx
[-0.6439122, -0.03643947, -0.03643981, 0.04543639, 0.00599139, 0.04544507]
)
g2_ref_e = torch.tensor([-12.782808303833008])
g2_ref_f = torch.tensor(
[
[0.0, -1.3619621e01, 7.5937047e00],
[0.0, 9.3918495e00, -1.0172190e01],
[0.0, 4.2277718e00, 2.5784855e00],
]
)
assert acl(g1.inferred_total_energy, g1_ref_e)
assert acl(g1.inferred_force, g1_ref_f)
assert acl(g1.inferred_stress, g1_ref_s)
assert acl(g2.inferred_total_energy, g2_ref_e)
assert acl(g2.inferred_force, g2_ref_f)
def test_7net_l3i5(atoms_pbc, atoms_mol):
"""
Reference from v0.9.3.post1 with SevenNetCalculator
"""
cp_path = pretrained_name_to_path('7net-l3i5')
model, config = model_from_checkpoint(cp_path)
cutoff = config['cutoff']
g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff))
g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff))
model.set_is_batch_data(False)
g1 = model(g1)
g2 = model(g2)
model.set_is_batch_data(True)
g1_ref_e = torch.tensor([-3.611131191253662])
g1_ref_f = torch.tensor(
[
[13.430887, 0.08655541, 0.08754013],
[-13.430886, -0.08655544, -0.08754011],
]
)
g1_ref_s = -1 * torch.tensor(
# xx, yy, zz, xy, yz, zx
[-0.6818918, -0.04104544, -0.04107663, 0.04794561, 0.00565416, 0.04793138]
)
g2_ref_e = torch.tensor([-12.700481414794922])
g2_ref_f = torch.tensor(
[
[0.0, -1.4547814e01, 8.1347866],
[0.0, 1.0308369e01, -1.0880318e01],
[0.0, 4.2394452, 2.7455316],
]
)
assert acl(g1.inferred_total_energy, g1_ref_e)
assert acl(g1.inferred_force, g1_ref_f, 1e-5)
assert acl(g1.inferred_stress, g1_ref_s, 1e-5)
assert acl(g2.inferred_total_energy, g2_ref_e)
assert acl(g2.inferred_force, g2_ref_f)
def test_7net_mf_0(atoms_pbc, atoms_mol):
cp_path = pretrained_name_to_path('7net-mf-0')
model, config = model_from_checkpoint(cp_path)
cutoff = config['cutoff']
g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff))
g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff))
g1[KEY.DATA_MODALITY] = 'R2SCAN'
g2[KEY.DATA_MODALITY] = 'R2SCAN'
model.set_is_batch_data(False)
g1 = model(g1)
g2 = model(g2)
model.set_is_batch_data(True)
g1_ref_e = torch.tensor([-11.607587814331055])
g1_ref_f = torch.tensor(
[
[8.512259, 0.07307914, 0.06676716],
[-8.512257, -0.07307915, -0.06676716],
]
)
g1_ref_s = -1 * torch.tensor(
# xx, yy, zz, xy, yz, zx
[-0.4516204, -0.02483013, -0.02485001, 0.03247492, 0.00259375, 0.03250402]
)
g2_ref_e = torch.tensor([-14.172412872314453])
g2_ref_f = torch.tensor(
[
[4.6566129e-10, -1.3429364e01, 6.9344816e00],
[2.3283064e-09, 8.9132404e00, -9.6807365e00],
[-2.7939677e-09, 4.5161238e00, 2.7462559e00],
]
)
assert acl(g1.inferred_total_energy, g1_ref_e)
assert acl(g1.inferred_force, g1_ref_f)
assert acl(g1.inferred_stress, g1_ref_s)
assert acl(g2.inferred_total_energy, g2_ref_e)
assert acl(g2.inferred_force, g2_ref_f)
def test_7net_mf_ompa_mpa(atoms_pbc, atoms_mol):
cp_path = pretrained_name_to_path('7net-mf-ompa')
model, config = model_from_checkpoint(cp_path)
cutoff = config['cutoff']
g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff))
g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff))
# mpa
g1[KEY.DATA_MODALITY] = 'mpa'
g2[KEY.DATA_MODALITY] = 'mpa'
model.set_is_batch_data(False)
g1 = model(g1)
g2 = model(g2)
model.set_is_batch_data(True)
g1_ref_e = torch.tensor([-3.490943193435669])
g1_ref_f = torch.tensor(
[
[1.2680445e01, -2.7985498e-04, -2.7979910e-04],
[-1.2680446e01, 2.7984008e-04, 2.7981028e-04],
]
)
g1_ref_s = -1 * torch.tensor(
# xx, yy, zz, xy, yz, zx
[-0.6481662, -0.02462837, -0.02462837, 0.02693467, 0.00459635, 0.02693467]
)
g2_ref_e = torch.tensor([-12.597525596618652])
g2_ref_f = torch.tensor(
[
[0.0, -12.245223, 7.26795],
[0.0, 8.816763, -9.423925],
[0.0, 3.4284601, 2.1559749],
]
)
assert acl(g1.inferred_total_energy, g1_ref_e)
assert acl(g1.inferred_force, g1_ref_f)
assert acl(g1.inferred_stress, g1_ref_s)
assert acl(g2.inferred_total_energy, g2_ref_e)
assert acl(g2.inferred_force, g2_ref_f)
def test_7net_mf_ompa_omat(atoms_pbc, atoms_mol):
cp_path = pretrained_name_to_path('7net-mf-ompa')
model, config = model_from_checkpoint(cp_path)
cutoff = config['cutoff']
g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff))
g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff))
# mpa
g1[KEY.DATA_MODALITY] = 'omat24'
g2[KEY.DATA_MODALITY] = 'omat24'
model.set_is_batch_data(False)
g1 = model(g1)
g2 = model(g2)
model.set_is_batch_data(True)
g1_ref_e = torch.tensor([-3.5094668865203857])
g1_ref_f = torch.tensor(
[
[1.2562084e01, -1.4219694e-03, -1.4219843e-03],
[-1.2562084e01, 1.4219508e-03, 1.4219955e-03],
]
)
g1_ref_s = -1 * torch.tensor(
# xx, yy, zz, xy, yz, zx
[-0.6430905, -0.0254128, -0.02541281, 0.0268343, 0.00460021, 0.0268343]
)
g2_ref_e = torch.tensor([-12.6202974319458])
g2_ref_f = torch.tensor(
[
[0.0, -12.205926, 7.2050343],
[0.0, 8.790399, -9.368677],
[0.0, 3.4155273, 2.163643],
]
)
assert acl(g1.inferred_total_energy, g1_ref_e)
assert acl(g1.inferred_force, g1_ref_f)
assert acl(g1.inferred_stress, g1_ref_s)
assert acl(g2.inferred_total_energy, g2_ref_e)
assert acl(g2.inferred_force, g2_ref_f)
def test_7net_omat(atoms_pbc, atoms_mol):
cp_path = pretrained_name_to_path('7net-omat')
model, config = model_from_checkpoint(cp_path)
cutoff = config['cutoff']
g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff))
g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff))
model.set_is_batch_data(False)
g1 = model(g1)
g2 = model(g2)
model.set_is_batch_data(True)
g1_ref_e = torch.tensor([-3.5033323764801025])
g1_ref_f = torch.tensor(
[
[12.533154, 0.02358698, 0.02358694],
[-12.533153, -0.02358699, -0.02358697],
]
)
g1_ref_s = -1 * torch.tensor(
# xx, yy, zz, xy, yz, zx
[-0.6420925, -0.02781446, -0.02781446, 0.02575445, 0.00381664, 0.02575445]
)
g2_ref_e = torch.tensor([-12.403768539428711])
g2_ref_f = torch.tensor(
[
[0, -12.848297, 7.11432],
[0.0, 9.265477, -9.564951],
[0.0, 3.58282, 2.4506311],
]
)
assert acl(g1.inferred_total_energy, g1_ref_e)
assert acl(g1.inferred_force, g1_ref_f)
assert acl(g1.inferred_stress, g1_ref_s)
assert acl(g2.inferred_total_energy, g2_ref_e)
assert acl(g2.inferred_force, g2_ref_f)
import pytest
import torch
import sevenn._keys as KEY
from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType
from sevenn.nn.scale import (
ModalWiseRescale,
Rescale,
SpeciesWiseRescale,
get_resolved_shift_scale,
)
################################################################################
# Tests for Rescale #
################################################################################
@pytest.mark.parametrize('shift,scale', [(0.0, 1.0), (1.0, 2.0), (-5.0, 10.0)])
def test_rescale_init(shift, scale):
"""
Test that Rescale can be initialized properly without errors
and that parameters are set correctly.
"""
module = Rescale(shift=shift, scale=scale)
assert module.shift.item() == shift
assert module.scale.item() == scale
assert module.key_input == KEY.SCALED_ATOMIC_ENERGY
assert module.key_output == KEY.ATOMIC_ENERGY
def test_rescale_forward():
"""
Test that Rescale forward pass correctly applies:
output = input * scale + shift
"""
# Setup
shift, scale = 1.0, 2.0
module = Rescale(shift=shift, scale=scale)
# Make some fake data
input_data = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float)
data: AtomGraphDataType = {KEY.SCALED_ATOMIC_ENERGY: input_data.clone()}
# Forward
out_data = module(data)
# Check correctness
expected_output = input_data * scale + shift
assert torch.allclose(out_data[KEY.ATOMIC_ENERGY], expected_output)
def test_rescale_get_shift_and_scale():
"""
Test get_shift() and get_scale() methods in Rescale.
"""
module = Rescale(shift=1.5, scale=3.5)
assert module.get_shift() == pytest.approx(1.5)
assert module.get_scale() == pytest.approx(3.5)
################################################################################
# Tests for SpeciesWiseRescale #
################################################################################
def test_specieswise_rescale_init_float():
"""
Test SpeciesWiseRescale when both shift and scale are floats
(should expand to same length lists).
"""
module = SpeciesWiseRescale(shift=[1.0, -1.0], scale=2.0)
# Expect a parameter of length = 1 in this scenario, but can differ
# if we raise an error for "Both shift and scale is not a list".
# Usually, you'd specify a known number of species or do from_mappers.
# The code as-is throws ValueError if both are float. Let's do from_mappers:
# We'll do direct init if your code allows it. If not, use from_mappers.
assert module.shift.shape == module.scale.shape
# They must be single-parameter (or expanded) if not from mappers.
def test_specieswise_rescale_init_list():
"""
Test initialization with list-based shift/scale of same length.
"""
shift = [1.0, 2.0, 3.0]
scale = [2.0, 3.0, 4.0]
module = SpeciesWiseRescale(shift=shift, scale=scale)
assert len(module.shift) == 3
assert len(module.scale) == 3
assert torch.allclose(module.shift, torch.tensor([1.0, 2.0, 3.0]))
assert torch.allclose(module.scale, torch.tensor([2.0, 3.0, 4.0]))
def test_specieswise_rescale_forward():
"""
Test that SpeciesWiseRescale forward pass applies:
output[i] = input[i]*scale[atom_type[i]] + shift[atom_type[i]]
"""
# Suppose we have two species types:
# 0 -> shift=1, scale=2, 1 -> shift=5, scale=10
# (we'll pass them as lists in the correct order)
shift = [1.0, 5.0]
scale = [2.0, 10.0]
module = SpeciesWiseRescale(
shift=shift,
scale=scale,
data_key_in='in',
data_key_out='out',
data_key_indices='z',
)
# Create mock data
# Suppose we have three atoms: species => [0, 1, 0]
# input => [ [1.], [1.], [3.] ]
data: AtomGraphDataType = {
'z': torch.tensor([0, 1, 0], dtype=torch.long),
'in': torch.tensor([[1.0], [1.0], [3.0]], dtype=torch.float),
}
out = module(data)
# Now let's manually compute expected:
# For atom 0: scale=2, shift=1, input=1 => 1*2+1=3
# For atom 1: scale=10, shift=5, input=1 => 1*10+5=15
# For atom 2: scale=2, shift=1, input=3 => 3*2+1=7
expected = torch.tensor([[3.0], [15.0], [7.0]])
assert torch.allclose(out['out'], expected)
def test_specieswise_rescale_get_shift_scale():
"""
Test get_shift() and get_scale() with/without type_map.
"""
shift = [1.0, 2.0]
scale = [3.0, 4.0]
module = SpeciesWiseRescale(shift=shift, scale=scale)
# Without type_map
# Should return the raw parameter values (list form).
s = module.get_shift()
sc = module.get_scale()
assert s == [1.0, 2.0]
assert sc == [3.0, 4.0]
# With a type_map (example: atomic_number 1 -> 0, 8 -> 1)
type_map = {1: 0, 8: 1} # hydrogen, oxygen
s_univ = module.get_shift(type_map)
sc_univ = module.get_scale(type_map)
# In this small example with NUM_UNIV_ELEMENT = 2, the _as_univ will produce
# a list of length = NUM_UNIV_ELEMENT. If your real NUM_UNIV_ELEMENT is bigger,
# the rest would be padded with default values.
# For demonstration let's assume it returns [1.0, 2.0].
# Check at least the known mapped portion:
assert len(s_univ) == NUM_UNIV_ELEMENT
assert len(sc_univ) == NUM_UNIV_ELEMENT
assert s_univ[1] == 1.0 # atomic_number=1 -> idx=0 -> shift=1.0
assert s_univ[8] == 2.0
################################################################################
# Tests for ModalWiseRescale #
################################################################################
def test_modalwise_rescale_init():
"""
Basic sanity check for ModalWiseRescale initialization with
certain shapes.
"""
# Suppose we have 2 modals, 3 species => shift, scale is shape [2,3]
shift = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
scale = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
module = ModalWiseRescale(
shift=shift,
scale=scale,
use_modal_wise_shift=True,
use_modal_wise_scale=True,
)
# Check shape
assert module.shift.shape == torch.Size([2, 3])
assert module.scale.shape == torch.Size([2, 3])
def test_modalwise_rescale_forward():
"""
Test that the forward pass of ModalWiseRescale matches
output[i] = input[i] * scale[modal_i, atom_i] + shift[modal_i, atom_i]
when both use_modal_wise_{shift,scale} are True.
"""
shift = [[0.0, 10.0], [5.0, 15.0]] # shape [2 (modals), 2 (species)]
scale = [[1.0, 2.0], [10.0, 20.0]]
module = ModalWiseRescale(
shift=shift,
scale=scale,
data_key_in='in',
data_key_out='out',
data_key_modal_indices='modal_idx',
data_key_atom_indices='atom_idx',
use_modal_wise_shift=True,
use_modal_wise_scale=True,
)
data: AtomGraphDataType = {
'in': torch.tensor([[1.0], [1.0], [2.0], [2.0]]),
'modal_idx': torch.tensor([0, 1], dtype=torch.long),
'atom_idx': torch.tensor([0, 1, 0, 1], dtype=torch.long),
'batch': torch.tensor([0, 0, 1, 1], dtype=torch.long),
}
out = module(data)
# i=0 => modal_idx=0, atom_idx=0 => shift=0.0, scale=1.0 => out=1*1+0=1
# i=1 => modal_idx=0, atom_idx=1 => shift=10.0, scale=2.0 => out=1*2+10=12
# i=2 => modal_idx=1, atom_idx=0 => shift=5.0, scale=10.0 => out=2*10+5=25
# i=3 => modal_idx=1, atom_idx=1 => shift=15.0, scale=20.0 => out=2*20+15=55
expected = torch.tensor([[1.0], [12.0], [25.0], [55.0]])
assert torch.allclose(out['out'], expected)
def test_modalwise_rescale_get_shift_scale():
"""
Test get_shift() and get_scale() with type_map and modal_map.
"""
# Setup
shift = [[0.0, 10.0], [5.0, 15.0]]
scale = [[1.0, 2.0], [10.0, 20.0]]
mod = ModalWiseRescale(
shift=shift,
scale=scale,
use_modal_wise_shift=True,
use_modal_wise_scale=True,
)
# Suppose we have type_map and modal_map
type_map = {1: 0, 8: 1} # Example: H->0, O->1
modal_map = {'a': 0, 'b': 1}
# get_shift, get_scale
s = mod.get_shift(type_map=type_map, modal_map=modal_map)
sc = mod.get_scale(type_map=type_map, modal_map=modal_map)
# Expect dict with keys "ambient", "pressure".
# Example: s["ambient"] = [ shift(0,0), shift(0,1) ] mapped to H,O
# s["pressure"] = [ shift(1,0), shift(1,1) ]
assert isinstance(s, dict) and isinstance(sc, dict)
assert set(s.keys()) == {'a', 'b'}
assert set(sc.keys()) == {'a', 'b'}
################################################################################
# Tests for get_resolved_shift_scale function #
################################################################################
def test_get_resolved_shift_scale_rescale():
"""
Test get_resolved_shift_scale for a Rescale instance.
"""
from_m = Rescale(shift=2.0, scale=5.0)
shift, scale = get_resolved_shift_scale(from_m)
assert shift == 2.0
assert scale == 5.0
def test_get_resolved_shift_scale_specieswise():
"""
Test get_resolved_shift_scale for a SpeciesWiseRescale instance.
"""
shift_list = [1.0, 2.0]
scale_list = [3.0, 4.0]
module = SpeciesWiseRescale(shift=shift_list, scale=scale_list)
type_map = {1: 0, 8: 1}
s, sc = get_resolved_shift_scale(module, type_map=type_map)
# The result should be extended to NUM_UNIV_ELEMENT length in real usage,
# but at least the first few should match shift_list, scale_list mapped.
assert isinstance(s, list)
assert isinstance(sc, list)
# Check mapped values
assert s[1] == shift_list[0]
assert s[8] == shift_list[1]
assert sc[1] == scale_list[0]
assert sc[8] == scale_list[1]
def test_get_resolved_shift_scale_modalwise():
"""
Test get_resolved_shift_scale for a ModalWiseRescale instance.
"""
shift = [[0.0, 10.0], [5.0, 15.0]]
scale = [[1.0, 2.0], [10.0, 20.0]]
mmod = ModalWiseRescale(
shift=shift,
scale=scale,
use_modal_wise_shift=True,
use_modal_wise_scale=True,
)
type_map = {1: 0, 8: 1}
modal_map = {'a': 0, 'b': 1}
s, sc = get_resolved_shift_scale(mmod, type_map=type_map, modal_map=modal_map)
# We expect dictionaries
assert isinstance(s, dict) and isinstance(sc, dict)
# Keys "a", "pressure"
assert 'a' in s
assert 'b' in s
# Check one example
# s["a"] => [0.0, 10.0]
# sc["a"] => [1.0, 2.0]
assert s['a'][1] == 0.0
assert s['a'][8] == 10.0
assert sc['a'][1] == 1.0
assert sc['a'][8] == 2.0
################################################################################
# Tests for from_mappers function #
################################################################################
@pytest.mark.parametrize(
'shift, scale, type_map, expected_shift, expected_scale',
[
# Both shift and scale are floats -> broadcast to each species
(
2.0,
3.0,
{1: 0, 8: 1}, # e.g., H -> index 0, O -> index 1
[2.0, 2.0], # broadcast
[3.0, 3.0],
),
# shift, scale are same-length lists => directly used
(
[0.5, 0.6],
[1.0, 1.1],
{1: 0, 8: 1},
[0.5, 0.6],
[1.0, 1.1],
),
# shift, scale are entire "universal" length (NUM_UNIV_ELEMENT=118),
# but we only map out the subset for the actual species in type_map
(
[0.1] * NUM_UNIV_ELEMENT,
[1.1] * NUM_UNIV_ELEMENT,
{1: 0, 8: 1},
[0.1, 0.1],
[1.1, 1.1],
),
# shift is a list, scale is float => shift is used directly, scale broadcast
(
[1.0, 2.0],
5.0,
{6: 0, 14: 1}, # C -> 0, Si -> 1
[1.0, 2.0],
[5.0, 5.0],
),
],
)
def test_specieswise_rescale_from_mappers(
shift, scale, type_map, expected_shift, expected_scale
):
"""
Test SpeciesWiseRescale.from_mappers with various combinations of
shift/scale (float, list, universal list) and a given type_map.
"""
module = SpeciesWiseRescale.from_mappers( # type: ignore
shift=shift,
scale=scale,
type_map=type_map,
)
# Check that the module's internal shift and scale have the correct shape
# The length must match number of species in type_map
assert module.shift.shape[0] == len(type_map)
assert module.scale.shape[0] == len(type_map)
# Check that the content matches expected
actual_shift = module.shift.detach().cpu().tolist()
actual_scale = module.scale.detach().cpu().tolist()
assert pytest.approx(actual_shift) == expected_shift
assert pytest.approx(actual_scale) == expected_scale
@pytest.mark.parametrize(
'shift, scale, use_modal_wise_shift, use_modal_wise_scale, '
'type_map, modal_map, expected_shift, expected_scale',
[
# Example 1: single float for shift/scale,
# broadcast over 2 modals and 2 species
(
1.0,
2.0,
True, # shift depends on modal
True, # scale depends on modal
{1: 0, 8: 1},
{'modA': 0, 'modB': 1},
# expect 2D => [2 modals x 2 species]
[[1.0, 1.0], [1.0, 1.0]],
[[2.0, 2.0], [2.0, 2.0]],
),
# Example 2: shift/scale are universal element-lists => use_modal=False => 1D
(
[0.5] * NUM_UNIV_ELEMENT,
[1.5] * NUM_UNIV_ELEMENT,
False, # shift is not modal-wise
False, # scale is not modal-wise
{6: 0, 14: 1}, # e.g. C->0, Si->1
{'modA': 0, 'modB': 1},
# 1D => length = n_atom_types(=2)
[0.5, 0.5],
[1.5, 1.5],
),
# Example 3: shift is dict of modals -> each is float
# => broadcast for each species
(
{'modA': 0.0, 'modB': 2.0},
{'modA': 1.0, 'modB': 3.0},
True,
True,
{1: 0, 8: 1},
{'modA': 0, 'modB': 1},
# shift => shape [2 modals, 2 species]
[[0.0, 0.0], [2.0, 2.0]],
[[1.0, 1.0], [3.0, 3.0]],
),
# Example 4: already in "modal-wise + species-wise" shape, direct pass
(
[[0.0, 10.0], [5.0, 15.0]],
[[1.0, 2.0], [10.0, 20.0]],
True,
True,
{1: 0, 8: 1},
{'modA': 0, 'modB': 1},
[[0.0, 10.0], [5.0, 15.0]],
[[1.0, 2.0], [10.0, 20.0]],
),
# Example 5: shift is a list of floats (one per modal),
# but we want modal-wise => broadcast for each species
(
[0.0, 10.0], # length=2 => same as #modals
[1.0, 2.0],
True,
True,
{1: 0, 8: 1},
{'modA': 0, 'modB': 1},
[[0.0, 0.0], [10.0, 10.0]],
[[1.0, 1.0], [2.0, 2.0]],
),
],
)
def test_modalwise_rescale_from_mappers(
shift,
scale,
use_modal_wise_shift,
use_modal_wise_scale,
type_map,
modal_map,
expected_shift,
expected_scale,
):
"""
Test ModalWiseRescale.from_mappers for different shapes of shift/scale,
combined with type_map and modal_map.
"""
module = ModalWiseRescale.from_mappers( # type: ignore
shift=shift,
scale=scale,
use_modal_wise_shift=use_modal_wise_shift,
use_modal_wise_scale=use_modal_wise_scale,
type_map=type_map,
modal_map=modal_map,
)
# Check shape of the resulting shift, scale
# If modal-wise, we expect a 2D shape: [n_modals, n_species]
# Otherwise, a 1D shape: [n_species]
if use_modal_wise_shift:
assert module.shift.dim() == 2
assert module.shift.shape[0] == len(modal_map)
assert module.shift.shape[1] == len(type_map)
else:
assert module.shift.dim() == 1
assert module.shift.shape[0] == len(type_map)
# Similarly for scale
if use_modal_wise_scale:
assert module.scale.dim() == 2
assert module.scale.shape[0] == len(modal_map)
assert module.scale.shape[1] == len(type_map)
else:
assert module.scale.dim() == 1
assert module.scale.shape[0] == len(type_map)
# Verify the content matches our expectation
actual_shift = module.shift.detach().cpu().tolist()
actual_scale = module.scale.detach().cpu().tolist()
assert actual_shift == expected_shift
assert actual_scale == expected_scale
import pathlib
import ase.io
import numpy as np
import pytest
import torch
from torch_geometric.loader import DataLoader
import sevenn.train.graph_dataset as graph_ds
from sevenn._const import NUM_UNIV_ELEMENT
from sevenn.error_recorder import ErrorRecorder
from sevenn.logger import Logger
from sevenn.scripts.processing_continue import processing_continue_v2
from sevenn.scripts.processing_epoch import processing_epoch_v2
from sevenn.train.dataload import graph_build
from sevenn.train.graph_dataset import from_config as dataset_from_config
from sevenn.train.loss import get_loss_functions_from_config
from sevenn.train.trainer import Trainer
from sevenn.util import (
chemical_species_preprocess,
get_error_recorder,
pretrained_name_to_path,
)
cutoff = 4.0
data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve()
hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz')
cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth')
sevennet_0_path = pretrained_name_to_path('7net-0_11July2024')
known_elements = ['Hf', 'O']
_elemwise_ref_energy_dct = {72: -17.379337, 8: -34.7499924}
Logger() # init
@pytest.fixture()
def HfO2_atoms():
atoms = ase.io.read(hfo2_path)
return atoms
@pytest.fixture(scope='module')
def HfO2_loader():
atoms = ase.io.read(hfo2_path, index=':')
assert isinstance(atoms, list)
graphs = graph_build(atoms, cutoff, y_from_calc=True)
return DataLoader(graphs, batch_size=2)
@pytest.fixture(scope='module')
def graph_dataset_path(tmp_path_factory):
gd_path = tmp_path_factory.mktemp('gd')
ds = graph_ds.SevenNetGraphDataset(
cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt'
)
return ds.processed_paths[0]
def get_model_config():
config = {
'cutoff': cutoff,
'channel': 4,
'radial_basis': {
'radial_basis_name': 'bessel',
},
'cutoff_function': {'cutoff_function_name': 'poly_cut'},
'interaction_type': 'nequip',
'lmax': 2,
'is_parity': True,
'num_convolution_layer': 3,
'weight_nn_hidden_neurons': [64, 64],
'act_radial': 'silu',
'act_scalar': {'e': 'silu', 'o': 'tanh'},
'act_gate': {'e': 'silu', 'o': 'tanh'},
'conv_denominator': 'avg_num_neigh',
'train_denominator': False,
'self_connection_type': 'nequip',
'train_shift_scale': False,
'irreps_manual': False,
'lmax_edge': -1,
'lmax_node': -1,
'readout_as_fcn': False,
'use_bias_in_linear': False,
'_normalize_sph': True,
}
config.update(**chemical_species_preprocess(known_elements))
return config
def get_train_config():
config = {
'random_seed': 1,
'epoch': 2,
'loss': 'mse',
'loss_param': {},
'optimizer': 'adam',
'optim_param': {},
'scheduler': 'exponentiallr',
'scheduler_param': {'gamma': 0.99},
'force_loss_weight': 1.0,
'stress_loss_weight': 0.1,
'per_epoch': 1,
'continue': {
'checkpoint': False,
'reset_optimizer': False,
'reset_scheduler': False,
'reset_epoch': False,
},
'is_train_stress': True,
'train_shuffle': True,
'best_metric': 'TotalLoss',
'error_record': [
('Energy', 'RMSE'),
('Force', 'RMSE'),
('Stress', 'RMSE'),
('TotalLoss', 'None'),
],
'use_modality': False,
'use_weight': False,
'device': 'cpu',
'is_ddp': False,
}
return config
def get_data_config():
config = {
'batch_size': 2,
'shift': 'per_atom_energy_mean',
'scale': 'force_rms',
'preprocess_num_cores': 1,
'data_format_args': {},
'load_trainset_path': hfo2_path,
}
return config
def get_config(overwrite=None):
cf = {}
cf.update(get_model_config())
cf.update(get_train_config())
cf.update(get_data_config())
if overwrite:
cf.update(overwrite)
return cf
def test_processing_continue_v2_7net0(tmp_path):
cp = torch.load(sevennet_0_path, weights_only=False, map_location='cpu')
cfg = get_config(
{
'continue': {
'checkpoint': sevennet_0_path,
'reset_optimizer': False,
'reset_scheduler': True,
'reset_epoch': False,
}
}
)
shift_ref = cp['model_state_dict']['rescale_atomic_energy.shift'].cpu().numpy()
scale_ref = np.array([1.73] * 89)
conv_denominator_ref = np.array([35.989574] * 5)
with Logger().switch_file(str(tmp_path / 'log.sevenn')):
state_dicts, epoch = processing_continue_v2(cfg)
assert epoch == 601
assert np.allclose(np.array(cfg['shift']), shift_ref)
assert np.allclose(np.array(cfg['shift'])[0], -5.062768)
assert np.allclose(np.array(cfg['scale']), scale_ref)
assert np.allclose(np.array(cfg['conv_denominator']), conv_denominator_ref)
assert cfg['_number_of_species'] == 89
assert cfg['_type_map'][89] == 0 # Ac
assert cfg['_type_map'][40] == 88 # Zr
assert state_dicts[2] is None # scheduler reset
@pytest.mark.parametrize(
'cfg_overwrite,ds_names',
[
({}, ['trainset']),
({'load_myset_path': hfo2_path}, ['trainset', 'myset']),
],
)
def test_dataset_from_config(cfg_overwrite, ds_names, tmp_path):
cfg = get_config(cfg_overwrite)
with Logger().switch_file(str(tmp_path / 'log.sevenn')):
datasets = dataset_from_config(cfg, tmp_path)
assert set(ds_names) == set(datasets.keys())
for ds_name in ds_names:
assert (tmp_path / 'sevenn_data' / f'{ds_name}.pt').is_file()
assert (tmp_path / 'sevenn_data' / f'{ds_name}.yaml').is_file()
def test_dataset_from_config_as_it_is_load(graph_dataset_path, tmp_path):
cfg = get_config({'load_trainset_path': graph_dataset_path})
new_wd = tmp_path / 'tmp_wd'
with Logger().switch_file(str(tmp_path / 'log.sevenn')):
_ = dataset_from_config(cfg, str(new_wd))
print((tmp_path / 'tmp_wd' / 'sevenn_data'))
assert not (tmp_path / 'tmp_wd' / 'sevenn_data').is_dir()
@pytest.mark.parametrize(
'cfg_overwrite,shift,scale,conv',
[
(
{},
-28.978,
0.113304,
25.333333,
),
(
{
'shift': -1.2345678,
},
-1.234567,
0.113304,
25.333333,
),
(
{
'conv_denominator': 'sqrt_avg_num_neigh',
},
-28.978,
0.113304,
25.333333**0.5,
),
(
{
'shift': 'force_rms',
},
0.113304,
0.113304,
25.333333,
),
(
{
'shift': 'elemwise_reference_energies',
},
[
0.0
if z not in _elemwise_ref_energy_dct
else _elemwise_ref_energy_dct[z]
for z in range(NUM_UNIV_ELEMENT)
],
0.113304,
25.333333,
),
],
)
def test_dataset_from_config_statistics_init(
cfg_overwrite, shift, scale, conv, tmp_path
):
cfg = get_config(cfg_overwrite)
with Logger().switch_file(str(tmp_path / 'log.sevenn')):
_ = dataset_from_config(cfg, tmp_path)
assert np.allclose(cfg['shift'], shift)
assert np.allclose(cfg['scale'], scale)
assert np.allclose(cfg['conv_denominator'], conv)
def test_dataset_from_config_chem_auto(tmp_path):
cfg = get_config(
{
'chemical_species': 'auto',
'_number_of_species': 'auto',
'_type_map': 'auto',
}
)
with Logger().switch_file(str(tmp_path / 'log.sevenn')):
_ = dataset_from_config(cfg, tmp_path)
assert cfg['chemical_species'] == ['Hf', 'O']
assert cfg['_number_of_species'] == 2
assert cfg['_type_map'] == {72: 0, 8: 1}
def test_run_one_epoch(HfO2_loader):
trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path)
trainer = Trainer(**trainer_args)
erc = get_error_recorder()
ref1 = {
'Energy_RMSE': '28.977758',
'Force_RMSE': '0.214107',
'Stress_RMSE': '190.014237',
}
ref2 = {
'Energy_RMSE': '28.977878',
'Force_RMSE': '0.213105',
'Stress_RMSE': '188.772557',
}
trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc)
ret1 = erc.get_dct()
erc.epoch_forward()
for k in ref1:
assert np.allclose(float(ret1[k]), float(ref1[k]))
trainer.run_one_epoch(HfO2_loader, is_train=True, error_recorder=erc)
erc.epoch_forward()
trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc)
ret2 = erc.get_dct()
erc.epoch_forward()
for k in ref2:
assert np.allclose(float(ret2[k]), float(ref2[k]))
def test_processing_epoch_v2(HfO2_loader, tmp_path):
trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path)
trainer = Trainer(**trainer_args)
erc = get_error_recorder()
start_epoch = 10
total_epoch = 12
per_epoch = 1
best_metric = 'Energy_RMSE'
best_metric_loader_key = 'myset'
loaders = {'trainset': HfO2_loader, 'myset': HfO2_loader}
with Logger().switch_file(str(tmp_path / 'log.sevenn')):
processing_epoch_v2(
config={},
trainer=trainer,
loaders=loaders,
start_epoch=start_epoch,
error_recorder=erc,
total_epoch=total_epoch,
per_epoch=per_epoch,
best_metric_loader_key=best_metric_loader_key,
best_metric=best_metric,
working_dir=tmp_path,
)
assert (tmp_path / 'checkpoint_10.pth').is_file()
assert (tmp_path / 'checkpoint_11.pth').is_file()
assert (tmp_path / 'checkpoint_12.pth').is_file()
assert (tmp_path / 'checkpoint_best.pth').is_file()
assert (tmp_path / 'lc.csv').is_file()
with open(tmp_path / 'lc.csv', 'r') as f:
lines = f.readlines()
heads = [ll.strip() for ll in lines[0].split(',')]
assert 'epoch' in heads
assert 'lr' in heads
assert 'trainset_Energy_RMSE' in heads
assert 'myset_Stress_MAE' in heads
lasts = [ll.strip() for ll in lines[-1].split(',')]
assert lasts[0] == '12'
assert lasts[1] == '0.000980' # lr
assert lasts[-2] == '0.087873' # myset Force MAE
def test_data_weight(graph_dataset_path, tmp_path):
cfg = get_config(
{
'load_trainset_path': [{
'file_list': [{'file': graph_dataset_path}],
'data_weight': {'energy': 0.1, 'force': 3.0, 'stress': 1.0},
}],
'error_record': [
('Energy', 'Loss'),
('Force', 'Loss'),
('Stress', 'Loss'),
('TotalLoss', 'None'),
],
'use_weight': True
}
)
trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path)
trainer_args['loss_functions'] = get_loss_functions_from_config(cfg)
trainer = Trainer(**trainer_args)
erc = ErrorRecorder.from_config(cfg, trainer.loss_functions)
db = graph_ds.from_config(cfg, working_dir=tmp_path)['trainset']
loader_w_weight = DataLoader(db, batch_size=len(db))
trainer.run_one_epoch(loader_w_weight, False, erc)
loss = erc.epoch_forward()
assert np.allclose(loss['Energy_Loss'], 839.7104492 * 0.1)
assert np.allclose(loss['Force_Loss'], 0.0152806 * 3.0)
assert np.allclose(loss['Stress_Loss'], 6017.568847 * 1.0)
def _write_empty_checkpoint():
from sevenn.model_build import build_E3_equivariant_model
# Function I used to make empty checkpoint, to write the test
cfg = get_config({'shift': 0.0, 'scale': 1.0, 'conv_denominator': 5.0})
model = build_E3_equivariant_model(cfg)
trainer = Trainer.from_config(model, cfg) # type: ignore
trainer.write_checkpoint('./cp_0.pth', config=cfg, epoch=0)
if __name__ == '__main__':
_write_empty_checkpoint()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment