"vscode:/vscode.git/clone" did not exist on "a60f88b5a413c7e18f7a33244f479e3085515cfd"
test_aev.py 13.7 KB
Newer Older
1
2
3
4
5
import torch
import torchani
import unittest
import os
import pickle
6
7
import itertools
import ase
8
import ase.io
9
import math
10
import traceback
Gao, Xiang's avatar
Gao, Xiang committed
11
from common_aev_test import _TestAEVBase
12

13
14
15
16
17

path = os.path.dirname(os.path.realpath(__file__))
N = 97


18
19
20
21
22
class TestIsolated(unittest.TestCase):
    # Tests that there is no error when atoms are separated
    # a distance greater than the cutoff radius from all other atoms
    # this can throw an IndexError for large distances or lone atoms
    def setUp(self):
Ignacio Pickering's avatar
Ignacio Pickering committed
23
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        ani1x = torchani.models.ANI1x().to(self.device)
        self.aev_computer = ani1x.aev_computer
        self.species_to_tensor = ani1x.species_to_tensor
        self.rcr = ani1x.aev_computer.Rcr
        self.rca = self.aev_computer.Rca

    def testCO2(self):
        species = self.species_to_tensor(['O', 'C', 'O']).to(self.device).unsqueeze(0)
        distances = [1.0, self.rca,
                     self.rca + 1e-4, self.rcr,
                     self.rcr + 1e-4, 2 * self.rcr]
        error = ()
        for dist in distances:
            coordinates = torch.tensor(
                [[[-dist, 0., 0.], [0., 0., 0.], [0., 0., dist]]],
                requires_grad=True, device=self.device)
            try:
                _, _ = self.aev_computer((species, coordinates))
            except IndexError:
                error = (traceback.format_exc(), dist)
            if error:
                self.fail(f'\n\n{error[0]}\nFailure at distance: {error[1]}\n'
                          f'Radial r_cut of aev_computer: {self.rcr}\n'
                          f'Angular r_cut of aev_computer: {self.rca}')

    def testH2(self):
        species = self.species_to_tensor(['H', 'H']).to(self.device).unsqueeze(0)
        distances = [1.0, self.rca,
                     self.rca + 1e-4, self.rcr,
                     self.rcr + 1e-4, 2 * self.rcr]
        error = ()
        for dist in distances:
            coordinates = torch.tensor(
                [[[0., 0., 0.], [0., 0., dist]]],
                requires_grad=True, device=self.device)
            try:
                _, _ = self.aev_computer((species, coordinates))
            except IndexError:
                error = (traceback.format_exc(), dist)
            if error:
                self.fail(f'\n\n{error[0]}\nFailure at distance: {error[1]}\n'
                          f'Radial r_cut of aev_computer: {self.rcr}\n'
                          f'Angular r_cut of aev_computer: {self.rca}')

    def testH(self):
        # Tests for failure on a single atom
        species = self.species_to_tensor(['H']).to(self.device).unsqueeze(0)
        error = ()
        coordinates = torch.tensor(
            [[[0., 0., 0.]]],
            requires_grad=True, device=self.device)
        try:
            _, _ = self.aev_computer((species, coordinates))
        except IndexError:
            error = (traceback.format_exc())
        if error:
            self.fail(f'\n\n{error}\nFailure on lone atom\n')


Gao, Xiang's avatar
Gao, Xiang committed
83
class TestAEV(_TestAEVBase):
84

85
86
    def testIsomers(self):
        for i in range(N):
87
            datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
88
89
90
            with open(datafile, 'rb') as f:
                coordinates, species, expected_radial, expected_angular, _, _ \
                    = pickle.load(f)
91
92
93
94
                coordinates = torch.from_numpy(coordinates)
                species = torch.from_numpy(species)
                expected_radial = torch.from_numpy(expected_radial)
                expected_angular = torch.from_numpy(expected_angular)
95
                _, aev = self.aev_computer((species, coordinates))
96
97
                self.assertAEVEqual(expected_radial, expected_angular, aev)

98
99
100
    def testPadding(self):
        species_coordinates = []
        radial_angular = []
101
        for i in range(N):
102
            datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
103
104
            with open(datafile, 'rb') as f:
                coordinates, species, radial, angular, _, _ = pickle.load(f)
105
106
107
108
                coordinates = torch.from_numpy(coordinates)
                species = torch.from_numpy(species)
                radial = torch.from_numpy(radial)
                angular = torch.from_numpy(angular)
109
110
                species_coordinates.append(torchani.utils.broadcast_first_dim(
                    {'species': species, 'coordinates': coordinates}))
