test_cuaev.py 9.96 KB
Newer Older
1
2
import os
import torch
3
4
import torchani
import unittest
5
import pickle
Gao, Xiang's avatar
Gao, Xiang committed
6
from torchani.testing import TestCase, make_tensor
7

8
9
10

path = os.path.dirname(os.path.realpath(__file__))

11
12
skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(),
                              'There is no device to run this test')
13
skipIfNoCUAEV = unittest.skipIf(not torchani.aev.has_cuaev, "only valid when cuaev is installed")
14
15


16
@skipIfNoCUAEV
Gao, Xiang's avatar
Gao, Xiang committed
17
class TestCUAEVNoGPU(TestCase):
18

Gao, Xiang's avatar
Gao, Xiang committed
19
    def testSimple(self):
20
21
22
23
24
        def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int):
            return torch.ops.cuaev.cuComputeAEV(coordinates, species, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
        s = torch.jit.script(f)
        self.assertIn("cuaev::cuComputeAEV", str(s.graph))

Gao, Xiang's avatar
Gao, Xiang committed
25
26
27
28
29
30
31
32
33
34
35
36
37
    def testAEVComputer(self):
        path = os.path.dirname(os.path.realpath(__file__))
        const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params')  # noqa: E501
        consts = torchani.neurochem.Constants(const_file)
        aev_computer = torchani.AEVComputer(**consts, use_cuda_extension=True)
        s = torch.jit.script(aev_computer)
        # Computation of AEV using cuaev when there is no atoms does not require CUDA, and can be run without GPU
        species = make_tensor((8, 0), 'cpu', torch.int64, low=-1, high=4)
        coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5)
        self.assertIn("cuaev::cuComputeAEV", str(s.graph_for((species, coordinates))))


@skipIfNoGPU
38
@skipIfNoCUAEV
Gao, Xiang's avatar
Gao, Xiang committed
39
class TestCUAEV(TestCase):
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    def setUp(self):
        self.tolerance = 5e-5
        self.device = 'cuda'
        Rcr = 5.2000e+00
        Rca = 3.5000e+00
        EtaR = torch.tensor([1.6000000e+01], device=self.device)
        ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=self.device)
        Zeta = torch.tensor([3.2000000e+01], device=self.device)
        ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=self.device)
        EtaA = torch.tensor([8.0000000e+00], device=self.device)
        ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=self.device)
        num_species = 4
        self.aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
        self.cuaev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species, use_cuda_extension=True)

    def testSimple(self):
        coordinates = torch.tensor([
            [[0.03192167, 0.00638559, 0.01301679],
             [-0.83140486, 0.39370209, -0.26395324],
             [-0.66518241, -0.84461308, 0.20759389],
             [0.45554739, 0.54289633, 0.81170881],
             [0.66091919, -0.16799635, -0.91037834]],
            [[-4.1862600, 0.0575700, -0.0381200],
             [-3.1689400, 0.0523700, 0.0200000],
             [-4.4978600, 0.8211300, 0.5604100],
             [-4.4978700, -0.8000100, 0.4155600],
             [0.00000000, -0.00000000, -0.00000000]]
Jinze Xue's avatar
Jinze Xue committed
68
        ], device=self.device)
69
70
71
72
73
74
        species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)

        _, aev = self.aev_computer((species, coordinates))
        _, cu_aev = self.cuaev_computer((species, coordinates))
        self.assertEqual(cu_aev, aev)

Jinze Xue's avatar
Jinze Xue committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    def testSimpleBackward(self):
        coordinates = torch.tensor([
            [[0.03192167, 0.00638559, 0.01301679],
             [-0.83140486, 0.39370209, -0.26395324],
             [-0.66518241, -0.84461308, 0.20759389],
             [0.45554739, 0.54289633, 0.81170881],
             [0.66091919, -0.16799635, -0.91037834]],
            [[-4.1862600, 0.0575700, -0.0381200],
             [-3.1689400, 0.0523700, 0.0200000],
             [-4.4978600, 0.8211300, 0.5604100],
             [-4.4978700, -0.8000100, 0.4155600],
             [0.00000000, -0.00000000, -0.00000000]]
        ], requires_grad=True, device=self.device)
        species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)

        _, aev = self.aev_computer((species, coordinates))
        aev.backward(torch.ones_like(aev))
        aev_grad = coordinates.grad

        coordinates = coordinates.clone().detach()
        coordinates.requires_grad_()
        _, cu_aev = self.cuaev_computer((species, coordinates))
        cu_aev.backward(torch.ones_like(cu_aev))
        cuaev_grad = coordinates.grad
        self.assertEqual(cu_aev, aev, f'cu_aev: {cu_aev}\n aev: {aev}')
        self.assertEqual(cuaev_grad, aev_grad, f'\ncuaev_grad: {cuaev_grad}\n aev_grad: {aev_grad}')

