test_data.py 791 Bytes
Newer Older
1
2
3
4
import sys

if sys.version_info.major >= 3:
    import os
5
    import unittest
6
7
8
    import torchani.data

    path = os.path.dirname(os.path.realpath(__file__))
9
    path = os.path.join(path, '../dataset')
10
11
12

    class TestDataset(unittest.TestCase):

13
14
        def _test_chunksize(self, chunksize):
            ds = torchani.data.ANIDataset(path, chunksize)
15
            for i, _ in ds:
16
                self.assertLessEqual(i['coordinates'].shape[0], chunksize)
17

18
19
        def testChunk64(self):
            self._test_chunksize(64)
20

21
22
        def testChunk128(self):
            self._test_chunksize(128)
23

24
25
        def testChunk32(self):
            self._test_chunksize(32)
26

27
28
        def testChunk256(self):
            self._test_chunksize(256)
29
30
31

    if __name__ == '__main__':
        unittest.main()