test_jit_builtin_models.py 1.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torchani
import unittest
import os


path = os.path.dirname(os.path.realpath(__file__))
dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 256
chunk_threshold = 5
other_properties = {'properties': ['energies'],
                    'padding_values': [None],
                    'padded_shapes': [(batch_size, )],
                    'dtypes': [torch.float64],
                    }


class TestBuiltinModelsJIT(unittest.TestCase):

    def setUp(self):
        self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
                                              chunk_threshold=chunk_threshold,
                                              other_properties=other_properties,
                                              subtract_self_energies=True)
        self.ani1ccx = torchani.models.ANI1ccx()

    def _test_model(self, model):
        chunk = self.ds[0][0][0]
        _, e = model(chunk)
        _, e2 = torch.jit.script(model)(chunk)
        self.assertTrue(torch.allclose(e, e2))

    def _test_ensemble(self, ensemble):
        self._test_model(ensemble)
        for m in ensemble:
            self._test_model(m)

    def testANI1x(self):
        ani1x = torchani.models.ANI1x()
        self._test_ensemble(ani1x)

    def testANI1ccx(self):
        ani1ccx = torchani.models.ANI1ccx()
        self._test_ensemble(ani1ccx)


if __name__ == '__main__':
    unittest.main()