test_data.py 3.58 KB
Newer Older
1
import os
2
import torch
3
import torchani
4
import unittest
5

6
path = os.path.dirname(os.path.realpath(__file__))
Gao, Xiang's avatar
Gao, Xiang committed
7
8
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4')
dataset_path2 = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
9
batch_size = 256
10
11
12
ani1x = torchani.models.ANI1x()
consts = ani1x.consts
aev_computer = ani1x.aev_computer
13
14
15
16


class TestData(unittest.TestCase):

17
    def setUp(self):
18
19
20
        self.ds = torchani.data.load_ani_dataset(dataset_path,
                                                 consts.species_to_tensor,
                                                 batch_size)
21

22
    def _assertTensorEqual(self, t1, t2):
23
        self.assertLess((t1 - t2).abs().max().item(), 1e-6)
24
25
26
27
28
29
30
31

    def testSplitBatch(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 8), dtype=torch.long)
        coordinates2 = torch.randn(2, 8, 3)
        species3 = torch.randint(4, (10, 20), dtype=torch.long)
        coordinates3 = torch.randn(10, 20, 3)
32
33
34
35
        species_coordinates = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
            {'species': species3, 'coordinates': coordinates3},
36
        ])
37
38
        species = species_coordinates['species']
        coordinates = species_coordinates['coordinates']
39
        natoms = (species >= 0).to(torch.long).sum(1)
40
        chunks = torchani.data.split_batch(natoms, species_coordinates)
41
42
        start = 0
        last = None
43
44
45
        for chunk in chunks:
            s = chunk['species']
            c = chunk['coordinates']
46
47
48
49
50
            n = (s >= 0).to(torch.long).sum(1)
            if last is not None:
                self.assertNotEqual(last[-1], n[0])
            conformations = s.shape[0]
            self.assertGreater(conformations, 0)
51
52
            s_ = species[start:(start + conformations), ...]
            c_ = coordinates[start:(start + conformations), ...]
53
54
55
            sc = torchani.utils.strip_redundant_padding({'species': s_, 'coordinates': c_})
            s_ = sc['species']
            c_ = sc['coordinates']
56
57
58
59
            self._assertTensorEqual(s, s_)
            self._assertTensorEqual(c, c_)
            start += conformations

60
61
62
        sc = torchani.utils.pad_atomic_properties(chunks)
        s = sc['species']
        c = sc['coordinates']
63
64
65
        self._assertTensorEqual(s, species)
        self._assertTensorEqual(c, coordinates)

66
    def testTensorShape(self):
67
        for i in self.ds:
Gao, Xiang's avatar
Gao, Xiang committed
68
            input_, output = i
69
70
71
72
            input_ = [{'species': x[0], 'coordinates': x[1]} for x in input_]
            species_coordinates = torchani.utils.pad_atomic_properties(input_)
            species = species_coordinates['species']
            coordinates = species_coordinates['coordinates']
73
74
75
76
77
            energies = output['energies']
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
Richard Xue's avatar
Richard Xue committed
78
            self.assertEqual(coordinates.shape[:2], species.shape[:2])
79
80
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
81

82
83
    def testNoUnnecessaryPadding(self):
        for i in self.ds:
Gao, Xiang's avatar
Gao, Xiang committed
84
85
            for input_ in i[0]:
                species, _ = input_
86
87
                non_padding = (species >= 0)[:, -1].nonzero()
                self.assertGreater(non_padding.numel(), 0)
88

89
90
91

if __name__ == '__main__':
    unittest.main()