import torchani import unittest import copy import tempfile import os import torch import torchani.pyanitools as pyanitools import torchani.data from math import ceil from bisect import bisect from pickle import dump, load path = os.path.dirname(os.path.realpath(__file__)) dataset_dir = os.path.join(path, 'dataset') class TestDataset(unittest.TestCase): def setUp(self, data_path=dataset_dir): self.data_path = data_path self.ds = torchani.data.load_dataset(data_path) def testLen(self): # compute data length using Dataset l1 = len(self.ds) # compute data lenght using pyanitools l2 = 0 for f in os.listdir(self.data_path): f = os.path.join(self.data_path, f) if os.path.isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')): for j in pyanitools.anidataloader(f): l2 += j['energies'].shape[0] # compute data length using iterator l3 = len(list(self.ds)) # these lengths should match self.assertEqual(l1, l2) self.assertEqual(l1, l3) def testNumChunks(self): chunksize = 64 # compute number of chunks using batch sampler bs = torchani.data.BatchSampler(self.ds, chunksize, 1) l1 = len(bs) # compute number of chunks using pyanitools l2 = 0 for f in os.listdir(self.data_path): f = os.path.join(self.data_path, f) if os.path.isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')): for j in pyanitools.anidataloader(f): conformations = j['energies'].shape[0] l2 += ceil(conformations / chunksize) # compute number of chunks using iterator l3 = len(list(bs)) # these lengths should match self.assertEqual(l1, l2) self.assertEqual(l1, l3) def testNumBatches(self): chunksize = 64 batch_chunks = 4 # compute number of batches using batch sampler bs = torchani.data.BatchSampler(self.ds, chunksize, batch_chunks) l1 = len(bs) # compute number of batches by simple math bs2 = torchani.data.BatchSampler(self.ds, chunksize, 1) l2 = ceil(len(bs2) / batch_chunks) # compute number of batches using iterator l3 = len(list(bs)) # these lengths should match self.assertEqual(l1, l2) self.assertEqual(l1, l3) def testBatchSize1(self): bs = torchani.data.BatchSampler(self.ds, 1, 1) self.assertEqual(len(bs), len(self.ds)) def testSplitSize(self): chunksize = 64 bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunks = len(bs) ds1, ds2 = torchani.data.random_split( self.ds, [200, chunks-200], chunksize) bs1 = torchani.data.BatchSampler(ds1, chunksize, 1) bs2 = torchani.data.BatchSampler(ds2, chunksize, 1) self.assertEqual(len(bs1), 200) self.assertEqual(len(bs2), chunks-200) def testSplitNoOverlap(self): chunksize = 64 bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunks = len(bs) ds1, ds2 = torchani.data.random_split( self.ds, [200, chunks-200], chunksize) indices1 = ds1.dataset.indices indices2 = ds2.dataset.indices self.assertEqual(len(indices1), len(ds1)) self.assertEqual(len(indices2), len(ds2)) self.assertEqual(len(indices1), len(set(indices1))) self.assertEqual(len(indices2), len(set(indices2))) self.assertEqual(len(self.ds), len(set(indices1+indices2))) def _testMolSizes(self, ds): for i in range(len(ds)): l = bisect(ds.cumulative_sizes, i) moli = ds[i][0].item() for j in range(len(ds)): l2 = bisect(ds.cumulative_sizes, j) molj = ds[j][0].item() if l == l2: self.assertEqual(moli, molj) else: if moli == molj: print(i, j) self.assertNotEqual(moli, molj) def testMolSizes(self): chunksize = 8 bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunks = len(bs) ds1, ds2 = torchani.data.random_split( self.ds, [50, chunks-50], chunksize) self._testMolSizes(ds1) def testSaveLoad(self): chunksize = 8 bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunks = len(bs) ds1, ds2 = torchani.data.random_split( self.ds, [50, chunks-50], chunksize) tmpdir = tempfile.TemporaryDirectory() tmpdirname = tmpdir.name filename = os.path.join(tmpdirname, 'test.obj') with open(filename, 'wb') as f: dump(ds1, f) with open(filename, 'rb') as f: ds1_loaded = load(f) self.assertEqual(len(ds1), len(ds1_loaded)) self.assertListEqual(ds1.sizes, ds1_loaded.sizes) self.assertIsInstance(ds1_loaded, torchani.data.ANIDataset) for i in range(len(ds1)): i1 = ds1[i] i2 = ds1_loaded[i] molid1 = i1[0].item() molid2 = i2[0].item() self.assertEqual(molid1, molid2) xyz1 = i1[1] xyz2 = i2[1] maxdiff = torch.max(torch.abs(xyz1-xyz2)).item() self.assertEqual(maxdiff, 0) e1 = i1[2].item() e2 = i2[2].item() self.assertEqual(e1, e2) if __name__ == '__main__': unittest.main()