Unverified Commit d10c7841 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Remove deprecated api (#428)

* Remove deprecated API and add a docstring to PaddedBatchChunkDataset

* Remove reference to deprecated API

* Added PaddedBatchChunkDataset to docs

* remove unused warnings
parent 8292fa97
...@@ -31,7 +31,7 @@ Datasets ...@@ -31,7 +31,7 @@ Datasets
.. autoclass:: torchani.data.CachedDataset .. autoclass:: torchani.data.CachedDataset
:members: :members:
.. autofunction:: torchani.data.load_ani_dataset .. autofunction:: torchani.data.load_ani_dataset
.. autoclass:: torchani.data.BatchedANIDataset .. autoclass:: torchani.data.PaddedBatchChunkDataset
......
...@@ -7,7 +7,6 @@ import os ...@@ -7,7 +7,6 @@ import os
from ._pyanitools import anidataloader from ._pyanitools import anidataloader
import torch import torch
from .. import utils from .. import utils
import warnings
from .new import CachedDataset, ShuffledDataset, find_threshold from .new import CachedDataset, ShuffledDataset, find_threshold
default_device = 'cuda' if torch.cuda.is_available() else 'cpu' default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
...@@ -159,6 +158,13 @@ def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_siz ...@@ -159,6 +158,13 @@ def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_siz
class PaddedBatchChunkDataset(Dataset): class PaddedBatchChunkDataset(Dataset):
r""" Dataset that contains batches in 'chunks', with padded structures
This dataset acts as a container of batches to be used when training. Each
of the batches is broken up into 'chunks', each of which is a tensor has
molecules with a smiliar number of atoms, but which have been padded with
dummy atoms in order for them to have the same tensor dimensions.
"""
def __init__(self, atomic_properties, properties, batch_size, def __init__(self, atomic_properties, properties, batch_size,
dtype=torch.get_default_dtype(), device=default_device): dtype=torch.get_default_dtype(), device=default_device):
...@@ -193,26 +199,6 @@ class PaddedBatchChunkDataset(Dataset): ...@@ -193,26 +199,6 @@ class PaddedBatchChunkDataset(Dataset):
return len(self.batches) return len(self.batches)
class BatchedANIDataset(PaddedBatchChunkDataset):
"""Same as :func:`torchani.data.load_ani_dataset`. This API has been deprecated."""
def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device):
self.properties = properties
self.atomic_properties = atomic_properties
warnings.warn("BatchedANIDataset is deprecated; use load_ani_dataset()", DeprecationWarning)
atomic_properties, properties = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties)
# do transformations on data
for t in transform:
atomic_properties, properties = t(atomic_properties, properties)
super().__init__(atomic_properties, properties, batch_size, dtype, device)
def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True, def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
rm_outlier=False, properties=('energies',), atomic_properties=(), rm_outlier=False, properties=('energies',), atomic_properties=(),
transform=(), dtype=torch.get_default_dtype(), device=default_device, transform=(), dtype=torch.get_default_dtype(), device=default_device,
...@@ -361,4 +347,4 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True, ...@@ -361,4 +347,4 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
return tuple(ret) return tuple(ret)
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'CachedDataset', 'ShuffledDataset', 'find_threshold'] __all__ = ['load_ani_dataset', 'PaddedBatchChunkDataset', 'CachedDataset', 'ShuffledDataset', 'find_threshold']
...@@ -172,9 +172,7 @@ class EnergyShifter(torch.nn.Module): ...@@ -172,9 +172,7 @@ class EnergyShifter(torch.nn.Module):
return self_energies.sum(dim=1) + intercept return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties): def subtract_from_dataset(self, atomic_properties, properties):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that """Transformer that subtracts self energies from a dataset"""
subtract self energies.
"""
if self.self_energies is None: if self.self_energies is None:
self_energies = self.sae_from_dataset(atomic_properties, properties) self_energies = self.sae_from_dataset(atomic_properties, properties)
self.self_energies = torch.tensor(self_energies, dtype=torch.double) self.self_energies = torch.tensor(self_energies, dtype=torch.double)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment