"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1ebda734e2a9edaccd89095c4bfdacc20d693a3c"
Unverified Commit 3132928c authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add unit test for builtin models (#327)

* Add unit test for builtin models

* flake8
parent 273c9fd6
...@@ -10,7 +10,11 @@ jobs: ...@@ -10,7 +10,11 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: [3.6, 3.7] python-version: [3.6, 3.7]
test-filenames: [test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py, test_data_new.py, test_ignite.py, test_utils.py, test_ase.py, test_energies.py, test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py, test_data.py, test_forces.py, test_structure_optim.py] test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_ignite.py, test_utils.py, test_ase.py, test_energies.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v1
......
import torch
import torchani
import unittest
import os
path = os.path.dirname(os.path.realpath(__file__))
dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 256
chunk_threshold = 5
other_properties = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(batch_size, )],
'dtypes': [torch.float64],
}
class TestBuiltinModelsJIT(unittest.TestCase):
def setUp(self):
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
chunk_threshold=chunk_threshold,
other_properties=other_properties,
subtract_self_energies=True)
self.ani1ccx = torchani.models.ANI1ccx()
def _test_model(self, model):
chunk = self.ds[0][0][0]
_, e = model(chunk)
_, e2 = torch.jit.script(model)(chunk)
self.assertTrue(torch.allclose(e, e2))
def _test_ensemble(self, ensemble):
self._test_model(ensemble)
for m in ensemble:
self._test_model(m)
def testANI1x(self):
ani1x = torchani.models.ANI1x()
self._test_ensemble(ani1x)
def testANI1ccx(self):
ani1ccx = torchani.models.ANI1ccx()
self._test_ensemble(ani1ccx)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment