test_benchmark.py 4.25 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
2
3
4
5
6
7
8
import torch
import torchani
import unittest
import copy


class TestBenchmark(unittest.TestCase):

9
10
    def setUp(self, dtype=torchani.default_dtype,
              device=torchani.default_device):
Xiang Gao's avatar
Xiang Gao committed
11
12
13
14
15
16
17
18
        self.dtype = dtype
        self.device = device
        self.conformations = 100
        self.species = list('HHCCNNOO')
        self.coordinates = torch.randn(
            self.conformations, 8, 3, dtype=dtype, device=device)
        self.count = 100

19
    def _testModule(self, run_module, result_module, asserts):
Xiang Gao's avatar
Xiang Gao committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        keys = []
        for i in asserts:
            if '>=' in i:
                i = i.split('>=')
                keys += [i[0].strip(), i[1].strip()]
            elif '<=' in i:
                i = i.split('<=')
                keys += [i[0].strip(), i[1].strip()]
            elif '>' in i:
                i = i.split('>')
                keys += [i[0].strip(), i[1].strip()]
            elif '<' in i:
                i = i.split('<')
                keys += [i[0].strip(), i[1].strip()]
            elif '=' in i:
                i = i.split('=')
                keys += [i[0].strip(), i[1].strip()]
            else:
                keys.append(i.strip())
39
        self.assertEqual(set(result_module.timers.keys()), set(keys))
Xiang Gao's avatar
Xiang Gao committed
40
        for i in keys:
41
42
            self.assertEqual(result_module.timers[i], 0)
        old_timers = copy.copy(result_module.timers)
Xiang Gao's avatar
Xiang Gao committed
43
        for _ in range(self.count):
44
            run_module((self.species, self.coordinates))
Xiang Gao's avatar
Xiang Gao committed
45
            for i in keys:
46
                self.assertLess(old_timers[i], result_module.timers[i])
Xiang Gao's avatar
Xiang Gao committed
47
48
49
50
51
52
            for i in asserts:
                if '>=' in i:
                    i = i.split('>=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertGreaterEqual(
53
                        result_module.timers[key0], result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
54
55
56
57
58
                elif '<=' in i:
                    i = i.split('<=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertLessEqual(
59
                        result_module.timers[key0], result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
60
61
62
63
64
                elif '>' in i:
                    i = i.split('>')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertGreater(
65
                        result_module.timers[key0], result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
66
67
68
69
                elif '<' in i:
                    i = i.split('<')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
70
71
                    self.assertLess(result_module.timers[key0],
                                    result_module.timers[key1])
Xiang Gao's avatar
Xiang Gao committed
72
73
74
75
                elif '=' in i:
                    i = i.split('=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
76
77
78
79
80
                    self.assertEqual(result_module.timers[key0],
                                     result_module.timers[key1])
            old_timers = copy.copy(result_module.timers)
        result_module.reset_timers()
        self.assertEqual(set(result_module.timers.keys()), set(keys))
Xiang Gao's avatar
Xiang Gao committed
81
        for i in keys:
82
            self.assertEqual(result_module.timers[i], 0)
Xiang Gao's avatar
Xiang Gao committed
83
84
85
86

    def testAEV(self):
        aev_computer = torchani.SortedAEV(
            benchmark=True, dtype=self.dtype, device=self.device)
87
88
89
        prepare = torchani.PrepareInput(aev_computer.species, self.device)
        run_module = torch.nn.Sequential(prepare, aev_computer)
        self._testModule(run_module, aev_computer, [
Xiang Gao's avatar
Xiang Gao committed
90
91
92
                         'terms and indices>radial terms',
                         'terms and indices>angular terms',
                         'total>terms and indices',
93
94
95
                         'total>combinations', 'total>assemble',
                         'total>mask_r', 'total>mask_a'
                         ])
Xiang Gao's avatar
Xiang Gao committed
96

97
    def testANIModel(self):
Xiang Gao's avatar
Xiang Gao committed
98
99
        aev_computer = torchani.SortedAEV(
            dtype=self.dtype, device=self.device)
100
        prepare = torchani.PrepareInput(aev_computer.species, self.device)
101
        model = torchani.models.NeuroChemNNP(
102
            aev_computer.species, benchmark=True).to(self.device)
103
104
        run_module = torch.nn.Sequential(prepare, aev_computer, model)
        self._testModule(run_module, model, ['forward'])
Xiang Gao's avatar
Xiang Gao committed
105
106
107
108


if __name__ == '__main__':
    unittest.main()