test_al.py 4.15 KB
Newer Older
1
2
3
4
import torch
import torchani
import math
import unittest
Jinze Xue's avatar
Jinze Xue committed
5
from torchani.testing import TestCase
6
7


Jinze Xue's avatar
Jinze Xue committed
8
class TestALAtomic(TestCase):
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def setUp(self):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model = torchani.models.ANI1x(periodic_table_index=True).to(
            self.device).double()
        self.converter = torchani.nn.SpeciesConverter(['H', 'C', 'N', 'O'])
        self.aev_computer = self.model.aev_computer
        self.ani_model = self.model.neural_networks
        self.first_model = self.model[0]
        # fully symmetric methane
        self.coordinates = torch.tensor(
            [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [0.5, 0.5, 0.5]],
            dtype=torch.double,
            device=self.device).unsqueeze(0)
        self.species = torch.tensor([[1, 1, 1, 1, 6]],
                                    dtype=torch.long,
                                    device=self.device)

    def testAverageAtomicEnergies(self):
        _, energies = self.model.atomic_energies(
            (self.species, self.coordinates))
30
        self.assertEqual(energies.shape, self.coordinates.shape[:-1])
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        # energies of all hydrogens should be equal
        self.assertTrue((torch.isclose(
            energies[:, :-1],
            torch.tensor(-0.54853380570289400620,
                         dtype=torch.double).to(self.device))).all())

    def testAtomicEnergies(self):
        _, energies = self.model.atomic_energies(
            (self.species, self.coordinates), average=False)
        self.assertTrue(energies.shape[1:] == self.coordinates.shape[:-1])
        self.assertTrue(energies.shape[0] == len(self.model.neural_networks))
        # energies of all hydrogens should be equal
        self.assertTrue(
            torch.isclose(
                energies[0, 0, 0],
                torch.tensor(-0.54562734428531045605,
                             device=self.device,
                             dtype=torch.double)))
        for e in energies:
            self.assertTrue((e[:, :-1] == e[:, 0]).all())


class TestALQBC(TestALAtomic):
    def testMemberEnergies(self):
        # fully symmetric methane
        _, energies = self.model.members_energies(
            (self.species, self.coordinates))

        # correctness of shape
        torch.set_printoptions(precision=15)
61
62
63
64
65
        self.assertEqual(energies.shape[-1], self.coordinates.shape[0])
        self.assertEqual(energies.shape[0], len(self.model.neural_networks))
        self.assertEqual(
            energies[0], self.first_model((self.species,
                                           self.coordinates)).energies)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        self.assertTrue(
            torch.isclose(
                energies[0],
                torch.tensor(-40.277153758433975,
                             dtype=torch.double,
                             device=self.device)))

    def testQBC(self):
        # fully symmetric methane
        _, _, qbc = self.model.energies_qbcs((self.species, self.coordinates))

        torch.set_printoptions(precision=15)
        std = self.model.members_energies(
            (self.species, self.coordinates)).energies.std(dim=0,
                                                           unbiased=True)
        self.assertTrue(
            torch.isclose(std / math.sqrt(self.coordinates.shape[1]), qbc))

        # also test with multiple coordinates
        coord1 = torch.tensor(
            [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [0.5, 0.5, 0.5]],
            dtype=torch.double,
            device=self.device).unsqueeze(0)
        coord2 = torch.randn(1, 5, 3, dtype=torch.double, device=self.device)

        coordinates = torch.cat((coord1, coord2), dim=0)
        species = torch.tensor([[1, 1, 1, 1, 6], [-1, 1, 1, 1, 1]],
                               dtype=torch.long,
                               device=self.device)
        std = self.model.members_energies(
            (species, coordinates)).energies.std(dim=0, unbiased=True)
        _, _, qbc = self.model.energies_qbcs((species, coordinates))
        std[0] = std[0] / math.sqrt(5)
        std[1] = std[1] / math.sqrt(4)
100
        self.assertEqual(std, qbc)
101
102
103
104


if __name__ == '__main__':
    unittest.main()