"vscode:/vscode.git/clone" did not exist on "dbec2f18478cf1de196d566e5fcc9b18754d0fa4"
test_data.py 1.38 KB
Newer Older
1
import os
2
import torchani
3
import unittest
4

5
path = os.path.dirname(os.path.realpath(__file__))
6
7
8
9
10
11
12
13
dataset_path = os.path.join(path, '../dataset')
print(dataset_path)
batch_size = 256
aev = torchani.AEVComputer()


class TestData(unittest.TestCase):

14
15
16
17
18
    def setUp(self):
        self.ds = torchani.training.BatchedANIDataset(dataset_path,
                                                      aev.species,
                                                      batch_size)

19
    def testTensorShape(self):
20
        for i in self.ds:
21
22
23
24
25
26
27
28
29
30
31
            input, output = i
            species, coordinates = input
            energies = output['energies']
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
            self.assertEqual(coordinates.shape[1], coordinates.shape[1])
            self.assertEqual(coordinates.shape[0], coordinates.shape[0])
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
32

33
34
35
36
37
38
39
    def testNoUnnecessaryPadding(self):
        for i in self.ds:
            input, _ = i
            species, _ = input
            non_padding = (species >= 0)[:, -1].nonzero()
            self.assertGreater(non_padding.numel(), 0)

40
41
42

if __name__ == '__main__':
    unittest.main()