test_data.py 1.07 KB
Newer Older
1
import os
2
import torchani
3
import unittest
4

5
path = os.path.dirname(os.path.realpath(__file__))
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
dataset_path = os.path.join(path, '../dataset')
print(dataset_path)
batch_size = 256
aev = torchani.AEVComputer()


class TestData(unittest.TestCase):

    def testTensorShape(self):
        ds = torchani.training.BatchedANIDataset(dataset_path, aev.species,
                                                 batch_size)
        for i in ds:
            input, output = i
            species, coordinates = input
            energies = output['energies']
            self.assertEqual(len(species.shape), 2)
            self.assertLessEqual(species.shape[0], batch_size)
            self.assertEqual(len(coordinates.shape), 3)
            self.assertEqual(coordinates.shape[2], 3)
            self.assertEqual(coordinates.shape[1], coordinates.shape[1])
            self.assertEqual(coordinates.shape[0], coordinates.shape[0])
            self.assertEqual(len(energies.shape), 1)
            self.assertEqual(coordinates.shape[0], energies.shape[0])
29

30
31
32

if __name__ == '__main__':
    unittest.main()