102
103
104
105
    def testTripeptideMD(self):
        for i in range(100):
            datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
            with open(datafile, 'rb') as f:
Jinze Xue's avatar
Jinze Xue committed
106
                coordinates, species, *_ = pickle.load(f)
107
108
109
110
111
112
                coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
                species = torch.from_numpy(species).unsqueeze(0).to(self.device)
                _, aev = self.aev_computer((species, coordinates))
                _, cu_aev = self.cuaev_computer((species, coordinates))
                self.assertEqual(cu_aev, aev)

Jinze Xue's avatar
Jinze Xue committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def testTripeptideMDBackward(self):
        for i in range(100):
            datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
            with open(datafile, 'rb') as f:
                coordinates, species, *_ = pickle.load(f)
                coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device).requires_grad_(True)
                species = torch.from_numpy(species).unsqueeze(0).to(self.device)
                _, aev = self.aev_computer((species, coordinates))
                aev.backward(torch.ones_like(aev))
                aev_grad = coordinates.grad

                coordinates = coordinates.clone().detach()
                coordinates.requires_grad_()
                _, cu_aev = self.cuaev_computer((species, coordinates))
                cu_aev.backward(torch.ones_like(cu_aev))
                cuaev_grad = coordinates.grad
                self.assertEqual(cu_aev, aev)
                self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)

132
133
134
135
136
137
138
139
140
141
    def testNIST(self):
        datafile = os.path.join(path, 'test_data/NIST/all')
        with open(datafile, 'rb') as f:
            data = pickle.load(f)
            for coordinates, species, _, _, _, _ in data:
                coordinates = torch.from_numpy(coordinates).to(torch.float).to(self.device)
                species = torch.from_numpy(species).to(self.device)
                _, aev = self.aev_computer((species, coordinates))
                _, cu_aev = self.cuaev_computer((species, coordinates))
                self.assertEqual(cu_aev, aev)
142

Jinze Xue's avatar
Jinze Xue committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def testNISTBackward(self):
        datafile = os.path.join(path, 'test_data/NIST/all')
        with open(datafile, 'rb') as f:
            data = pickle.load(f)
            for coordinates, species, _, _, _, _ in data:
                coordinates = torch.from_numpy(coordinates).to(torch.float).to(self.device).requires_grad_(True)
                species = torch.from_numpy(species).to(self.device)
                _, aev = self.aev_computer((species, coordinates))
                aev.backward(torch.ones_like(aev))
                aev_grad = coordinates.grad

                coordinates = coordinates.clone().detach()
                coordinates.requires_grad_()
                _, cu_aev = self.cuaev_computer((species, coordinates))
                cu_aev.backward(torch.ones_like(cu_aev))
                cuaev_grad = coordinates.grad
                self.assertEqual(cu_aev, aev)
                self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)

162
    def testVeryDenseMolecule(self):
Jinze Xue's avatar
Jinze Xue committed
163
        """
Jinze Xue's avatar
Jinze Xue committed
164
165
        Test very dense molecule for aev correctness, especially for angular kernel when center atom pairs are more than 32.
        issue: https://github.com/aiqm/torchani/pull/555
Jinze Xue's avatar
Jinze Xue committed
166
        """
167
168
169
        for i in range(100):
            datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
            with open(datafile, 'rb') as f:
Jinze Xue's avatar
Jinze Xue committed
170
                coordinates, species, *_ = pickle.load(f)
171
172
173
174
175
176
177
                # change angstrom coordinates to 10 times smaller
                coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
                species = torch.from_numpy(species).unsqueeze(0).to(self.device)
                _, aev = self.aev_computer((species, coordinates))
                _, cu_aev = self.cuaev_computer((species, coordinates))
                self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5)

Jinze Xue's avatar
Jinze Xue committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    def testVeryDenseMoleculeBackward(self):
        for i in range(100):
            datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
            with open(datafile, 'rb') as f:
                coordinates, species, *_ = pickle.load(f)
                # change angstrom coordinates to 10 times smaller
                coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
                coordinates.requires_grad_(True)
                species = torch.from_numpy(species).unsqueeze(0).to(self.device)

                _, aev = self.aev_computer((species, coordinates))
                aev.backward(torch.ones_like(aev))
                aev_grad = coordinates.grad

                coordinates = coordinates.clone().detach()
                coordinates.requires_grad_()
                _, cu_aev = self.cuaev_computer((species, coordinates))
                cu_aev.backward(torch.ones_like(cu_aev))
                cuaev_grad = coordinates.grad
                self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5)
                self.assertEqual(cuaev_grad, aev_grad, atol=5e-4, rtol=5e-4)

200
201
202

if __name__ == '__main__':
    unittest.main()