test_aev.py 15.8 KB
Newer Older
1
2
3
4
5
import torch
import torchani
import unittest
import os
import pickle
6
import copy
7
8
import itertools
import ase
9
import ase.io
10
import math
11
import traceback
Gao, Xiang's avatar
Gao, Xiang committed
12
from common_aev_test import _TestAEVBase
13

14
15
16
17
18

path = os.path.dirname(os.path.realpath(__file__))
N = 97


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class TestIsolated(unittest.TestCase):
    # Tests that there is no error when atoms are separated
    # a distance greater than the cutoff radius from all other atoms
    # this can throw an IndexError for large distances or lone atoms
    def setUp(self):
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        ani1x = torchani.models.ANI1x().to(self.device)
        self.aev_computer = ani1x.aev_computer
        self.species_to_tensor = ani1x.species_to_tensor
        self.rcr = ani1x.aev_computer.Rcr
        self.rca = self.aev_computer.Rca

    def testCO2(self):
        species = self.species_to_tensor(['O', 'C', 'O']).to(self.device).unsqueeze(0)
        distances = [1.0, self.rca,
                     self.rca + 1e-4, self.rcr,
                     self.rcr + 1e-4, 2 * self.rcr]
        error = ()
        for dist in distances:
            coordinates = torch.tensor(
                [[[-dist, 0., 0.], [0., 0., 0.], [0., 0., dist]]],
                requires_grad=True, device=self.device)
            try:
                _, _ = self.aev_computer((species, coordinates))
            except IndexError:
                error = (traceback.format_exc(), dist)
            if error:
                self.fail(f'\n\n{error[0]}\nFailure at distance: {error[1]}\n'
                          f'Radial r_cut of aev_computer: {self.rcr}\n'
                          f'Angular r_cut of aev_computer: {self.rca}')

    def testH2(self):
        species = self.species_to_tensor(['H', 'H']).to(self.device).unsqueeze(0)
        distances = [1.0, self.rca,
                     self.rca + 1e-4, self.rcr,
                     self.rcr + 1e-4, 2 * self.rcr]
        error = ()
        for dist in distances:
            coordinates = torch.tensor(
                [[[0., 0., 0.], [0., 0., dist]]],
                requires_grad=True, device=self.device)
            try:
                _, _ = self.aev_computer((species, coordinates))
            except IndexError:
                error = (traceback.format_exc(), dist)
            if error:
                self.fail(f'\n\n{error[0]}\nFailure at distance: {error[1]}\n'
                          f'Radial r_cut of aev_computer: {self.rcr}\n'
                          f'Angular r_cut of aev_computer: {self.rca}')

    def testH(self):
        # Tests for failure on a single atom
        species = self.species_to_tensor(['H']).to(self.device).unsqueeze(0)
        error = ()
        coordinates = torch.tensor(
            [[[0., 0., 0.]]],
            requires_grad=True, device=self.device)
        try:
            _, _ = self.aev_computer((species, coordinates))
        except IndexError:
            error = (traceback.format_exc())
        if error:
            self.fail(f'\n\n{error}\nFailure on lone atom\n')


Gao, Xiang's avatar
Gao, Xiang committed
87
class TestAEV(_TestAEVBase):
88

89
90
    def testIsomers(self):
        for i in range(N):
91
            datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
92
93
94
            with open(datafile, 'rb') as f:
                coordinates, species, expected_radial, expected_angular, _, _ \
                    = pickle.load(f)
95
96
97
98
99
100
101
102
                coordinates = torch.from_numpy(coordinates)
                species = torch.from_numpy(species)
                expected_radial = torch.from_numpy(expected_radial)
                expected_angular = torch.from_numpy(expected_angular)
                coordinates = self.transform(coordinates)
                species = self.transform(species)
                expected_radial = self.transform(expected_radial)
                expected_angular = self.transform(expected_angular)
103
                _, aev = self.aev_computer((species, coordinates))
104
105
                self.assertAEVEqual(expected_radial, expected_angular, aev)

106
107
108
    def testPadding(self):
        species_coordinates = []
        radial_angular = []
109
        for i in range(N):
110
            datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
111
112
            with open(datafile, 'rb') as f:
                coordinates, species, radial, angular, _, _ = pickle.load(f)
113
114
115
116
117
118
119
120
                coordinates = torch.from_numpy(coordinates)
                species = torch.from_numpy(species)
                radial = torch.from_numpy(radial)
                angular = torch.from_numpy(angular)
                coordinates = self.transform(coordinates)
                species = self.transform(species)
                radial = self.transform(radial)
                angular = self.transform(angular)
121
122
                species_coordinates.append(torchani.utils.broadcast_first_dim(
                    {'species': species, 'coordinates': coordinates}))
123
                radial_angular.append((radial, angular))
124
        species_coordinates = torchani.utils.pad_atomic_properties(
125
            species_coordinates)
126
        _, aev = self.aev_computer((species_coordinates['species'], species_coordinates['coordinates']))
127
128
129
130
        start = 0
        for expected_radial, expected_angular in radial_angular:
            conformations = expected_radial.shape[0]
            atoms = expected_radial.shape[1]
131
            aev_ = aev[start:(start + conformations), 0:atoms]
132
            start += conformations
133
            self.assertAEVEqual(expected_radial, expected_angular, aev_)
134

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    @unittest.skipIf(not torch.cuda.is_available(), "Too slow on CPU")
    def testGradient(self):
        """Test validity of autodiff by comparing analytical and numerical
        gradients.
        """
        datafile = os.path.join(path, 'test_data/NIST/all')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Create local copy of aev_computer to avoid interference with other
        # tests.
        aev_computer = copy.deepcopy(self.aev_computer).to(device).to(torch.float64)
        with open(datafile, 'rb') as f:
            data = pickle.load(f)
            for coordinates, species, _, _, _, _ in data:
                coordinates = torch.from_numpy(coordinates).to(device).to(torch.float64)
                coordinates.requires_grad_(True)
                species = torch.from_numpy(species).to(device)

                # PyTorch gradcheck expects to test a funtion with inputs and
                # outputs of type torch.Tensor. The numerical estimation of
                # the deriviate involves making small modifications to the
                # input and observing how it affects the output. The species
                # tensor needs to be removed from the input so that gradcheck
                # does not attempt to estimate the gradient with respect to
                # species and fail.
                # Create simple function wrapper to handle this.
                def aev_forward_wrapper(coords):
                    # Return only the aev portion of the output.
                    return aev_computer((species, coords))[1]
                # Sanity Check: Forward wrapper returns aev without error.
                aev_forward_wrapper(coordinates)
                torch.autograd.gradcheck(
                    aev_forward_wrapper,
                    coordinates
                )

170

171
172
173
174
175
176
class TestAEVJIT(TestAEV):
    def setUp(self):
        super().setUp()
        self.aev_computer = torch.jit.script(self.aev_computer)


177
class TestPBCSeeEachOther(unittest.TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
178
    def setUp(self):
179
180
        self.ani1x = torchani.models.ANI1x()
        self.aev_computer = self.ani1x.aev_computer.to(torch.double)
181
182
183
184
185
186
187
188
189
190
191

    def testTranslationalInvariancePBC(self):
        coordinates = torch.tensor(
            [[[0, 0, 0],
              [1, 0, 0],
              [0, 1, 0],
              [0, 0, 1],
              [0, 1, 1]]],
            dtype=torch.double, requires_grad=True)
        cell = torch.eye(3, dtype=torch.double) * 2
        species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
192
        pbc = torch.ones(3, dtype=torch.bool)
193

194
        _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
195
196
197

        for _ in range(100):
            translation = torch.randn(3, dtype=torch.double)
198
            _, aev2 = self.aev_computer((species, coordinates + translation), cell=cell, pbc=pbc)
199
200
201
202
203
            self.assertTrue(torch.allclose(aev, aev2))

    def testPBCConnersSeeEachOther(self):
        species = torch.tensor([[0, 0]])
        cell = torch.eye(3, dtype=torch.double) * 10
204
        pbc = torch.ones(3, dtype=torch.bool)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)

        xyz1 = torch.tensor([0.1, 0.1, 0.1])
        xyz2s = [
            torch.tensor([9.9, 0.0, 0.0]),
            torch.tensor([0.0, 9.9, 0.0]),
            torch.tensor([0.0, 0.0, 9.9]),
            torch.tensor([9.9, 9.9, 0.0]),
            torch.tensor([0.0, 9.9, 9.9]),
            torch.tensor([9.9, 0.0, 9.9]),
            torch.tensor([9.9, 9.9, 9.9]),
        ]

        for xyz2 in xyz2s:
            coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
220
            atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
221
222
223
224
225
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testPBCSurfaceSeeEachOther(self):
        cell = torch.eye(3, dtype=torch.double) * 10
226
        pbc = torch.ones(3, dtype=torch.bool)
227
228
229
230
231
232
233
234
235
236
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
        species = torch.tensor([[0, 0]])

        for i in range(3):
            xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
            xyz1[i] = 0.1
            xyz2 = xyz1.clone()
            xyz2[i] = 9.9

            coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
237
            atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
238
239
240
241
242
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testPBCEdgesSeeEachOther(self):
        cell = torch.eye(3, dtype=torch.double) * 10
243
        pbc = torch.ones(3, dtype=torch.bool)
244
245
246
247
248
249
250
251
252
253
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
        species = torch.tensor([[0, 0]])

        for i, j in itertools.combinations(range(3), 2):
            xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
            xyz1[i] = 0.1
            xyz1[j] = 0.1
            for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
                xyz2 = xyz1.clone()
                xyz2[i] = new_i
254
                xyz2[j] = new_j
255
256

            coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
257
            atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
