test_data_new.py 6.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torchani
import unittest
import pkbar
import torch
import os

path = os.path.dirname(os.path.realpath(__file__))
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5')

batch_size = 2560
chunk_threshold = 5

13
14
15
16
17
18
19
20
21
22
23
24
other_properties = {'properties': ['dipoles', 'forces', 'energies'],
                    'padding_values': [None, 0, None],
                    'padded_shapes': [(batch_size, 3), (batch_size, -1, 3), (batch_size, )],
                    'dtypes': [torch.float32, torch.float32, torch.float64],
                    }

other_properties = {'properties': ['energies'],
                    'padding_values': [None],
                    'padded_shapes': [(batch_size, )],
                    'dtypes': [torch.float64],
                    }

25
26
27
28
29
30
31
32
33
34
35
36
37

class TestFindThreshold(unittest.TestCase):
    def setUp(self):
        print('.. check find threshold to split chunks')

    def testFindThreshould(self):
        torchani.data.find_threshold(dspath, batch_size=batch_size, threshold_max=10)


class TestShuffledData(unittest.TestCase):

    def setUp(self):
        print('.. setup shuffle dataset')
38
39
40
41
42
        self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size,
                                                chunk_threshold=chunk_threshold,
                                                num_workers=2,
                                                other_properties=other_properties,
                                                subtract_self_energies=True)
43
44
45
46
47
48
        self.chunks, self.properties = iter(self.ds).next()

    def testTensorShape(self):
        print('=> checking tensor shape')
        print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
        batch_len = 0
49
        print('1. chunks')
50
        for i, chunk in enumerate(self.chunks):
51
52
            print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
                  'coordinates:', list(chunk[1].size()), chunk[1].dtype)
53
54
55
56
57
58
59
            # check dtype
            self.assertEqual(chunk[0].dtype, torch.int64)
            self.assertEqual(chunk[1].dtype, torch.float32)
            # check shape
            self.assertEqual(chunk[1].shape[2], 3)
            self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
            batch_len += chunk[0].shape[0]
60
61
62
63
64
65
66
67
68
        print('2. properties')
        for i, key in enumerate(other_properties['properties']):
            print(key, list(self.properties[key].size()), self.properties[key].dtype)
            # check dtype
            self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
            # shape[0] == batch_size
            self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
            # check len(shape)
            self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
69
70
71
72
73
74
75
76
77

    def testLoadDataset(self):
        print('=> test loading all dataset')
        pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
                          + 'batches: {}, batch_size: {}'.format(len(self.ds), batch_size),
                          len(self.ds))
        for i, _ in enumerate(self.ds):
            pbar.update(i)

78
79
80
81
82
83
    def testSplitDataset(self):
        print('=> test splitting dataset')
        train_ds, val_ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2, validation_split=0.1)
        frac = len(val_ds) / (len(val_ds) + len(train_ds))
        self.assertLess(abs(frac - 0.1), 0.05)

84
85
86
87
88
89
90
91
92
93
94
95
    def testNoUnnecessaryPadding(self):
        print('=> checking No Unnecessary Padding')
        for i, chunk in enumerate(self.chunks):
            species, _ = chunk
            non_padding = (species >= 0)[:, -1].nonzero()
            self.assertGreater(non_padding.numel(), 0)


class TestCachedData(unittest.TestCase):

    def setUp(self):
        print('.. setup cached dataset')
96
97
98
99
        self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
                                              chunk_threshold=chunk_threshold,
                                              other_properties=other_properties,
                                              subtract_self_energies=True)
100
101
102
103
104
105
        self.chunks, self.properties = self.ds[0]

    def testTensorShape(self):
        print('=> checking tensor shape')
        print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
        batch_len = 0
106
        print('1. chunks')
107
        for i, chunk in enumerate(self.chunks):
108
109
            print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
                  'coordinates:', list(chunk[1].size()), chunk[1].dtype)
110
111
112
113
114
115
116
            # check dtype
            self.assertEqual(chunk[0].dtype, torch.int64)
            self.assertEqual(chunk[1].dtype, torch.float32)
            # check shape
            self.assertEqual(chunk[1].shape[2], 3)
            self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
            batch_len += chunk[0].shape[0]
117
118
119
120
121
122
123
124
125
        print('2. properties')
        for i, key in enumerate(other_properties['properties']):
            print(key, list(self.properties[key].size()), self.properties[key].dtype)
            # check dtype
            self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
            # shape[0] == batch_size
            self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
            # check len(shape)
            self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
126
127
128

    def testLoadDataset(self):
        print('=> test loading all dataset')
129
130
131
132
133
134
135
        self.ds.load()

    def testSplitDataset(self):
        print('=> test splitting dataset')
        train_dataset, val_dataset = self.ds.split(0.1)
        frac = len(val_dataset) / len(self.ds)
        self.assertLess(abs(frac - 0.1), 0.05)
136
137
138
139
140
141
142
143
144
145
146

    def testNoUnnecessaryPadding(self):
        print('=> checking No Unnecessary Padding')
        for i, chunk in enumerate(self.chunks):
            species, _ = chunk
            non_padding = (species >= 0)[:, -1].nonzero()
            self.assertGreater(non_padding.numel(), 0)


if __name__ == "__main__":
    unittest.main()