111
                radial_angular.append((radial, angular))
112
        species_coordinates = torchani.utils.pad_atomic_properties(
113
            species_coordinates)
114
        _, aev = self.aev_computer((species_coordinates['species'], species_coordinates['coordinates']))
115
116
117
118
        start = 0
        for expected_radial, expected_angular in radial_angular:
            conformations = expected_radial.shape[0]
            atoms = expected_radial.shape[1]
119
            aev_ = aev[start:(start + conformations), 0:atoms]
120
            start += conformations
121
            self.assertAEVEqual(expected_radial, expected_angular, aev_)
122
123


124
125
126
127
128
129
class TestAEVJIT(TestAEV):
    def setUp(self):
        super().setUp()
        self.aev_computer = torch.jit.script(self.aev_computer)


130
class TestPBCSeeEachOther(unittest.TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
131
    def setUp(self):
132
133
        self.ani1x = torchani.models.ANI1x()
        self.aev_computer = self.ani1x.aev_computer.to(torch.double)
134
135
136
137
138
139
140
141
142
143
144

    def testTranslationalInvariancePBC(self):
        coordinates = torch.tensor(
            [[[0, 0, 0],
              [1, 0, 0],
              [0, 1, 0],
              [0, 0, 1],
              [0, 1, 1]]],
            dtype=torch.double, requires_grad=True)
        cell = torch.eye(3, dtype=torch.double) * 2
        species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
145
        pbc = torch.ones(3, dtype=torch.bool)
146

147
        _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
148
149
150

        for _ in range(100):
            translation = torch.randn(3, dtype=torch.double)
151
            _, aev2 = self.aev_computer((species, coordinates + translation), cell=cell, pbc=pbc)
152
153
154
155
156
            self.assertTrue(torch.allclose(aev, aev2))

    def testPBCConnersSeeEachOther(self):
        species = torch.tensor([[0, 0]])
        cell = torch.eye(3, dtype=torch.double) * 10
157
        pbc = torch.ones(3, dtype=torch.bool)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)

        xyz1 = torch.tensor([0.1, 0.1, 0.1])
        xyz2s = [
            torch.tensor([9.9, 0.0, 0.0]),
            torch.tensor([0.0, 9.9, 0.0]),
            torch.tensor([0.0, 0.0, 9.9]),
            torch.tensor([9.9, 9.9, 0.0]),
            torch.tensor([0.0, 9.9, 9.9]),
            torch.tensor([9.9, 0.0, 9.9]),
            torch.tensor([9.9, 9.9, 9.9]),
        ]

        for xyz2 in xyz2s:
            coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
173
174
            atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
            atom_index1, atom_index2 = atom_index12.unbind(0)
175
176
177
178
179
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testPBCSurfaceSeeEachOther(self):
        cell = torch.eye(3, dtype=torch.double) * 10
180
        pbc = torch.ones(3, dtype=torch.bool)
181
182
183
184
185
186
187
188
189
190
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
        species = torch.tensor([[0, 0]])

        for i in range(3):
            xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
            xyz1[i] = 0.1
            xyz2 = xyz1.clone()
            xyz2[i] = 9.9

            coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
191
192
            atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
            atom_index1, atom_index2 = atom_index12.unbind(0)
193
194
195
196
197
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testPBCEdgesSeeEachOther(self):
        cell = torch.eye(3, dtype=torch.double) * 10
198
        pbc = torch.ones(3, dtype=torch.bool)
199
200
201
202
203
204
205
206
207
208
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
        species = torch.tensor([[0, 0]])

        for i, j in itertools.combinations(range(3), 2):
            xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
            xyz1[i] = 0.1
            xyz1[j] = 0.1
            for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
                xyz2 = xyz1.clone()
                xyz2[i] = new_i
209
                xyz2[j] = new_j
210
211

            coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
212
213
            atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
            atom_index1, atom_index2 = atom_index12.unbind(0)
214
215
216
217
218
219
220
            self.assertEqual(atom_index1.tolist(), [0])
            self.assertEqual(atom_index2.tolist(), [1])

    def testNonRectangularPBCConnersSeeEachOther(self):
        species = torch.tensor([[0, 0]])
        cell = ase.geometry.cellpar_to_cell([10, 10, 10 * math.sqrt(2), 90, 45, 90])
        cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
221
        pbc = torch.ones(3, dtype=torch.bool)
222
223
224
225
226
227
        allshifts = torchani.aev.compute_shifts(cell, pbc, 1)

        xyz1 = torch.tensor([0.1, 0.1, 0.05], dtype=torch.double)
        xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)

        coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