258
259
260
261
262
263
264
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testNonRectangularPBCConnersSeeEachOther(self):
        species = torch.tensor([[0, 0]])
        cell = ase.geometry.cellpar_to_cell([10, 10, 10 * math.sqrt(2), 90, 45, 90])
        cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
265
        pbc = torch.ones(3, dtype=torch.bool)
266
267
268
269
270
271
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)

        xyz1 = torch.tensor([0.1, 0.1, 0.05], dtype=torch.double)
        xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)

        coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
272
        atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
273
274
275
276
277
        self.assertEqual(atom_index1.tolist(), [0])
        self.assertEqual(atom_index2.tolist(), [1])


class TestAEVOnBoundary(unittest.TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
278

279
280
281
282
283
284
285
286
287
288
    def setUp(self):
        self.eps = 1e-9
        cell = ase.geometry.cellpar_to_cell([100, 100, 100 * math.sqrt(2), 90, 45, 90])
        self.cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
        self.inv_cell = torch.inverse(self.cell)
        self.coordinates = torch.tensor([[[0.0, 0.0, 0.0],
                                          [1.0, -0.1, -0.1],
                                          [-0.1, 1.0, -0.1],
                                          [-0.1, -0.1, 1.0],
                                          [-1.0, -1.0, -1.0]]], dtype=torch.double)
289
        self.species = torch.tensor([[1, 0, 0, 0, 0]])
290
        self.pbc = torch.ones(3, dtype=torch.bool)
291
292
        self.v1, self.v2, self.v3 = self.cell
        self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
293
294
        ani1x = torchani.models.ANI1x()
        self.aev_computer = ani1x.aev_computer.to(torch.double)
295
        _, self.aev = self.aev_computer((self.species, self.center_coordinates), cell=self.cell, pbc=self.pbc)
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

    def assertInCell(self, coordinates):
        coordinates_cell = coordinates @ self.inv_cell
        self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
        in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
        self.assertTrue(in_cell.all())

    def assertNotInCell(self, coordinates):
        coordinates_cell = coordinates @ self.inv_cell
        self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
        in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
        self.assertFalse(in_cell.all())

    def testCornerSurfaceAndEdge(self):
        for i, j, k in itertools.product([0, 0.5, 1], repeat=3):
            if i == 0.5 and j == 0.5 and k == 0.5:
                continue
            coordinates = self.coordinates + i * self.v1 + j * self.v2 + k * self.v3
            self.assertNotInCell(coordinates)
            coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
            self.assertInCell(coordinates)
317
            _, aev = self.aev_computer((self.species, coordinates), cell=self.cell, pbc=self.pbc)
318
319
            self.assertGreater(aev.abs().max().item(), 0)
            self.assertTrue(torch.allclose(aev, self.aev))
320

Gao, Xiang's avatar
Gao, Xiang committed
321

322
323
324
class TestAEVOnBenzenePBC(unittest.TestCase):

    def setUp(self):
325
326
        ani1x = torchani.models.ANI1x()
        self.aev_computer = ani1x.aev_computer
327
328
329
        filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
        benzene = ase.io.read(filename)
        self.cell = torch.tensor(benzene.get_cell(complete=True)).float()
330
        self.pbc = torch.tensor(benzene.get_pbc(), dtype=torch.bool)
331
332
333
        species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
        self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0)
        self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float()
334
        _, self.aev = self.aev_computer((self.species, self.coordinates), cell=self.cell, pbc=self.pbc)
335
        self.natoms = self.aev.shape[1]
336
337

    def testRepeat(self):
338
        tolerance = 5e-6
339
340
341
342
343
344
345
346
347
        c1, c2, c3 = self.cell
        species2 = self.species.repeat(1, 4)
        coordinates2 = torch.cat([
            self.coordinates,
            self.coordinates + c1,
            self.coordinates + 2 * c1,
            self.coordinates + 3 * c1,
        ], dim=1)
        cell2 = torch.stack([4 * c1, c2, c3])
348
        _, aev2 = self.aev_computer((species2, coordinates2), cell=cell2, pbc=self.pbc)
349
350
351
352
353
354
355
356
357
358
359
360
361
        for i in range(3):
            aev3 = aev2[:, i * self.natoms: (i + 1) * self.natoms, :]
            self.assertTrue(torch.allclose(self.aev, aev3, atol=tolerance))

    def testManualMirror(self):
        c1, c2, c3 = self.cell
        species2 = self.species.repeat(1, 3 ** 3)
        coordinates2 = torch.cat([
            self.coordinates + i * c1 + j * c2 + k * c3
            for i, j, k in itertools.product([0, -1, 1], repeat=3)
        ], dim=1)
        _, aev2 = self.aev_computer((species2, coordinates2))
        aev2 = aev2[:, :self.natoms, :]
362
363
364
        self.assertTrue(torch.allclose(self.aev, aev2))


365
366
if __name__ == '__main__':
    unittest.main()