Unverified Commit f6ef4ebb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add unit test for model ensemble (#14)

parent ba3036d1
import unittest
import pickle
import os
import torch
import torchani
path = os.path.dirname(os.path.realpath(__file__))
N = 97
class TestEnsemble(unittest.TestCase):
def setUp(self):
self.tol = 1e-5
self.conformations = 20
def _test_molecule(self, coordinates, species):
prefix = torchani.buildin_model_prefix
n = torchani.buildin_ensembles
aev = torchani.SortedAEV(device=torch.device('cpu'))
coordinates, species = aev.sort_by_species(coordinates, species)
ensemble = torchani.ModelOnAEV(aev, derivative=True,
from_nc=prefix,
ensemble=n)
models = [torchani.ModelOnAEV(aev, derivative=True,
from_nc=prefix + '{}/networks/'.format(i)) for i in range(n)]
energy1, force1 = ensemble(coordinates, species)
energy2, force2 = zip(*[m(coordinates, species) for m in models])
energy2 = sum(energy2) / n
force2 = sum(force2) / n
energy_diff = (energy1 - energy2).abs().max().item()
force_diff = (force1 - force2).abs().max().item()
self.assertLess(energy_diff, self.tol)
self.assertLess(force_diff, self.tol)
def testGDB(self):
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _ = pickle.load(f)
self._test_molecule(coordinates, species)
if __name__ == '__main__':
unittest.main()
...@@ -2,9 +2,10 @@ from .energyshifter import EnergyShifter ...@@ -2,9 +2,10 @@ from .energyshifter import EnergyShifter
from .nn import ModelOnAEV, PerSpeciesFromNeuroChem from .nn import ModelOnAEV, PerSpeciesFromNeuroChem
from .aev import SortedAEV from .aev import SortedAEV
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, default_dtype, default_device buildin_model_prefix, buildin_ensembles, default_dtype, default_device
__all__ = ['SortedAEV', 'EnergyShifter', 'ModelOnAEV', __all__ = ['SortedAEV', 'EnergyShifter', 'ModelOnAEV',
'PerSpeciesFromNeuroChem', 'data', 'buildin_const_file', 'PerSpeciesFromNeuroChem', 'data', 'buildin_const_file',
'buildin_sae_file', 'buildin_network_dir', 'buildin_dataset_dir', 'buildin_sae_file', 'buildin_network_dir', 'buildin_dataset_dir',
'buildin_model_prefix', 'default_dtype', 'default_device'] 'buildin_model_prefix', 'buildin_ensembles', 'default_dtype',
'default_device']
...@@ -14,5 +14,7 @@ buildin_network_dir = pkg_resources.resource_filename( ...@@ -14,5 +14,7 @@ buildin_network_dir = pkg_resources.resource_filename(
buildin_model_prefix = pkg_resources.resource_filename( buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train') __name__, 'resources/ani-1x_dft_x8ens/train')
buildin_ensembles = 8
default_dtype = torch.float32 default_dtype = torch.float32
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment