test_data.py 4.17 KB
Newer Older
1
import os
2
import torch
3
import torchani
4
import unittest
5

6
path = os.path.dirname(os.path.realpath(__file__))
7
dataset_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
8
batch_size = 256
9
ani1x = torchani.models.ANI1x()
10
sae_dict = ani1x.sae_dict
11
aev_computer = ani1x.aev_computer
12
13
14
15
16


class TestData(unittest.TestCase):

    def testTensorShape(self):
17
18
19
20
21
        ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
        for d in ds:
            species = d['species']
            coordinates = d['coordinates']
            energies = d['energies']
22
23
24
25
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
Richard Xue's avatar
Richard Xue committed
26
            self.assertEqual(coordinates.shape[:2], species.shape[:2])
27
28
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
29

30
    def testNoUnnecessaryPadding(self):
31
32
33
34
35
        ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
        for d in ds:
            species = d['species']
            non_padding = (species >= 0)[:, -1].nonzero()
            self.assertGreater(non_padding.numel(), 0)
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    def testReEnter(self):
        # make sure that a dataset can be iterated multiple times
        ds = torchani.data.load(dataset_path)
        for d in ds:
            pass
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.subtract_self_energies(sae_dict)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.species_to_indices()
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.shuffle()
        entered = False
        for d in ds:
            entered = True
            pass
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.collate(batch_size)
        entered = False
        for d in ds:
            entered = True
            pass
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.cache()
        entered = False
        for d in ds:
            entered = True
            pass
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

    def testShapeInference(self):
        shifter = torchani.EnergyShifter(None)
        ds = torchani.data.load(dataset_path).subtract_self_energies(shifter)
        len(ds)
        ds = ds.species_to_indices()
        len(ds)
        ds = ds.shuffle()
        len(ds)
        ds = ds.collate(batch_size)
        len(ds)

111
112
113
114
115
116
117
118
119
120
121
122
    def testSAE(self):
        tolerance = 1e-5
        shifter = torchani.EnergyShifter(None)
        torchani.data.load(dataset_path).subtract_self_energies(shifter)
        true_self_energies = torch.tensor([-19.354171758844188,
                                           -19.354171758844046,
                                           -54.712238523648587,
                                           -75.162829556770987], dtype=torch.float64)
        diff = torch.abs(true_self_energies - shifter.self_energies)
        for e in diff:
            self.assertLess(e, tolerance)

123
124
125
126
127
128
129
    def testDataloader(self):
        shifter = torchani.EnergyShifter(None)
        dataset = list(torchani.data.load(dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle())
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
        for i in loader:
            pass

130
131
132

if __name__ == '__main__':
    unittest.main()