test_jit_builtin_models.py 1.22 KB
Newer Older
1
2
3
4
import torch
import torchani
import unittest
import os
Jinze Xue's avatar
Jinze Xue committed
5
from torchani.testing import TestCase
6
7
8
9
10
11


path = os.path.dirname(os.path.realpath(__file__))
dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')


Jinze Xue's avatar
Jinze Xue committed
12
class TestBuiltinModelsJIT(TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
13
14
    # Tests if JIT compiled models have the same output energies
    # as eager (non JIT) models
15
16

    def setUp(self):
Gao, Xiang's avatar
Gao, Xiang committed
17
18
19
        # in general self energies should be subtracted, and shuffle should be
        # performed, but for these tests this is not important
        self.ds = torchani.data.load(dspath).species_to_indices().collate(256).cache()
20
21

    def _test_model(self, model):
22
23
24
25
        properties = next(iter(self.ds))
        input_ = (properties['species'], properties['coordinates'].float())
        _, e = model(input_)
        _, e2 = torch.jit.script(model)(input_)
26
        self.assertEqual(e, e2)
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    def _test_ensemble(self, ensemble):
        self._test_model(ensemble)
        for m in ensemble:
            self._test_model(m)

    def testANI1x(self):
        ani1x = torchani.models.ANI1x()
        self._test_ensemble(ani1x)

    def testANI1ccx(self):
        ani1ccx = torchani.models.ANI1ccx()
        self._test_ensemble(ani1ccx)


if __name__ == '__main__':
    unittest.main()