test_aev.py 15.8 KB
Newer Older
1
2
3
4
5
import torch
import torchani
import unittest
import os
import pickle
6
import copy
7
8
import itertools
import ase
9
import ase.io
10
import math
11
import traceback
Gao, Xiang's avatar
Gao, Xiang committed
12
from common_aev_test import _TestAEVBase
13

14
15
16
17
18

path = os.path.dirname(os.path.realpath(__file__))
N = 97


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class TestIsolated(unittest.TestCase):
    # Tests that there is no error when atoms are separated
    # a distance greater than the cutoff radius from all other atoms
    # this can throw an IndexError for large distances or lone atoms
    def setUp(self):
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        ani1x = torchani.models.ANI1x().to(self.device)
        self.aev_computer = ani1x.aev_computer
        self.species_to_tensor = ani1x.species_to_tensor
        self.rcr = ani1x.aev_computer.Rcr
        self.rca = self.aev_computer.Rca

    def testCO2(self):
        species = self.species_to_tensor(['O', 'C', 'O']).to(self.device).unsqueeze(0)
        distances = [1.0, self.rca,
                     self.rca + 1e-4, self.rcr,
                     self.rcr + 1e-4, 2 * self.rcr]
        error = ()
        for dist in distances:
            coordinates = torch.tensor(
                [[[-dist, 0., 0.], [0., 0., 0.], [0., 0., dist]]],
                requires_grad=True, device=self.device)
            try:
                _, _ = self.aev_computer((species, coordinates))
            except IndexError:
                error = (traceback.format_exc(), dist)
            if error:
                self.fail(f'\n\n{error[0]}\nFailure at distance: {error[1]}\n'
                          f'Radial r_cut of aev_computer: {self.rcr}\n'
                          f'Angular r_cut of aev_computer: {self.rca}')

    def testH2(self):
        species = self.species_to_tensor(['H', 'H']).to(self.device).unsqueeze(0)
        distances = [1.0, self.rca,
                     self.rca + 1e-4, self.rcr,
                     self.rcr + 1e-4, 2 * self.rcr]
        error = ()
        for dist in distances:
            coordinates = torch.tensor(
                [[[0., 0., 0.], [0., 0., dist]]],
                requires_grad=True, device=self.device)
            try:
                _, _ = self.aev_computer((species, coordinates))
            except IndexError:
                error = (traceback.format_exc(), dist)
            if error:
                self.fail(f'\n\n{error[0]}\nFailure at distance: {error[1]}\n'
                          f'Radial r_cut of aev_computer: {self.rcr}\n'
                          f'Angular r_cut of aev_computer: {self.rca}')

    def testH(self):
        # Tests for failure on a single atom
        species = self.species_to_tensor(['H']).to(self.device).unsqueeze(0)
        error = ()
        coordinates = torch.tensor(
            [[[0., 0., 0.]]],
            requires_grad=True, device=self.device)
        try:
            _, _ = self.aev_computer((species, coordinates))
        except IndexError:
            error = (traceback.format_exc())
        if error:
            self.fail(f'\n\n{error}\nFailure on lone atom\n')


Gao, Xiang's avatar
Gao, Xiang committed
87
class TestAEV(_TestAEVBase):
88

89
90
    def testIsomers(self):
        for i in range(N):
91
            datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
92
93
94
            with open(datafile, 'rb') as f:
                coordinates, species, expected_radial, expected_angular, _, _ \
                    = pickle.load(f)
95
96
97
98
99
100
101
102
                coordinates = torch.from_numpy(coordinates)
                species = torch.from_numpy(species)
                expected_radial = torch.from_numpy(expected_radial)
                expected_angular = torch.from_numpy(expected_angular)
                coordinates = self.transform(coordinates)
                species = self.transform(species)
                expected_radial = self.transform(expected_radial)
                expected_angular = self.transform(expected_angular)
103
                _, aev = self.aev_computer((species, coordinates))
104
105
                self.assertAEVEqual(expected_radial, expected_angular, aev)

106
107
108
    def testPadding(self):
        species_coordinates = []
        radial_angular = []
109
        for i in range(N):
110
            datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
111
112
            with open(datafile, 'rb') as f:
                coordinates, species, radial, angular, _, _ = pickle.load(f)
113
114
115
116
117
118
119
120
                coordinates = torch.from_numpy(coordinates)
                species = torch.from_numpy(species)
                radial = torch.from_numpy(radial)
                angular = torch.from_numpy(angular)
                coordinates = self.transform(coordinates)
                species = self.transform(species)
                radial = self.transform(radial)
                angular = self.transform(angular)
121
                species_coordinates.append({'species': species, 'coordinates': coordinates})
122
                radial_angular.append((radial, angular))
123
        species_coordinates = torchani.utils.pad_atomic_properties(
124
            species_coordinates)
125
        _, aev = self.aev_computer((species_coordinates['species'], species_coordinates['coordinates']))
