test_data.py 2.95 KB
Newer Older
1
import os
2
import torch
3
import torchani
4
import unittest
5

6
path = os.path.dirname(os.path.realpath(__file__))
7
8
dataset_path = os.path.join(path, '../dataset')
batch_size = 256
Gao, Xiang's avatar
Gao, Xiang committed
9
consts = torchani.buildins.consts
10
11
12
13


class TestData(unittest.TestCase):

14
15
    def setUp(self):
        self.ds = torchani.training.BatchedANIDataset(dataset_path,
16
                                                      consts.species,
17
18
                                                      batch_size)

19
20
21
22
23
24
25
26
27
28
    def _assertTensorEqual(self, t1, t2):
        self.assertEqual((t1-t2).abs().max(), 0)

    def testSplitBatch(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 8), dtype=torch.long)
        coordinates2 = torch.randn(2, 8, 3)
        species3 = torch.randint(4, (10, 20), dtype=torch.long)
        coordinates3 = torch.randn(10, 20, 3)
Gao, Xiang's avatar
Gao, Xiang committed
29
        species, coordinates = torchani.utils.pad_and_batch([
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
            (species1, coordinates1),
            (species2, coordinates2),
            (species3, coordinates3),
        ])
        natoms = (species >= 0).to(torch.long).sum(1)
        chunks = torchani.training.data.split_batch(natoms, species,
                                                    coordinates)
        start = 0
        last = None
        for s, c in chunks:
            n = (s >= 0).to(torch.long).sum(1)
            if last is not None:
                self.assertNotEqual(last[-1], n[0])
            conformations = s.shape[0]
            self.assertGreater(conformations, 0)
            s_ = species[start:start+conformations, ...]
            c_ = coordinates[start:start+conformations, ...]
Gao, Xiang's avatar
Gao, Xiang committed
47
            s_, c_ = torchani.utils.strip_redundant_padding(s_, c_)
48
49
50
51
            self._assertTensorEqual(s, s_)
            self._assertTensorEqual(c, c_)
            start += conformations

Gao, Xiang's avatar
Gao, Xiang committed
52
        s, c = torchani.utils.pad_and_batch(chunks)
53
54
55
        self._assertTensorEqual(s, species)
        self._assertTensorEqual(c, coordinates)

56
    def testTensorShape(self):
57
        for i in self.ds:
58
            input, output = i
Gao, Xiang's avatar
Gao, Xiang committed
59
            species, coordinates = torchani.utils.pad_and_batch(input)
60
61
62
63
64
65
66
67
68
            energies = output['energies']
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
            self.assertEqual(coordinates.shape[1], coordinates.shape[1])
            self.assertEqual(coordinates.shape[0], coordinates.shape[0])
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
69

70
71
    def testNoUnnecessaryPadding(self):
        for i in self.ds:
72
73
74
75
            for input in i[0]:
                species, _ = input
                non_padding = (species >= 0)[:, -1].nonzero()
                self.assertGreater(non_padding.numel(), 0)
76

77
78
79

if __name__ == '__main__':
    unittest.main()