228
229
        atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
        atom_index1, atom_index2 = atom_index12.unbind(0)
230
231
232
233
234
        self.assertEqual(atom_index1.tolist(), [0])
        self.assertEqual(atom_index2.tolist(), [1])


class TestAEVOnBoundary(unittest.TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
235

236
237
238
239
240
241
242
243
244
245
    def setUp(self):
        self.eps = 1e-9
        cell = ase.geometry.cellpar_to_cell([100, 100, 100 * math.sqrt(2), 90, 45, 90])
        self.cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
        self.inv_cell = torch.inverse(self.cell)
        self.coordinates = torch.tensor([[[0.0, 0.0, 0.0],
                                          [1.0, -0.1, -0.1],
                                          [-0.1, 1.0, -0.1],
                                          [-0.1, -0.1, 1.0],
                                          [-1.0, -1.0, -1.0]]], dtype=torch.double)
246
        self.species = torch.tensor([[1, 0, 0, 0, 0]])
247
        self.pbc = torch.ones(3, dtype=torch.bool)
248
249
        self.v1, self.v2, self.v3 = self.cell
        self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
250
251
        ani1x = torchani.models.ANI1x()
        self.aev_computer = ani1x.aev_computer.to(torch.double)
252
        _, self.aev = self.aev_computer((self.species, self.center_coordinates), cell=self.cell, pbc=self.pbc)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    def assertInCell(self, coordinates):
        coordinates_cell = coordinates @ self.inv_cell
        self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
        in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
        self.assertTrue(in_cell.all())

    def assertNotInCell(self, coordinates):
        coordinates_cell = coordinates @ self.inv_cell
        self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
        in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
        self.assertFalse(in_cell.all())

    def testCornerSurfaceAndEdge(self):
        for i, j, k in itertools.product([0, 0.5, 1], repeat=3):
            if i == 0.5 and j == 0.5 and k == 0.5:
                continue
            coordinates = self.coordinates + i * self.v1 + j * self.v2 + k * self.v3
            self.assertNotInCell(coordinates)
            coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
            self.assertInCell(coordinates)
274
            _, aev = self.aev_computer((self.species, coordinates), cell=self.cell, pbc=self.pbc)
275
276
            self.assertGreater(aev.abs().max().item(), 0)
            self.assertTrue(torch.allclose(aev, self.aev))
277

Gao, Xiang's avatar
Gao, Xiang committed
278

279
280
281
class TestAEVOnBenzenePBC(unittest.TestCase):

    def setUp(self):
282
283
        ani1x = torchani.models.ANI1x()
        self.aev_computer = ani1x.aev_computer
284
285
286
        filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
        benzene = ase.io.read(filename)
        self.cell = torch.tensor(benzene.get_cell(complete=True)).float()
287
        self.pbc = torch.tensor(benzene.get_pbc(), dtype=torch.bool)
288
289
290
        species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
        self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0)
        self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float()
291
        _, self.aev = self.aev_computer((self.species, self.coordinates), cell=self.cell, pbc=self.pbc)
292
        self.natoms = self.aev.shape[1]
293
294

    def testRepeat(self):
295
        tolerance = 5e-6
296
297
298
299
300
301
302
303
304
        c1, c2, c3 = self.cell
        species2 = self.species.repeat(1, 4)
        coordinates2 = torch.cat([
            self.coordinates,
            self.coordinates + c1,
            self.coordinates + 2 * c1,
            self.coordinates + 3 * c1,
        ], dim=1)
        cell2 = torch.stack([4 * c1, c2, c3])
305
        _, aev2 = self.aev_computer((species2, coordinates2), cell=cell2, pbc=self.pbc)
306
307
308
309
310
311
312
313
314
315
316
317
318
        for i in range(3):
            aev3 = aev2[:, i * self.natoms: (i + 1) * self.natoms, :]
            self.assertTrue(torch.allclose(self.aev, aev3, atol=tolerance))

    def testManualMirror(self):
        c1, c2, c3 = self.cell
        species2 = self.species.repeat(1, 3 ** 3)
        coordinates2 = torch.cat([
            self.coordinates + i * c1 + j * c2 + k * c3
            for i, j, k in itertools.product([0, -1, 1], repeat=3)
        ], dim=1)
        _, aev2 = self.aev_computer((species2, coordinates2))
        aev2 = aev2[:, :self.natoms, :]
319
320
321
        self.assertTrue(torch.allclose(self.aev, aev2))


322
323
if __name__ == '__main__':
    unittest.main()