test_data.py 4.13 KB
Newer Older
1
import os
2
import torch
3
import torchani
4
import unittest
Jinze Xue's avatar
Jinze Xue committed
5
from torchani.testing import TestCase
6

7
path = os.path.dirname(os.path.realpath(__file__))
8
dataset_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
9
batch_size = 256
Gao, Xiang's avatar
Gao, Xiang committed
10
ani1x_sae_dict = {'H': -0.60095298, 'C': -38.08316124, 'N': -54.7077577, 'O': -75.19446356}
11
12


Jinze Xue's avatar
Jinze Xue committed
13
class TestData(TestCase):
14
15

    def testTensorShape(self):
Gao, Xiang's avatar
Gao, Xiang committed
16
        ds = torchani.data.load(dataset_path).subtract_self_energies(ani1x_sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
17
18
19
20
        for d in ds:
            species = d['species']
            coordinates = d['coordinates']
            energies = d['energies']
21
22
23
24
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
Richard Xue's avatar
Richard Xue committed
25
            self.assertEqual(coordinates.shape[:2], species.shape[:2])
26
27
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
28

29
    def testNoUnnecessaryPadding(self):
Gao, Xiang's avatar
Gao, Xiang committed
30
        ds = torchani.data.load(dataset_path).subtract_self_energies(ani1x_sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
31
32
33
34
        for d in ds:
            species = d['species']
            non_padding = (species >= 0)[:, -1].nonzero()
            self.assertGreater(non_padding.numel(), 0)
35

36
37
38
    def testReEnter(self):
        # make sure that a dataset can be iterated multiple times
        ds = torchani.data.load(dataset_path)
39
        for _ in ds:
40
41
42
43
44
45
            pass
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

Gao, Xiang's avatar
Gao, Xiang committed
46
        ds = ds.subtract_self_energies(ani1x_sae_dict)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.species_to_indices()
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.shuffle()
        entered = False
        for d in ds:
            entered = True
            pass
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.collate(batch_size)
        entered = False
        for d in ds:
            entered = True
            pass
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

        ds = ds.cache()
        entered = False
        for d in ds:
            entered = True
            pass
        self.assertTrue(entered)
        entered = False
        for d in ds:
            entered = True
        self.assertTrue(entered)

    def testShapeInference(self):
        shifter = torchani.EnergyShifter(None)
        ds = torchani.data.load(dataset_path).subtract_self_energies(shifter)
        len(ds)
        ds = ds.species_to_indices()
        len(ds)
        ds = ds.shuffle()
        len(ds)
        ds = ds.collate(batch_size)
        len(ds)

110
111
112
113
114
115
116
    def testSAE(self):
        shifter = torchani.EnergyShifter(None)
        torchani.data.load(dataset_path).subtract_self_energies(shifter)
        true_self_energies = torch.tensor([-19.354171758844188,
                                           -19.354171758844046,
                                           -54.712238523648587,
                                           -75.162829556770987], dtype=torch.float64)
117
        self.assertEqual(true_self_energies, shifter.self_energies)
118

119
120
121
122
    def testDataloader(self):
        shifter = torchani.EnergyShifter(None)
        dataset = list(torchani.data.load(dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle())
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
123
        for _ in loader:
124
125
            pass

126
127
128

if __name__ == '__main__':
    unittest.main()