test_data.py 669 Bytes
Newer Older
1
2
3
import os
import unittest
import torchani.data
4

5
6
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset')
7
8


9
class TestDataset(unittest.TestCase):
10

11
12
13
14
    def _test_chunksize(self, chunksize):
        ds = torchani.data.ANIDataset(path, chunksize)
        for i, _ in ds:
            self.assertLessEqual(i['coordinates'].shape[0], chunksize)
15

16
17
    def testChunk64(self):
        self._test_chunksize(64)
18

19
20
    def testChunk128(self):
        self._test_chunksize(128)
21

22
23
    def testChunk32(self):
        self._test_chunksize(32)
24

25
26
    def testChunk256(self):
        self._test_chunksize(256)
27

28
29
30

if __name__ == '__main__':
    unittest.main()