126
127
128
129
        start = 0
        for expected_radial, expected_angular in radial_angular:
            conformations = expected_radial.shape[0]
            atoms = expected_radial.shape[1]
130
            aev_ = aev[start:(start + conformations), 0:atoms]
131
            start += conformations
132
            self.assertAEVEqual(expected_radial, expected_angular, aev_)
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    @unittest.skipIf(not torch.cuda.is_available(), "Too slow on CPU")
    def testGradient(self):
        """Test validity of autodiff by comparing analytical and numerical
        gradients.
        """
        datafile = os.path.join(path, 'test_data/NIST/all')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Create local copy of aev_computer to avoid interference with other
        # tests.
        aev_computer = copy.deepcopy(self.aev_computer).to(device).to(torch.float64)
        with open(datafile, 'rb') as f:
            data = pickle.load(f)
            for coordinates, species, _, _, _, _ in data:
                coordinates = torch.from_numpy(coordinates).to(device).to(torch.float64)
                coordinates.requires_grad_(True)
                species = torch.from_numpy(species).to(device)

                # PyTorch gradcheck expects to test a funtion with inputs and
                # outputs of type torch.Tensor. The numerical estimation of
                # the deriviate involves making small modifications to the
                # input and observing how it affects the output. The species
                # tensor needs to be removed from the input so that gradcheck
                # does not attempt to estimate the gradient with respect to
                # species and fail.
                # Create simple function wrapper to handle this.
                def aev_forward_wrapper(coords):
                    # Return only the aev portion of the output.
                    return aev_computer((species, coords))[1]
                # Sanity Check: Forward wrapper returns aev without error.
                aev_forward_wrapper(coordinates)
                torch.autograd.gradcheck(
                    aev_forward_wrapper,
                    coordinates
                )

169

170
171
172
173
174
175
class TestAEVJIT(TestAEV):
    def setUp(self):
        super().setUp()
        self.aev_computer = torch.jit.script(self.aev_computer)


176
class TestPBCSeeEachOther(unittest.TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
177
    def setUp(self):
178
179
        self.ani1x = torchani.models.ANI1x()
        self.aev_computer = self.ani1x.aev_computer.to(torch.double)
180
181
182
183
184
185
186
187
188
189
190

    def testTranslationalInvariancePBC(self):
        coordinates = torch.tensor(
            [[[0, 0, 0],
              [1, 0, 0],
              [0, 1, 0],
              [0, 0, 1],
              [0, 1, 1]]],
            dtype=torch.double, requires_grad=True)
        cell = torch.eye(3, dtype=torch.double) * 2
        species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
191
        pbc = torch.ones(3, dtype=torch.bool)
192

193
        _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
194
195
196

        for _ in range(100):
            translation = torch.randn(3, dtype=torch.double)
197
            _, aev2 = self.aev_computer((species, coordinates + translation), cell=cell, pbc=pbc)
198
199
200
201
202
            self.assertTrue(torch.allclose(aev, aev2))

    def testPBCConnersSeeEachOther(self):
        species = torch.tensor([[0, 0]])
        cell = torch.eye(3, dtype=torch.double) * 10
203
        pbc = torch.ones(3, dtype=torch.bool)
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)

        xyz1 = torch.tensor([0.1, 0.1, 0.1])
        xyz2s = [
            torch.tensor([9.9, 0.0, 0.0]),
            torch.tensor([0.0, 9.9, 0.0]),
            torch.tensor([0.0, 0.0, 9.9]),
            torch.tensor([9.9, 9.9, 0.0]),
            torch.tensor([0.0, 9.9, 9.9]),
            torch.tensor([9.9, 0.0, 9.9]),
            torch.tensor([9.9, 9.9, 9.9]),
        ]

        for xyz2 in xyz2s:
            coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
219
            atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
220
221
222
223
224
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testPBCSurfaceSeeEachOther(self):
        cell = torch.eye(3, dtype=torch.double) * 10
225
        pbc = torch.ones(3, dtype=torch.bool)
226
227
228
229
230
231
232
233
234
235
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
        species = torch.tensor([[0, 0]])

        for i in range(3):
            xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
            xyz1[i] = 0.1
            xyz2 = xyz1.clone()
            xyz2[i] = 9.9

            coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
236
            atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
237
238
239
240
241
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testPBCEdgesSeeEachOther(self):
        cell = torch.eye(3, dtype=torch.double) * 10
242
        pbc = torch.ones(3, dtype=torch.bool)
243
244
245
246
247
248
249
250
251
252
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
        species = torch.tensor([[0, 0]])

        for i, j in itertools.combinations(range(3), 2):
            xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
            xyz1[i] = 0.1
            xyz1[j] = 0.1
            for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
                xyz2 = xyz1.clone()
                xyz2[i] = new_i
253
                xyz2[j] = new_j
254
255

            coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
256
            atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
