Unverified Commit 14a62dc4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Refactor BatchedANIDataset and create API for splittting datasets (#237)

parent 0ce36d82
......@@ -24,6 +24,8 @@ Datasets
========
.. automodule:: torchani.data
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
.. autoclass:: torchani.data.AEVCacheLoader
.. automodule:: torchani.data.cache_aev
......
......@@ -62,14 +62,10 @@ species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
###############################################################################
# Now let's setup datasets. Note that here for our demo purpose, we set both
# training set and validation set the ``ani_gdb_s01.h5`` in TorchANI's
# repository. This allows this program to finish very quick, because that
# dataset is very small. But this is wrong and should be avoided for any
# serious training. These paths assumes the user run this script under the
# ``examples`` directory of TorchANI's repository. If you download this script,
# you should manually set the path of these files in your system before this
# script can run successfully.
# Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this
# script, you should manually set the path of these files in your system before
# this script can run successfully.
#
# Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable
......@@ -81,18 +77,13 @@ try:
path = os.path.dirname(os.path.realpath(__file__))
except NameError:
path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560
training = torchani.data.BatchedANIDataset(
training_path, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset(
validation_path, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
###############################################################################
# When iterating the dataset, we will get pairs of input and output
......
......@@ -46,25 +46,19 @@ try:
path = os.path.dirname(os.path.realpath(__file__))
except NameError:
path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
validation_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 2560
###############################################################################
# The code to create the dataset is a bit different: we need to manually
# specify that ``atomic_properties=['forces']`` so that forces will be read
# from hdf5 files.
training = torchani.data.BatchedANIDataset(
training_path, species_to_tensor, batch_size, device=device,
atomic_properties=['forces'],
transform=[energy_shifter.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset(
validation_path, species_to_tensor, batch_size, device=device,
training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, device=device,
atomic_properties=['forces'],
transform=[energy_shifter.subtract_from_dataset])
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
###############################################################################
# When iterating the dataset, we will get pairs of input and output
......
......@@ -32,8 +32,7 @@ try:
path = os.path.dirname(os.path.realpath(__file__))
except NameError:
path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') # noqa: E501
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
# checkpoint file to save model when validation RMSE improves
model_checkpoint = 'model.pt'
......@@ -102,13 +101,9 @@ writer = torch.utils.tensorboard.SummaryWriter(log_dir=log)
###############################################################################
# Now load training and validation datasets into memory.
training = torchani.data.BatchedANIDataset(
training_path, consts.species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset(
validation_path, consts.species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
training, validation = torchani.data.load_ani_dataset(
dspath, consts.species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
###############################################################################
# We have tools to deal with the chunking (see :ref:`training-example`). These
......
......@@ -16,7 +16,7 @@ aev_computer = builtins.aev_computer
class TestData(unittest.TestCase):
def setUp(self):
self.ds = torchani.data.BatchedANIDataset(dataset_path,
self.ds = torchani.data.load_ani_dataset(dataset_path,
consts.species_to_tensor,
batch_size)
......
......@@ -10,6 +10,7 @@ from .. import utils, neurochem, aev
import pickle
import numpy as np
from scipy.sparse import bsr_matrix
import warnings
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
......@@ -159,11 +160,69 @@ def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_siz
return batches
class BatchedANIDataset(Dataset):
"""Load data from hdf5 files, create minibatches, and convert to tensors.
class PaddedBatchChunkDataset(Dataset):
This is already a dataset of batches, so when iterated, a batch rather
than a single data point will be yielded.
def __init__(self, atomic_properties, properties, batch_size,
dtype=torch.get_default_dtype(), device=default_device):
super().__init__()
self.device = device
self.dtype = dtype
# convert to desired dtype
for k in properties:
properties[k] = properties[k].to(dtype)
for k in atomic_properties:
if k == 'species':
continue
atomic_properties[k] = atomic_properties[k].to(dtype)
self.batches = split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size)
def __getitem__(self, idx):
atomic_properties, properties = self.batches[idx]
atomic_properties, properties = atomic_properties.copy(), properties.copy()
species_coordinates = []
for chunk in atomic_properties:
for k in chunk:
chunk[k] = chunk[k].to(self.device)
species_coordinates.append((chunk['species'], chunk['coordinates']))
for k in properties:
properties[k] = properties[k].to(self.device)
properties['atomic'] = atomic_properties
return species_coordinates, properties
def __len__(self):
return len(self.batches)
class BatchedANIDataset(PaddedBatchChunkDataset):
"""Same as :func:`torchani.data.load_ani_dataset`. This API has been deprecated."""
def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device):
self.properties = properties
self.atomic_properties = atomic_properties
warnings.warn("BatchedANIDataset is deprecated; use load_ani_dataset()", DeprecationWarning)
atomic_properties, properties = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties)
# do transformations on data
for t in transform:
atomic_properties, properties = t(atomic_properties, properties)
super().__init__(atomic_properties, properties, batch_size, dtype, device)
def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device,
split=(None,)):
"""Load ANI dataset from hdf5 files, and split into subsets.
The return datasets are already a dataset of batches, so when iterated, a
batch rather than a single data point will be yielded.
Since each batch might contain molecules of very different sizes, putting
the whole batch into a single tensor would require adding ghost atoms to
......@@ -217,52 +276,67 @@ class BatchedANIDataset(Dataset):
dtype (:class:`torch.dtype`): dtype of coordinates and properties to
to convert the dataset to.
device (:class:`torch.dtype`): device to put tensors when iterating.
split (list): as sequence of integers or floats or ``None``. Integers
are interpreted as number of elements, floats are interpreted as
percentage, and ``None`` are interpreted as the rest of the dataset
and can only appear as the last element of :class:`split`. For
example, if the whole dataset has 10000 entry, and split is
``(5000, 0.1, None)``, then this function will create 3 datasets,
where the first dataset contains 5000 elements, the second dataset
contains ``int(0.1 * 10000)``, which is 1000, and the third dataset
will contains ``10000 - 5000 - 1000`` elements. By default this
creates only a single dataset.
Returns:
An instance of :class:`torchani.data.PaddedBatchChunkDataset` if there is
only one element in :attr:`split`, otherwise returns a tuple of the same
classes according to :attr:`split`.
.. _pyanitools.py:
https://github.com/isayev/ASE_ANI/blob/master/lib/pyanitools.py
"""
def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device):
super(BatchedANIDataset, self).__init__()
self.properties = properties
self.atomic_properties = atomic_properties
self.device = device
self.dtype = dtype
atomic_properties, properties = load_and_pad_whole_dataset(
atomic_properties_, properties_ = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties)
# do transformations on data
for t in transform:
atomic_properties, properties = t(atomic_properties, properties)
# convert to desired dtype
for k in properties:
properties[k] = properties[k].to(dtype)
for k in atomic_properties:
if k == 'species':
continue
atomic_properties[k] = atomic_properties[k].to(dtype)
self.batches = split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size)
def __getitem__(self, idx):
atomic_properties, properties = self.batches[idx]
atomic_properties, properties = atomic_properties.copy(), properties.copy()
species_coordinates = []
for chunk in atomic_properties:
for k in chunk:
chunk[k] = chunk[k].to(self.device)
species_coordinates.append((chunk['species'], chunk['coordinates']))
for k in properties:
properties[k] = properties[k].to(self.device)
properties['atomic'] = atomic_properties
return species_coordinates, properties
def __len__(self):
return len(self.batches)
atomic_properties_, properties_ = t(atomic_properties_, properties_)
molecules = atomic_properties_['species'].shape[0]
atomic_keys = ['species', 'coordinates', *atomic_properties]
keys = properties
# compute size of each subset
split_ = []
total = 0
for index, size in enumerate(split):
if isinstance(size, float):
size = int(size * molecules)
if size is None:
assert index == len(split) - 1
size = molecules - total
split_.append(size)
total += size
# split
start = 0
splitted = []
for size in split_:
ap = {k: atomic_properties_[k][start:start + size] for k in atomic_keys}
p = {k: properties_[k][start:start + size] for k in keys}
start += size
splitted.append((ap, p))
# consturct batched dataset
ret = []
for ap, p in splitted:
ds = PaddedBatchChunkDataset(ap, p, batch_size, dtype, device)
ds.properties = properties
ds.atomic_properties = atomic_properties
ret.append(ds)
if len(ret) == 1:
return ret[0]
return tuple(ret)
class AEVCacheLoader(Dataset):
......@@ -343,13 +417,23 @@ class SparseAEVCacheLoader(AEVCacheLoader):
builtin = neurochem.Builtins()
def create_aev_cache(dataset, aev_computer, output, enable_tqdm=True, encoder=lambda x: x):
def create_aev_cache(dataset, aev_computer, output, progress_bar=True, encoder=lambda *x: x):
"""Cache AEV for the given dataset.
Arguments:
dataset (:class:`torchani.data.PaddedBatchChunkDataset`): the dataset to be cached
aev_computer (:class:`torchani.AEVComputer`): the AEV computer used to compute aev
output (str): path to the directory where cache will be stored
progress_bar (bool): whether to show progress bar
encoder (:class:`collections.abc.Callable`): The callable
(species, aev) -> (encoded_species, encoded_aev) that encode species and aev
"""
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if enable_tqdm:
if progress_bar:
import tqdm
indices = tqdm.trange(len(dataset))
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment