Commit 9639d716 authored by Richard Xue's avatar Richard Xue Committed by Gao, Xiang
Browse files

Add split to new dataset API (#299)

* split

* clean

* docs

* docs

* Update new.py
parent b9e2c259
...@@ -27,6 +27,7 @@ Datasets ...@@ -27,6 +27,7 @@ Datasets
.. autofunction:: torchani.data.find_threshold .. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset .. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset .. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset .. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache .. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset .. autoclass:: torchani.data.BatchedANIDataset
......
...@@ -54,6 +54,12 @@ class TestShuffledData(unittest.TestCase): ...@@ -54,6 +54,12 @@ class TestShuffledData(unittest.TestCase):
for i, _ in enumerate(self.ds): for i, _ in enumerate(self.ds):
pbar.update(i) pbar.update(i)
def testSplitDataset(self):
print('=> test splitting dataset')
train_ds, val_ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2, validation_split=0.1)
frac = len(val_ds) / (len(val_ds) + len(train_ds))
self.assertLess(abs(frac - 0.1), 0.05)
def testNoUnnecessaryPadding(self): def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding') print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks): for i, chunk in enumerate(self.chunks):
...@@ -91,11 +97,13 @@ class TestCachedData(unittest.TestCase): ...@@ -91,11 +97,13 @@ class TestCachedData(unittest.TestCase):
def testLoadDataset(self): def testLoadDataset(self):
print('=> test loading all dataset') print('=> test loading all dataset')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total ' self.ds.load()
+ 'batches: {}, batch_size: {}'.format(len(self.ds), batch_size),
len(self.ds)) def testSplitDataset(self):
for i, _ in enumerate(self.ds): print('=> test splitting dataset')
pbar.update(i) train_dataset, val_dataset = self.ds.split(0.1)
frac = len(val_dataset) / len(self.ds)
self.assertLess(abs(frac - 0.1), 0.05)
def testNoUnnecessaryPadding(self): def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding') print('=> checking No Unnecessary Padding')
......
...@@ -76,8 +76,8 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -76,8 +76,8 @@ class CachedDataset(torch.utils.data.Dataset):
anidata = anidataloader(file_path) anidata = anidataloader(file_path)
anidata_size = anidata.group_size() anidata_size = anidata.group_size()
enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED self.enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED
if enable_pkbar: if self.enable_pkbar:
pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size) pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size)
for i, molecule in enumerate(anidata): for i, molecule in enumerate(anidata):
...@@ -92,7 +92,7 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -92,7 +92,7 @@ class CachedDataset(torch.utils.data.Dataset):
self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']])) self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']]))
self.data_self_energies += list(np.tile(self_energies, (num_conformations, 1))) self.data_self_energies += list(np.tile(self_energies, (num_conformations, 1)))
if enable_pkbar: if self.enable_pkbar:
pbar.update(i) pbar.update(i)
if subtract_self_energies: if subtract_self_energies:
...@@ -172,6 +172,43 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -172,6 +172,43 @@ class CachedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self.length return self.length
def split(self, validation_split):
"""Split dataset into traning and validaiton.
Arguments:
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
"""
val_size = int(validation_split * len(self))
train_size = len(self) - val_size
ds = []
if self.enable_pkbar:
message = ('=> processing, splitting and caching dataset into cpu memory: \n'
+ 'total batches: {}, train batches: {}, val batches: {}, batch_size: {}')
pbar = pkbar.Pbar(message.format(len(self), train_size, val_size, self.batch_size),
len(self))
for i, _ in enumerate(self):
ds.append(self[i])
if self.enable_pkbar:
pbar.update(i)
train_dataset = ds[:train_size]
val_dataset = ds[train_size:]
return train_dataset, val_dataset
def load(self):
"""Cache dataset into CPU memory. If not called, dataset will be cached during the first epoch.
"""
if self.enable_pkbar:
pbar = pkbar.Pbar('=> processing and caching dataset into cpu memory: \ntotal '
+ 'batches: {}, batch_size: {}'.format(len(self), self.batch_size),
len(self))
for i, _ in enumerate(self):
if self.enable_pkbar:
pbar.update(i)
@staticmethod @staticmethod
def sort_list_with_index(inputs, index): def sort_list_with_index(inputs, index):
return [inputs[i] for i in index] return [inputs[i] for i in index]
...@@ -229,6 +266,7 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -229,6 +266,7 @@ class CachedDataset(torch.utils.data.Dataset):
def ShuffledDataset(file_path, def ShuffledDataset(file_path,
batch_size=1000, num_workers=0, shuffle=True, chunk_threshold=20, batch_size=1000, num_workers=0, shuffle=True, chunk_threshold=20,
validation_split=0.0,
species_order=['H', 'C', 'N', 'O'], species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=False, subtract_self_energies=False,
self_energies=[-0.600953, -38.08316, -54.707756, -75.194466]): self_energies=[-0.600953, -38.08316, -54.707756, -75.194466]):
...@@ -242,6 +280,8 @@ def ShuffledDataset(file_path, ...@@ -242,6 +280,8 @@ def ShuffledDataset(file_path,
shuffle (bool): whether to shuffle. shuffle (bool): whether to shuffle.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` chunk_threshold (int): threshould to split batch into chunks. Set to ``None``
will not split chunks. will not split chunks.
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
species_order (list): a list which specify how species are transfomed to int. species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``. for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``. subtract_self_energies (bool): whether subtract self energies from ``energies``.
...@@ -273,14 +313,27 @@ def ShuffledDataset(file_path, ...@@ -273,14 +313,27 @@ def ShuffledDataset(file_path,
def my_collate_fn(data, chunk_threshold=chunk_threshold): def my_collate_fn(data, chunk_threshold=chunk_threshold):
return collate_fn(data, chunk_threshold) return collate_fn(data, chunk_threshold)
data_loader = torch.utils.data.DataLoader(dataset=dataset, val_size = int(validation_split * len(dataset))
batch_size=batch_size, train_size = len(dataset) - val_size
shuffle=shuffle, train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
num_workers=num_workers,
pin_memory=False, train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
collate_fn=my_collate_fn) batch_size=batch_size,
shuffle=shuffle,
return data_loader num_workers=num_workers,
pin_memory=False,
collate_fn=my_collate_fn)
if val_size == 0:
return train_data_loader
val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=False,
collate_fn=my_collate_fn)
return train_data_loader, val_data_loader
class TorchData(torch.utils.data.Dataset): class TorchData(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment