test_benchmark.py 3.92 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torchani
import unittest
import copy


class TestBenchmark(unittest.TestCase):

    def setUp(self, dtype=torchani.default_dtype, device=torchani.default_device):
        self.dtype = dtype
        self.device = device
        self.conformations = 100
        self.species = list('HHCCNNOO')
        self.coordinates = torch.randn(
            self.conformations, 8, 3, dtype=dtype, device=device)
        self.count = 100

    def _testModule(self, module, asserts):
        keys = []
        for i in asserts:
            if '>=' in i:
                i = i.split('>=')
                keys += [i[0].strip(), i[1].strip()]
            elif '<=' in i:
                i = i.split('<=')
                keys += [i[0].strip(), i[1].strip()]
            elif '>' in i:
                i = i.split('>')
                keys += [i[0].strip(), i[1].strip()]
            elif '<' in i:
                i = i.split('<')
                keys += [i[0].strip(), i[1].strip()]
            elif '=' in i:
                i = i.split('=')
                keys += [i[0].strip(), i[1].strip()]
            else:
                keys.append(i.strip())
        self.assertEqual(set(module.timers.keys()), set(keys))
        for i in keys:
            self.assertEqual(module.timers[i], 0)
        old_timers = copy.copy(module.timers)
        for _ in range(self.count):
            module(self.coordinates, self.species)
            for i in keys:
                self.assertLess(old_timers[i], module.timers[i])
            for i in asserts:
                if '>=' in i:
                    i = i.split('>=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertGreaterEqual(
                        module.timers[key0], module.timers[key1])
                elif '<=' in i:
                    i = i.split('<=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertLessEqual(
                        module.timers[key0], module.timers[key1])
                elif '>' in i:
                    i = i.split('>')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertGreater(
                        module.timers[key0], module.timers[key1])
                elif '<' in i:
                    i = i.split('<')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertLess(module.timers[key0], module.timers[key1])
                elif '=' in i:
                    i = i.split('=')
                    key0 = i[0].strip()
                    key1 = i[1].strip()
                    self.assertEqual(module.timers[key0], module.timers[key1])
            old_timers = copy.copy(module.timers)
        module.reset_timers()
        self.assertEqual(set(module.timers.keys()), set(keys))
        for i in keys:
            self.assertEqual(module.timers[i], 0)

    def testAEV(self):
        aev_computer = torchani.SortedAEV(
            benchmark=True, dtype=self.dtype, device=self.device)
        self._testModule(aev_computer, [
                         'terms and indices>radial terms',
                         'terms and indices>angular terms',
                         'total>terms and indices',
88
89
90
                         'total>combinations', 'total>assemble',
                         'total>mask_r', 'total>mask_a'
                         ])
Xiang Gao's avatar
Xiang Gao committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    def testModelOnAEV(self):
        aev_computer = torchani.SortedAEV(
            dtype=self.dtype, device=self.device)
        model = torchani.ModelOnAEV(
            aev_computer, benchmark=True, from_nc=None)
        self._testModule(model, ['forward>aev', 'forward>nn'])
        model = torchani.ModelOnAEV(
            aev_computer, benchmark=True, derivative=True, from_nc=None)
        self._testModule(
            model, ['forward>aev', 'forward>nn', 'forward>derivative'])


if __name__ == '__main__':
    unittest.main()