Unverified Commit ba3036d1 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add limited python2 compatibility (#13)

parent 87c4184a
...@@ -11,4 +11,5 @@ steps: ...@@ -11,4 +11,5 @@ steps:
image: '${{build-torchani}}' image: '${{build-torchani}}'
commands: commands:
- flake8 - flake8
- python setup.py test - python setup.py test
\ No newline at end of file - python2 setup.py test
\ No newline at end of file
import torchani import sys
import unittest
import tempfile if sys.version_info.major >= 3:
import os import torchani
import torch import unittest
import torchani.pyanitools as pyanitools import tempfile
import torchani.data import os
from math import ceil import torch
from bisect import bisect import torchani.pyanitools as pyanitools
from pickle import dump, load import torchani.data
from math import ceil
from bisect import bisect
path = os.path.dirname(os.path.realpath(__file__)) from pickle import dump, load
dataset_dir = os.path.join(path, 'dataset')
path = os.path.dirname(os.path.realpath(__file__))
dataset_dir = os.path.join(path, 'dataset')
class TestDataset(unittest.TestCase):
class TestDataset(unittest.TestCase):
def setUp(self, data_path=dataset_dir):
self.data_path = data_path def setUp(self, data_path=dataset_dir):
self.ds = torchani.data.load_dataset(data_path) self.data_path = data_path
self.ds = torchani.data.load_dataset(data_path)
def testLen(self):
# compute data length using Dataset def testLen(self):
l1 = len(self.ds) # compute data length using Dataset
# compute data lenght using pyanitools l1 = len(self.ds)
l2 = 0 # compute data lenght using pyanitools
for f in os.listdir(self.data_path): l2 = 0
f = os.path.join(self.data_path, f) for f in os.listdir(self.data_path):
if os.path.isfile(f) and \ f = os.path.join(self.data_path, f)
(f.endswith('.h5') or f.endswith('.hdf5')): if os.path.isfile(f) and \
for j in pyanitools.anidataloader(f): (f.endswith('.h5') or f.endswith('.hdf5')):
l2 += j['energies'].shape[0] for j in pyanitools.anidataloader(f):
# compute data length using iterator l2 += j['energies'].shape[0]
l3 = len(list(self.ds)) # compute data length using iterator
# these lengths should match l3 = len(list(self.ds))
self.assertEqual(l1, l2) # these lengths should match
self.assertEqual(l1, l3) self.assertEqual(l1, l2)
self.assertEqual(l1, l3)
def testNumChunks(self):
chunksize = 64 def testNumChunks(self):
# compute number of chunks using batch sampler chunksize = 64
bs = torchani.data.BatchSampler(self.ds, chunksize, 1) # compute number of chunks using batch sampler
l1 = len(bs) bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
# compute number of chunks using pyanitools l1 = len(bs)
l2 = 0 # compute number of chunks using pyanitools
for f in os.listdir(self.data_path): l2 = 0
f = os.path.join(self.data_path, f) for f in os.listdir(self.data_path):
if os.path.isfile(f) and \ f = os.path.join(self.data_path, f)
(f.endswith('.h5') or f.endswith('.hdf5')): if os.path.isfile(f) and \
for j in pyanitools.anidataloader(f): (f.endswith('.h5') or f.endswith('.hdf5')):
conformations = j['energies'].shape[0] for j in pyanitools.anidataloader(f):
l2 += ceil(conformations / chunksize) conformations = j['energies'].shape[0]
# compute number of chunks using iterator l2 += ceil(conformations / chunksize)
l3 = len(list(bs)) # compute number of chunks using iterator
# these lengths should match l3 = len(list(bs))
self.assertEqual(l1, l2) # these lengths should match
self.assertEqual(l1, l3) self.assertEqual(l1, l2)
self.assertEqual(l1, l3)
def testNumBatches(self):
chunksize = 64 def testNumBatches(self):
batch_chunks = 4 chunksize = 64
# compute number of batches using batch sampler batch_chunks = 4
bs = torchani.data.BatchSampler(self.ds, chunksize, batch_chunks) # compute number of batches using batch sampler
l1 = len(bs) bs = torchani.data.BatchSampler(self.ds, chunksize, batch_chunks)
# compute number of batches by simple math l1 = len(bs)
bs2 = torchani.data.BatchSampler(self.ds, chunksize, 1) # compute number of batches by simple math
l2 = ceil(len(bs2) / batch_chunks) bs2 = torchani.data.BatchSampler(self.ds, chunksize, 1)
# compute number of batches using iterator l2 = ceil(len(bs2) / batch_chunks)
l3 = len(list(bs)) # compute number of batches using iterator
# these lengths should match l3 = len(list(bs))
self.assertEqual(l1, l2) # these lengths should match
self.assertEqual(l1, l3) self.assertEqual(l1, l2)
self.assertEqual(l1, l3)
def testBatchSize1(self):
bs = torchani.data.BatchSampler(self.ds, 1, 1) def testBatchSize1(self):
self.assertEqual(len(bs), len(self.ds)) bs = torchani.data.BatchSampler(self.ds, 1, 1)
self.assertEqual(len(bs), len(self.ds))
def testSplitSize(self):
chunksize = 64 def testSplitSize(self):
bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunksize = 64
chunks = len(bs) bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
ds1, ds2 = torchani.data.random_split( chunks = len(bs)
self.ds, [200, chunks-200], chunksize) ds1, ds2 = torchani.data.random_split(
bs1 = torchani.data.BatchSampler(ds1, chunksize, 1) self.ds, [200, chunks-200], chunksize)
bs2 = torchani.data.BatchSampler(ds2, chunksize, 1) bs1 = torchani.data.BatchSampler(ds1, chunksize, 1)
self.assertEqual(len(bs1), 200) bs2 = torchani.data.BatchSampler(ds2, chunksize, 1)
self.assertEqual(len(bs2), chunks-200) self.assertEqual(len(bs1), 200)
self.assertEqual(len(bs2), chunks-200)
def testSplitNoOverlap(self):
chunksize = 64 def testSplitNoOverlap(self):
bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunksize = 64
chunks = len(bs) bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
ds1, ds2 = torchani.data.random_split( chunks = len(bs)
self.ds, [200, chunks-200], chunksize) ds1, ds2 = torchani.data.random_split(
indices1 = ds1.dataset.indices self.ds, [200, chunks-200], chunksize)
indices2 = ds2.dataset.indices indices1 = ds1.dataset.indices
self.assertEqual(len(indices1), len(ds1)) indices2 = ds2.dataset.indices
self.assertEqual(len(indices2), len(ds2)) self.assertEqual(len(indices1), len(ds1))
self.assertEqual(len(indices1), len(set(indices1))) self.assertEqual(len(indices2), len(ds2))
self.assertEqual(len(indices2), len(set(indices2))) self.assertEqual(len(indices1), len(set(indices1)))
self.assertEqual(len(self.ds), len(set(indices1+indices2))) self.assertEqual(len(indices2), len(set(indices2)))
self.assertEqual(len(self.ds), len(set(indices1+indices2)))
def _testMolSizes(self, ds):
for i in range(len(ds)): def _testMolSizes(self, ds):
left = bisect(ds.cumulative_sizes, i) for i in range(len(ds)):
moli = ds[i][0].item() left = bisect(ds.cumulative_sizes, i)
for j in range(len(ds)): moli = ds[i][0].item()
left2 = bisect(ds.cumulative_sizes, j) for j in range(len(ds)):
molj = ds[j][0].item() left2 = bisect(ds.cumulative_sizes, j)
if left == left2: molj = ds[j][0].item()
self.assertEqual(moli, molj) if left == left2:
else: self.assertEqual(moli, molj)
if moli == molj: else:
print(i, j) if moli == molj:
self.assertNotEqual(moli, molj) print(i, j)
self.assertNotEqual(moli, molj)
def testMolSizes(self):
chunksize = 8 def testMolSizes(self):
bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunksize = 8
chunks = len(bs) bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
ds1, ds2 = torchani.data.random_split( chunks = len(bs)
self.ds, [50, chunks-50], chunksize) ds1, ds2 = torchani.data.random_split(
self._testMolSizes(ds1) self.ds, [50, chunks-50], chunksize)
self._testMolSizes(ds1)
def testSaveLoad(self):
chunksize = 8 def testSaveLoad(self):
bs = torchani.data.BatchSampler(self.ds, chunksize, 1) chunksize = 8
chunks = len(bs) bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
ds1, ds2 = torchani.data.random_split( chunks = len(bs)
self.ds, [50, chunks-50], chunksize) ds1, ds2 = torchani.data.random_split(
self.ds, [50, chunks-50], chunksize)
tmpdir = tempfile.TemporaryDirectory()
tmpdirname = tmpdir.name tmpdir = tempfile.TemporaryDirectory()
filename = os.path.join(tmpdirname, 'test.obj') tmpdirname = tmpdir.name
filename = os.path.join(tmpdirname, 'test.obj')
with open(filename, 'wb') as f:
dump(ds1, f) with open(filename, 'wb') as f:
dump(ds1, f)
with open(filename, 'rb') as f:
ds1_loaded = load(f) with open(filename, 'rb') as f:
ds1_loaded = load(f)
self.assertEqual(len(ds1), len(ds1_loaded))
self.assertListEqual(ds1.sizes, ds1_loaded.sizes) self.assertEqual(len(ds1), len(ds1_loaded))
self.assertIsInstance(ds1_loaded, torchani.data.ANIDataset) self.assertListEqual(ds1.sizes, ds1_loaded.sizes)
self.assertIsInstance(ds1_loaded, torchani.data.ANIDataset)
for i in range(len(ds1)):
i1 = ds1[i] for i in range(len(ds1)):
i2 = ds1_loaded[i] i1 = ds1[i]
molid1 = i1[0].item() i2 = ds1_loaded[i]
molid2 = i2[0].item() molid1 = i1[0].item()
self.assertEqual(molid1, molid2) molid2 = i2[0].item()
xyz1 = i1[1] self.assertEqual(molid1, molid2)
xyz2 = i2[1] xyz1 = i1[1]
maxdiff = torch.max(torch.abs(xyz1-xyz2)).item() xyz2 = i2[1]
self.assertEqual(maxdiff, 0) maxdiff = torch.max(torch.abs(xyz1-xyz2)).item()
e1 = i1[2].item() self.assertEqual(maxdiff, 0)
e2 = i2[2].item() e1 = i1[2].item()
self.assertEqual(e1, e2) e2 = i2[2].item()
self.assertEqual(e1, e2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment