test_benchmark.py 3.87 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
2
3
4
5
6
7
8
import torch
import torchani
import unittest
import copy


class TestBenchmark(unittest.TestCase):

9
    def setUp(self):
Xiang Gao's avatar
Xiang Gao committed
10
        self.conformations = 100
11
12
        self.species = torch.randint(4, (self.conformations, 8),
                                     dtype=torch.long)
13
        self.coordinates = torch.randn(self.conformations, 8, 3)
Xiang Gao's avatar
Xiang Gao committed
14
15
        self.count = 100

16
    def _testModule(self, run_module, result_module, asserts):
Xiang Gao's avatar
Xiang Gao committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        keys = []
        for i in asserts:
            if '>=' in i:
                i = i.split('>=')
                keys += [i[0].strip(), i[1].strip()]
            elif '<=' in i:
                i = i.split('<=')
                keys += [i[0].strip(), i[1].strip()]
            elif '>' in i:
                i = i.split('>')
                keys += [i[0].strip(), i[1].strip()]
            elif '<' in i:
                i = i.split('<')
                keys += [i[0].strip(), i[1].strip()]
            elif '=' in i:
                i = i.split('=')
                keys += [i[0].strip(), i[1].strip()]
            else:
                keys.append(i.strip())
36
        self.assertEqual(set(result_module.timers.keys()), set(keys))
Xiang Gao's avatar
Xiang Gao committed
37
        for i in keys:
38
39
            self.assertEqual(result_module.timers[i], 0)
        old_timers = copy.copy(result_module.timers)
Xiang Gao's avatar
Xiang Gao committed
40
        for _ in range(self.count):
41
            run_module((self.species, self.coordinates))
Xiang Gao's avatar
Xiang Gao committed
42
            for i in keys:
43
                self.assertLess(old_timers[i], result_module.timers[i])
Xiang Gao's avatar
Xiang Gao committed
44
45
46
47
48
49
            for i in asserts:
                if '>=' in i:
                    i = i.split('>=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertGreaterEqual(
50
                        result_module.timers[key0], result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
51
52
53
54
55
                elif '<=' in i:
                    i = i.split('<=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertLessEqual(
56
                        result_module.timers[key0], result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
57
58
59
60
61
                elif '>' in i:
                    i = i.split('>')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertGreater(
62
                        result_module.timers[key0], result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
63
64
65
66
                elif '<' in i:
                    i = i.split('<')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
67
68
                    self.assertLess(result_module.timers[key0],
                                    result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
69
70
71
72
                elif '=' in i:
                    i = i.split('=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
73
74
75
76
77
                    self.assertEqual(result_module.timers[key0],
                                     result_module.timers[key1])
            old_timers = copy.copy(result_module.timers)
        result_module.reset_timers()
        self.assertEqual(set(result_module.timers.keys()), set(keys))
Xiang Gao's avatar
Xiang Gao committed
78
        for i in keys:
79
            self.assertEqual(result_module.timers[i], 0)
Xiang Gao's avatar
Xiang Gao committed
80
81

    def testAEV(self):
82
        aev_computer = torchani.AEVComputer(benchmark=True)
83
        self._testModule(aev_computer, aev_computer, [
Xiang Gao's avatar
Xiang Gao committed
84
85
86
                         'terms and indices>radial terms',
                         'terms and indices>angular terms',
                         'total>terms and indices',
87
88
89
                         'total>combinations', 'total>assemble',
                         'total>mask_r', 'total>mask_a'
                         ])
Xiang Gao's avatar
Xiang Gao committed
90

91
    def testANIModel(self):
92
        aev_computer = torchani.AEVComputer()
93
94
        model = torchani.models.NeuroChemNNP(aev_computer.species,
                                             benchmark=True)
95
        run_module = torch.nn.Sequential(aev_computer, model)
96
        self._testModule(run_module, model, ['forward'])
Xiang Gao's avatar
Xiang Gao committed
97
98
99
100


if __name__ == '__main__':
    unittest.main()