Unverified Commit 14a62dc4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Refactor BatchedANIDataset and create API for splittting datasets (#237)

parent 0ce36d82
...@@ -24,6 +24,8 @@ Datasets ...@@ -24,6 +24,8 @@ Datasets
======== ========
.. automodule:: torchani.data .. automodule:: torchani.data
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset .. autoclass:: torchani.data.BatchedANIDataset
.. autoclass:: torchani.data.AEVCacheLoader .. autoclass:: torchani.data.AEVCacheLoader
.. automodule:: torchani.data.cache_aev .. automodule:: torchani.data.cache_aev
......
...@@ -62,14 +62,10 @@ species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO') ...@@ -62,14 +62,10 @@ species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
############################################################################### ###############################################################################
# Now let's setup datasets. Note that here for our demo purpose, we set both # Now let's setup datasets. These paths assumes the user run this script under
# training set and validation set the ``ani_gdb_s01.h5`` in TorchANI's # the ``examples`` directory of TorchANI's repository. If you download this
# repository. This allows this program to finish very quick, because that # script, you should manually set the path of these files in your system before
# dataset is very small. But this is wrong and should be avoided for any # this script can run successfully.
# serious training. These paths assumes the user run this script under the
# ``examples`` directory of TorchANI's repository. If you download this script,
# you should manually set the path of these files in your system before this
# script can run successfully.
# #
# Also note that we need to subtracting energies by the self energies of all # Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable # atoms for each molecule. This makes the range of energies in a reasonable
...@@ -81,18 +77,13 @@ try: ...@@ -81,18 +77,13 @@ try:
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
except NameError: except NameError:
path = os.getcwd() path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560 batch_size = 2560
training = torchani.data.BatchedANIDataset( training, validation = torchani.data.load_ani_dataset(
training_path, species_to_tensor, batch_size, device=device, dspath, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset]) transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
validation = torchani.data.BatchedANIDataset(
validation_path, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
############################################################################### ###############################################################################
# When iterating the dataset, we will get pairs of input and output # When iterating the dataset, we will get pairs of input and output
......
...@@ -46,25 +46,19 @@ try: ...@@ -46,25 +46,19 @@ try:
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
except NameError: except NameError:
path = os.getcwd() path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani-1x/sample.h5') dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
validation_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 2560 batch_size = 2560
############################################################################### ###############################################################################
# The code to create the dataset is a bit different: we need to manually # The code to create the dataset is a bit different: we need to manually
# specify that ``atomic_properties=['forces']`` so that forces will be read # specify that ``atomic_properties=['forces']`` so that forces will be read
# from hdf5 files. # from hdf5 files.
training = torchani.data.BatchedANIDataset(
training_path, species_to_tensor, batch_size, device=device,
atomic_properties=['forces'],
transform=[energy_shifter.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset( training, validation = torchani.data.load_ani_dataset(
validation_path, species_to_tensor, batch_size, device=device, dspath, species_to_tensor, batch_size, device=device,
atomic_properties=['forces'], atomic_properties=['forces'],
transform=[energy_shifter.subtract_from_dataset]) transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
############################################################################### ###############################################################################
# When iterating the dataset, we will get pairs of input and output # When iterating the dataset, we will get pairs of input and output
......
...@@ -32,8 +32,7 @@ try: ...@@ -32,8 +32,7 @@ try:
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
except NameError: except NameError:
path = os.getcwd() path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') # noqa: E501
# checkpoint file to save model when validation RMSE improves # checkpoint file to save model when validation RMSE improves
model_checkpoint = 'model.pt' model_checkpoint = 'model.pt'
...@@ -102,13 +101,9 @@ writer = torch.utils.tensorboard.SummaryWriter(log_dir=log) ...@@ -102,13 +101,9 @@ writer = torch.utils.tensorboard.SummaryWriter(log_dir=log)
############################################################################### ###############################################################################
# Now load training and validation datasets into memory. # Now load training and validation datasets into memory.
training = torchani.data.BatchedANIDataset( training, validation = torchani.data.load_ani_dataset(
training_path, consts.species_to_tensor, batch_size, device=device, dspath, consts.species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset]) transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
validation = torchani.data.BatchedANIDataset(
validation_path, consts.species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
############################################################################### ###############################################################################
# We have tools to deal with the chunking (see :ref:`training-example`). These # We have tools to deal with the chunking (see :ref:`training-example`). These
......
...@@ -16,9 +16,9 @@ aev_computer = builtins.aev_computer ...@@ -16,9 +16,9 @@ aev_computer = builtins.aev_computer
class TestData(unittest.TestCase): class TestData(unittest.TestCase):
def setUp(self): def setUp(self):
self.ds = torchani.data.BatchedANIDataset(dataset_path, self.ds = torchani.data.load_ani_dataset(dataset_path,
consts.species_to_tensor, consts.species_to_tensor,
batch_size) batch_size)
def _assertTensorEqual(self, t1, t2): def _assertTensorEqual(self, t1, t2):
self.assertLess((t1 - t2).abs().max().item(), 1e-6) self.assertLess((t1 - t2).abs().max().item(), 1e-6)
......
...@@ -10,6 +10,7 @@ from .. import utils, neurochem, aev ...@@ -10,6 +10,7 @@ from .. import utils, neurochem, aev
import pickle import pickle
import numpy as np import numpy as np
from scipy.sparse import bsr_matrix from scipy.sparse import bsr_matrix
import warnings
default_device = 'cuda' if torch.cuda.is_available() else 'cpu' default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
...@@ -159,11 +160,69 @@ def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_siz ...@@ -159,11 +160,69 @@ def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_siz
return batches return batches
class BatchedANIDataset(Dataset): class PaddedBatchChunkDataset(Dataset):
"""Load data from hdf5 files, create minibatches, and convert to tensors.
This is already a dataset of batches, so when iterated, a batch rather def __init__(self, atomic_properties, properties, batch_size,
than a single data point will be yielded. dtype=torch.get_default_dtype(), device=default_device):
super().__init__()
self.device = device
self.dtype = dtype
# convert to desired dtype
for k in properties:
properties[k] = properties[k].to(dtype)
for k in atomic_properties:
if k == 'species':
continue
atomic_properties[k] = atomic_properties[k].to(dtype)
self.batches = split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size)
def __getitem__(self, idx):
atomic_properties, properties = self.batches[idx]
atomic_properties, properties = atomic_properties.copy(), properties.copy()
species_coordinates = []
for chunk in atomic_properties:
for k in chunk:
chunk[k] = chunk[k].to(self.device)
species_coordinates.append((chunk['species'], chunk['coordinates']))
for k in properties:
properties[k] = properties[k].to(self.device)
properties['atomic'] = atomic_properties
return species_coordinates, properties
def __len__(self):
return len(self.batches)
class BatchedANIDataset(PaddedBatchChunkDataset):
"""Same as :func:`torchani.data.load_ani_dataset`. This API has been deprecated."""
def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device):
self.properties = properties
self.atomic_properties = atomic_properties
warnings.warn("BatchedANIDataset is deprecated; use load_ani_dataset()", DeprecationWarning)
atomic_properties, properties = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties)
# do transformations on data
for t in transform:
atomic_properties, properties = t(atomic_properties, properties)
super().__init__(atomic_properties, properties, batch_size, dtype, device)
def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device,
split=(None,)):
"""Load ANI dataset from hdf5 files, and split into subsets.
The return datasets are already a dataset of batches, so when iterated, a
batch rather than a single data point will be yielded.
Since each batch might contain molecules of very different sizes, putting Since each batch might contain molecules of very different sizes, putting
the whole batch into a single tensor would require adding ghost atoms to the whole batch into a single tensor would require adding ghost atoms to
...@@ -217,52 +276,67 @@ class BatchedANIDataset(Dataset): ...@@ -217,52 +276,67 @@ class BatchedANIDataset(Dataset):
dtype (:class:`torch.dtype`): dtype of coordinates and properties to dtype (:class:`torch.dtype`): dtype of coordinates and properties to
to convert the dataset to. to convert the dataset to.
device (:class:`torch.dtype`): device to put tensors when iterating. device (:class:`torch.dtype`): device to put tensors when iterating.
split (list): as sequence of integers or floats or ``None``. Integers
are interpreted as number of elements, floats are interpreted as
percentage, and ``None`` are interpreted as the rest of the dataset
and can only appear as the last element of :class:`split`. For
example, if the whole dataset has 10000 entry, and split is
``(5000, 0.1, None)``, then this function will create 3 datasets,
where the first dataset contains 5000 elements, the second dataset
contains ``int(0.1 * 10000)``, which is 1000, and the third dataset
will contains ``10000 - 5000 - 1000`` elements. By default this
creates only a single dataset.
Returns:
An instance of :class:`torchani.data.PaddedBatchChunkDataset` if there is
only one element in :attr:`split`, otherwise returns a tuple of the same
classes according to :attr:`split`.
.. _pyanitools.py: .. _pyanitools.py:
https://github.com/isayev/ASE_ANI/blob/master/lib/pyanitools.py https://github.com/isayev/ASE_ANI/blob/master/lib/pyanitools.py
""" """
atomic_properties_, properties_ = load_and_pad_whole_dataset(
def __init__(self, path, species_tensor_converter, batch_size, path, species_tensor_converter, shuffle, properties, atomic_properties)
shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device): # do transformations on data
super(BatchedANIDataset, self).__init__() for t in transform:
self.properties = properties atomic_properties_, properties_ = t(atomic_properties_, properties_)
self.atomic_properties = atomic_properties
self.device = device molecules = atomic_properties_['species'].shape[0]
self.dtype = dtype atomic_keys = ['species', 'coordinates', *atomic_properties]
keys = properties
atomic_properties, properties = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties) # compute size of each subset
split_ = []
# do transformations on data total = 0
for t in transform: for index, size in enumerate(split):
atomic_properties, properties = t(atomic_properties, properties) if isinstance(size, float):
size = int(size * molecules)
# convert to desired dtype if size is None:
for k in properties: assert index == len(split) - 1
properties[k] = properties[k].to(dtype) size = molecules - total
for k in atomic_properties: split_.append(size)
if k == 'species': total += size
continue
atomic_properties[k] = atomic_properties[k].to(dtype) # split
start = 0
self.batches = split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size) splitted = []
for size in split_:
def __getitem__(self, idx): ap = {k: atomic_properties_[k][start:start + size] for k in atomic_keys}
atomic_properties, properties = self.batches[idx] p = {k: properties_[k][start:start + size] for k in keys}
atomic_properties, properties = atomic_properties.copy(), properties.copy() start += size
species_coordinates = [] splitted.append((ap, p))
for chunk in atomic_properties:
for k in chunk: # consturct batched dataset
chunk[k] = chunk[k].to(self.device) ret = []
species_coordinates.append((chunk['species'], chunk['coordinates'])) for ap, p in splitted:
for k in properties: ds = PaddedBatchChunkDataset(ap, p, batch_size, dtype, device)
properties[k] = properties[k].to(self.device) ds.properties = properties
properties['atomic'] = atomic_properties ds.atomic_properties = atomic_properties
return species_coordinates, properties ret.append(ds)
if len(ret) == 1:
def __len__(self): return ret[0]
return len(self.batches) return tuple(ret)
class AEVCacheLoader(Dataset): class AEVCacheLoader(Dataset):
...@@ -343,13 +417,23 @@ class SparseAEVCacheLoader(AEVCacheLoader): ...@@ -343,13 +417,23 @@ class SparseAEVCacheLoader(AEVCacheLoader):
builtin = neurochem.Builtins() builtin = neurochem.Builtins()
def create_aev_cache(dataset, aev_computer, output, enable_tqdm=True, encoder=lambda x: x): def create_aev_cache(dataset, aev_computer, output, progress_bar=True, encoder=lambda *x: x):
"""Cache AEV for the given dataset.
Arguments:
dataset (:class:`torchani.data.PaddedBatchChunkDataset`): the dataset to be cached
aev_computer (:class:`torchani.AEVComputer`): the AEV computer used to compute aev
output (str): path to the directory where cache will be stored
progress_bar (bool): whether to show progress bar
encoder (:class:`collections.abc.Callable`): The callable
(species, aev) -> (encoded_species, encoded_aev) that encode species and aev
"""
# dump out the dataset # dump out the dataset
filename = os.path.join(output, 'dataset') filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
pickle.dump(dataset, f) pickle.dump(dataset, f)
if enable_tqdm: if progress_bar:
import tqdm import tqdm
indices = tqdm.trange(len(dataset)) indices = tqdm.trange(len(dataset))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment