test_data.py 3.87 KB
Newer Older
1
import os
2
import torch
3
import torchani
4
import unittest
5
from torchani.data.cache_aev import cache_aev
6

7
path = os.path.dirname(os.path.realpath(__file__))
Gao, Xiang's avatar
Gao, Xiang committed
8
9
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4')
dataset_path2 = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
10
batch_size = 256
11
12
builtins = torchani.neurochem.Builtins()
consts = builtins.consts
13
aev_computer = builtins.aev_computer
14
15
16
17


class TestData(unittest.TestCase):

18
    def setUp(self):
19
20
21
        self.ds = torchani.data.BatchedANIDataset(dataset_path,
                                                  consts.species_to_tensor,
                                                  batch_size)
22

23
    def _assertTensorEqual(self, t1, t2):
24
        self.assertLess((t1 - t2).abs().max().item(), 1e-6)
25
26
27
28
29
30
31
32

    def testSplitBatch(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 8), dtype=torch.long)
        coordinates2 = torch.randn(2, 8, 3)
        species3 = torch.randint(4, (10, 20), dtype=torch.long)
        coordinates3 = torch.randn(10, 20, 3)
33
        species, coordinates = torchani.utils.pad_coordinates([
34
35
36
37
38
            (species1, coordinates1),
            (species2, coordinates2),
            (species3, coordinates3),
        ])
        natoms = (species >= 0).to(torch.long).sum(1)
39
        chunks = torchani.data.split_batch(natoms, species, coordinates)
40
41
42
43
44
45
46
47
        start = 0
        last = None
        for s, c in chunks:
            n = (s >= 0).to(torch.long).sum(1)
            if last is not None:
                self.assertNotEqual(last[-1], n[0])
            conformations = s.shape[0]
            self.assertGreater(conformations, 0)
48
49
            s_ = species[start:(start + conformations), ...]
            c_ = coordinates[start:(start + conformations), ...]
Gao, Xiang's avatar
Gao, Xiang committed
50
            s_, c_ = torchani.utils.strip_redundant_padding(s_, c_)
51
52
53
54
            self._assertTensorEqual(s, s_)
            self._assertTensorEqual(c, c_)
            start += conformations

55
        s, c = torchani.utils.pad_coordinates(chunks)
56
57
58
        self._assertTensorEqual(s, species)
        self._assertTensorEqual(c, coordinates)

59
    def testTensorShape(self):
60
        for i in self.ds:
Gao, Xiang's avatar
Gao, Xiang committed
61
62
            input_, output = i
            species, coordinates = torchani.utils.pad_coordinates(input_)
63
64
65
66
67
68
69
70
71
            energies = output['energies']
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
            self.assertEqual(coordinates.shape[1], coordinates.shape[1])
            self.assertEqual(coordinates.shape[0], coordinates.shape[0])
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
72

73
74
    def testNoUnnecessaryPadding(self):
        for i in self.ds:
Gao, Xiang's avatar
Gao, Xiang committed
75
76
            for input_ in i[0]:
                species, _ = input_
77
78
                non_padding = (species >= 0)[:, -1].nonzero()
                self.assertGreater(non_padding.numel(), 0)
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def testAEVCacheLoader(self):
        tmpdir = os.path.join(os.getcwd(), 'tmp')
        if not os.path.exists(tmpdir):
            os.makedirs(tmpdir)
        cache_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
        loader = torchani.data.AEVCacheLoader(tmpdir)
        ds = loader.dataset
        aev_computer_dev = aev_computer.to(loader.dataset.device)
        for _ in range(3):
            for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
                for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
                    self._assertTensorEqual(s1, s2)
                    s2, a2 = aev_computer_dev((s2, c))
                    self._assertTensorEqual(s1, s2)
                    self._assertTensorEqual(a, a2)

96
97
98

if __name__ == '__main__':
    unittest.main()