test_data.py 6.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import sys

if sys.version_info.major >= 3:
    import torchani
    import unittest
    import tempfile
    import os
    import torch
    import torchani.pyanitools as pyanitools
    import torchani.data
    from math import ceil
    from bisect import bisect
    from pickle import dump, load

    path = os.path.dirname(os.path.realpath(__file__))
    dataset_dir = os.path.join(path, 'dataset')

    class TestDataset(unittest.TestCase):

        def setUp(self, data_path=dataset_dir):
            self.data_path = data_path
            self.ds = torchani.data.load_dataset(data_path)

        def testLen(self):
            # compute data length using Dataset
            l1 = len(self.ds)
            # compute data lenght using pyanitools
            l2 = 0
            for f in os.listdir(self.data_path):
                f = os.path.join(self.data_path, f)
                if os.path.isfile(f) and \
                   (f.endswith('.h5') or f.endswith('.hdf5')):
                    for j in pyanitools.anidataloader(f):
                        l2 += j['energies'].shape[0]
            # compute data length using iterator
            l3 = len(list(self.ds))
            # these lengths should match
            self.assertEqual(l1, l2)
            self.assertEqual(l1, l3)

        def testNumChunks(self):
            chunksize = 64
            # compute number of chunks using batch sampler
            bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
            l1 = len(bs)
            # compute number of chunks using pyanitools
            l2 = 0
            for f in os.listdir(self.data_path):
                f = os.path.join(self.data_path, f)
                if os.path.isfile(f) and \
                   (f.endswith('.h5') or f.endswith('.hdf5')):
                    for j in pyanitools.anidataloader(f):
                        conformations = j['energies'].shape[0]
                        l2 += ceil(conformations / chunksize)
            # compute number of chunks using iterator
            l3 = len(list(bs))
            # these lengths should match
            self.assertEqual(l1, l2)
            self.assertEqual(l1, l3)

        def testNumBatches(self):
            chunksize = 64
            batch_chunks = 4
            # compute number of batches using batch sampler
            bs = torchani.data.BatchSampler(self.ds, chunksize, batch_chunks)
            l1 = len(bs)
            # compute number of batches by simple math
            bs2 = torchani.data.BatchSampler(self.ds, chunksize, 1)
            l2 = ceil(len(bs2) / batch_chunks)
            # compute number of batches using iterator
            l3 = len(list(bs))
            # these lengths should match
            self.assertEqual(l1, l2)
            self.assertEqual(l1, l3)

        def testBatchSize1(self):
            bs = torchani.data.BatchSampler(self.ds, 1, 1)
            self.assertEqual(len(bs), len(self.ds))

        def testSplitSize(self):
            chunksize = 64
            bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
            chunks = len(bs)
            ds1, ds2 = torchani.data.random_split(
                self.ds, [200, chunks-200], chunksize)
            bs1 = torchani.data.BatchSampler(ds1, chunksize, 1)
            bs2 = torchani.data.BatchSampler(ds2, chunksize, 1)
            self.assertEqual(len(bs1), 200)
            self.assertEqual(len(bs2), chunks-200)

        def testSplitNoOverlap(self):
            chunksize = 64
            bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
            chunks = len(bs)
            ds1, ds2 = torchani.data.random_split(
                self.ds, [200, chunks-200], chunksize)
            indices1 = ds1.dataset.indices
            indices2 = ds2.dataset.indices
            self.assertEqual(len(indices1), len(ds1))
            self.assertEqual(len(indices2), len(ds2))
            self.assertEqual(len(indices1), len(set(indices1)))
            self.assertEqual(len(indices2), len(set(indices2)))
            self.assertEqual(len(self.ds), len(set(indices1+indices2)))

        def _testMolSizes(self, ds):
            for i in range(len(ds)):
                left = bisect(ds.cumulative_sizes, i)
                moli = ds[i][0].item()
                for j in range(len(ds)):
                    left2 = bisect(ds.cumulative_sizes, j)
                    molj = ds[j][0].item()
                    if left == left2:
                        self.assertEqual(moli, molj)
                    else:
                        if moli == molj:
                            print(i, j)
                        self.assertNotEqual(moli, molj)

        def testMolSizes(self):
            chunksize = 8
            bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
            chunks = len(bs)
            ds1, ds2 = torchani.data.random_split(
                self.ds, [50, chunks-50], chunksize)
            self._testMolSizes(ds1)

        def testSaveLoad(self):
            chunksize = 8
            bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
            chunks = len(bs)
            ds1, ds2 = torchani.data.random_split(
                self.ds, [50, chunks-50], chunksize)

            tmpdir = tempfile.TemporaryDirectory()
            tmpdirname = tmpdir.name
            filename = os.path.join(tmpdirname, 'test.obj')

            with open(filename, 'wb') as f:
                dump(ds1, f)

            with open(filename, 'rb') as f:
                ds1_loaded = load(f)

            self.assertEqual(len(ds1), len(ds1_loaded))
            self.assertListEqual(ds1.sizes, ds1_loaded.sizes)
            self.assertIsInstance(ds1_loaded, torchani.data.ANIDataset)

            for i in range(len(ds1)):
                i1 = ds1[i]
                i2 = ds1_loaded[i]
                molid1 = i1[0].item()
                molid2 = i2[0].item()
                self.assertEqual(molid1, molid2)
                xyz1 = i1[1]
                xyz2 = i2[1]
                maxdiff = torch.max(torch.abs(xyz1-xyz2)).item()
                self.assertEqual(maxdiff, 0)
                e1 = i1[2].item()
                e2 = i2[2].item()
                self.assertEqual(e1, e2)

    if __name__ == '__main__':
        unittest.main()