"tests/compute/test_traversal.py" did not exist on "6d96a97f179ba8cdb27b40f5bf4cd42d8fa6c4b6"
test_data.py 1.41 KB
Newer Older
1
import os
2
import torchani
3
import unittest
4

5
path = os.path.dirname(os.path.realpath(__file__))
6
dataset_path = os.path.join(path, 'dataset/ani-1x/sample.h5')
7
batch_size = 256
8
9
ani1x = torchani.models.ANI1x()
consts = ani1x.consts
10
sae_dict = ani1x.sae_dict
11
aev_computer = ani1x.aev_computer
12
13
14
15
16


class TestData(unittest.TestCase):

    def testTensorShape(self):
17
18
19
20
21
        ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
        for d in ds:
            species = d['species']
            coordinates = d['coordinates']
            energies = d['energies']
22
23
24
25
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
Richard Xue's avatar
Richard Xue committed
26
            self.assertEqual(coordinates.shape[:2], species.shape[:2])
27
28
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
29

30
    def testNoUnnecessaryPadding(self):
31
32
33
34
35
        ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
        for d in ds:
            species = d['species']
            non_padding = (species >= 0)[:, -1].nonzero()
            self.assertGreater(non_padding.numel(), 0)
36

37
38
39

if __name__ == '__main__':
    unittest.main()