test_data.py 5.25 KB
Newer Older
1
import os
2
import torch
3
import torchani
4
import unittest
5
from torchani.data.cache_aev import cache_aev, cache_sparse_aev
6

7
path = os.path.dirname(os.path.realpath(__file__))
Gao, Xiang's avatar
Gao, Xiang committed
8
9
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4')
dataset_path2 = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
10
batch_size = 256
11
12
builtins = torchani.neurochem.Builtins()
consts = builtins.consts
13
aev_computer = builtins.aev_computer
14
15
16
17


class TestData(unittest.TestCase):

18
    def setUp(self):
19
20
21
        self.ds = torchani.data.BatchedANIDataset(dataset_path,
                                                  consts.species_to_tensor,
                                                  batch_size)
22

23
    def _assertTensorEqual(self, t1, t2):
24
        self.assertLess((t1 - t2).abs().max().item(), 1e-6)
25
26
27
28
29
30
31
32

    def testSplitBatch(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 8), dtype=torch.long)
        coordinates2 = torch.randn(2, 8, 3)
        species3 = torch.randint(4, (10, 20), dtype=torch.long)
        coordinates3 = torch.randn(10, 20, 3)
33
34
35
36
        species_coordinates = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
            {'species': species3, 'coordinates': coordinates3},
37
        ])
38
39
        species = species_coordinates['species']
        coordinates = species_coordinates['coordinates']
40
        natoms = (species >= 0).to(torch.long).sum(1)
41
        chunks = torchani.data.split_batch(natoms, species_coordinates)
42
43
        start = 0
        last = None
44
45
46
        for chunk in chunks:
            s = chunk['species']
            c = chunk['coordinates']
47
48
49
50
51
            n = (s >= 0).to(torch.long).sum(1)
            if last is not None:
                self.assertNotEqual(last[-1], n[0])
            conformations = s.shape[0]
            self.assertGreater(conformations, 0)
52
53
            s_ = species[start:(start + conformations), ...]
            c_ = coordinates[start:(start + conformations), ...]
54
55
56
            sc = torchani.utils.strip_redundant_padding({'species': s_, 'coordinates': c_})
            s_ = sc['species']
            c_ = sc['coordinates']
57
58
59
60
            self._assertTensorEqual(s, s_)
            self._assertTensorEqual(c, c_)
            start += conformations

61
62
63
        sc = torchani.utils.pad_atomic_properties(chunks)
        s = sc['species']
        c = sc['coordinates']
64
65
66
        self._assertTensorEqual(s, species)
        self._assertTensorEqual(c, coordinates)

67
    def testTensorShape(self):
68
        for i in self.ds:
Gao, Xiang's avatar
Gao, Xiang committed
69
            input_, output = i
70
71
72
73
            input_ = [{'species': x[0], 'coordinates': x[1]} for x in input_]
            species_coordinates = torchani.utils.pad_atomic_properties(input_)
            species = species_coordinates['species']
            coordinates = species_coordinates['coordinates']
74
75
76
77
78
79
80
81
82
            energies = output['energies']
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
            self.assertEqual(coordinates.shape[1], coordinates.shape[1])
            self.assertEqual(coordinates.shape[0], coordinates.shape[0])
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
83

84
85
    def testNoUnnecessaryPadding(self):
        for i in self.ds:
Gao, Xiang's avatar
Gao, Xiang committed
86
87
            for input_ in i[0]:
                species, _ = input_
88
89
                non_padding = (species >= 0)[:, -1].nonzero()
                self.assertGreater(non_padding.numel(), 0)
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def testAEVCacheLoader(self):
        tmpdir = os.path.join(os.getcwd(), 'tmp')
        if not os.path.exists(tmpdir):
            os.makedirs(tmpdir)
        cache_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
        loader = torchani.data.AEVCacheLoader(tmpdir)
        ds = loader.dataset
        aev_computer_dev = aev_computer.to(loader.dataset.device)
        for _ in range(3):
            for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
                for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
                    self._assertTensorEqual(s1, s2)
                    s2, a2 = aev_computer_dev((s2, c))
                    self._assertTensorEqual(s1, s2)
                    self._assertTensorEqual(a, a2)

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def testSparseAEVCacheLoader(self):
        tmpdir = os.path.join(os.getcwd(), 'tmp')
        if not os.path.exists(tmpdir):
            os.makedirs(tmpdir)
        cache_sparse_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
        loader = torchani.data.SparseAEVCacheLoader(tmpdir)
        ds = loader.dataset
        aev_computer_dev = aev_computer.to(loader.dataset.device)
        for _ in range(3):
            for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
                for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
                    self._assertTensorEqual(s1, s2)
                    s2, a2 = aev_computer_dev((s2, c))
                    self._assertTensorEqual(s1, s2)
                    self._assertTensorEqual(a, a2)

123
124
125

if __name__ == '__main__':
    unittest.main()