257
258
259
260
261
262
263
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testNonRectangularPBCConnersSeeEachOther(self):
        species = torch.tensor([[0, 0]])
        cell = ase.geometry.cellpar_to_cell([10, 10, 10 * math.sqrt(2), 90, 45, 90])
        cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
264
        pbc = torch.ones(3, dtype=torch.bool)
265
266
267
268
269
270
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)

        xyz1 = torch.tensor([0.1, 0.1, 0.05], dtype=torch.double)
        xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)

        coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
271
        atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
272
273
274
275
276
        self.assertEqual(atom_index1.tolist(), [0])
        self.assertEqual(atom_index2.tolist(), [1])


class TestAEVOnBoundary(unittest.TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
277

278
279
280
281
282
283
284
285
286
287
    def setUp(self):
        self.eps = 1e-9
        cell = ase.geometry.cellpar_to_cell([100, 100, 100 * math.sqrt(2), 90, 45, 90])
        self.cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
        self.inv_cell = torch.inverse(self.cell)
        self.coordinates = torch.tensor([[[0.0, 0.0, 0.0],
                                          [1.0, -0.1, -0.1],
                                          [-0.1, 1.0, -0.1],
                                          [-0.1, -0.1, 1.0],
                                          [-1.0, -1.0, -1.0]]], dtype=torch.double)
288
        self.species = torch.tensor([[1, 0, 0, 0, 0]])
289
        self.pbc = torch.ones(3, dtype=torch.bool)
290
291
        self.v1, self.v2, self.v3 = self.cell
        self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
292
293
        ani1x = torchani.models.ANI1x()
        self.aev_computer = ani1x.aev_computer.to(torch.double)
294
        _, self.aev = self.aev_computer((self.species, self.center_coordinates), cell=self.cell, pbc=self.pbc)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

    def assertInCell(self, coordinates):
        coordinates_cell = coordinates @ self.inv_cell
        self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
        in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
        self.assertTrue(in_cell.all())

    def assertNotInCell(self, coordinates):
        coordinates_cell = coordinates @ self.inv_cell
        self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
        in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
        self.assertFalse(in_cell.all())

    def testCornerSurfaceAndEdge(self):
        for i, j, k in itertools.product([0, 0.5, 1], repeat=3):
            if i == 0.5 and j == 0.5 and k == 0.5:
                continue
            coordinates = self.coordinates + i * self.v1 + j * self.v2 + k * self.v3
            self.assertNotInCell(coordinates)
            coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
            self.assertInCell(coordinates)
316
            _, aev = self.aev_computer((self.species, coordinates), cell=self.cell, pbc=self.pbc)
317
318
            self.assertGreater(aev.abs().max().item(), 0)
            self.assertTrue(torch.allclose(aev, self.aev))
319

Gao, Xiang's avatar
Gao, Xiang committed
320

321
322
323
class TestAEVOnBenzenePBC(unittest.TestCase):

    def setUp(self):
324
325
        ani1x = torchani.models.ANI1x()
        self.aev_computer = ani1x.aev_computer
326
327
328
        filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
        benzene = ase.io.read(filename)
        self.cell = torch.tensor(benzene.get_cell(complete=True)).float()
329
        self.pbc = torch.tensor(benzene.get_pbc(), dtype=torch.bool)
330
331
332
        species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
        self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0)
        self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float()
333
        _, self.aev = self.aev_computer((self.species, self.coordinates), cell=self.cell, pbc=self.pbc)
334
        self.natoms = self.aev.shape[1]
335
336

    def testRepeat(self):
337
        tolerance = 5e-6
338
339
340
341
342
343
344
345
346
        c1, c2, c3 = self.cell
        species2 = self.species.repeat(1, 4)
        coordinates2 = torch.cat([
            self.coordinates,
            self.coordinates + c1,
            self.coordinates + 2 * c1,
            self.coordinates + 3 * c1,
        ], dim=1)
        cell2 = torch.stack([4 * c1, c2, c3])
347
        _, aev2 = self.aev_computer((species2, coordinates2), cell=cell2, pbc=self.pbc)
348
349
350
351
352
353
354
355
356
357
358
359
360
        for i in range(3):
            aev3 = aev2[:, i * self.natoms: (i + 1) * self.natoms, :]
            self.assertTrue(torch.allclose(self.aev, aev3, atol=tolerance))

    def testManualMirror(self):
        c1, c2, c3 = self.cell
        species2 = self.species.repeat(1, 3 ** 3)
        coordinates2 = torch.cat([
            self.coordinates + i * c1 + j * c2 + k * c3
            for i, j, k in itertools.product([0, -1, 1], repeat=3)
        ], dim=1)
        _, aev2 = self.aev_computer((species2, coordinates2))
        aev2 = aev2[:, :self.natoms, :]
361
362
363
        self.assertTrue(torch.allclose(self.aev, aev2))


364
365
if __name__ == '__main__':
    unittest.main()