Commit a41cc8b7 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by Gao, Xiang
Browse files

Replace BatchedANIDataset (#272)

parent 1888d734
......@@ -20,7 +20,7 @@ class TestIgnite(unittest.TestCase):
aev_computer = ani1x.aev_computer
nnp = copy.deepcopy(ani1x.neural_networks[0])
shift_energy = ani1x.energy_shifter
ds = torchani.data.BatchedANIDataset(
ds = torchani.data.load_ani_dataset(
path, ani1x.consts.species_to_tensor, batchsize,
transform=[shift_energy.subtract_from_dataset],
device=aev_computer.EtaR.device)
......
......@@ -47,7 +47,7 @@ container = container.to(device)
if parser.dataset_path.endswith('.h5') or \
parser.dataset_path.endswith('.hdf5') or \
os.path.isdir(parser.dataset_path):
dataset = torchani.data.BatchedANIDataset(
dataset = torchani.data.load_ani_dataset(
parser.dataset_path, consts.species_to_tensor, parser.batch_size,
device=device, transform=[shift_energy.subtract_from_dataset])
datasets = [dataset]
......
......@@ -49,7 +49,7 @@ class Flatten(torch.nn.Module):
nnp = torch.nn.Sequential(aev_computer, model, Flatten()).to(device)
dataset = torchani.data.BatchedANIDataset(
dataset = torchani.data.load_ani_dataset(
parser.dataset_path, consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
......
......@@ -483,7 +483,7 @@ def _cache_aev(output, dataset_path, batchsize, device, constfile,
else:
transform = ()
dataset = BatchedANIDataset(
dataset = load_ani_dataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
......@@ -507,4 +507,4 @@ def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
SparseAEVCacheLoader.encode_aev, **kwargs)
__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev']
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev']
......@@ -10,7 +10,7 @@ from ignite.contrib.metrics.regression import MaximumAbsoluteError
class Container(torch.nn.ModuleDict):
"""Each minibatch is splitted into chunks, as explained in the docstring of
:class:`torchani.data.BatchedANIDataset`, as a result, it is impossible to
:method:`torchani.data.load_ani_dataset`, as a result, it is impossible to
use :class:`torchani.AEVComputer`, :class:`torchani.ANIModel` directly with
ignite. This class is designed to solve this issue.
......
......@@ -287,7 +287,7 @@ if sys.version_info[0] > 2:
try:
import ignite
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric, MaxAEMetric
from ..data import BatchedANIDataset # noqa: E402
from ..data import load_ani_dataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
except ImportError:
raise RuntimeError(
......@@ -306,7 +306,7 @@ if sys.version_info[0] > 2:
self.imports.RMSEMetric = RMSEMetric
self.imports.MaxAEMetric = MaxAEMetric
self.imports.MAEMetric = MAEMetric
self.imports.BatchedANIDataset = BatchedANIDataset
self.imports.load_ani_dataset = load_ani_dataset
self.imports.AEVCacheLoader = AEVCacheLoader
self.warned = False
......@@ -596,11 +596,11 @@ if sys.version_info[0] > 2:
self.training_set = self.imports.AEVCacheLoader(training_path)
self.validation_set = self.imports.AEVCacheLoader(validation_path)
else:
self.training_set = self.imports.BatchedANIDataset(
self.training_set = self.imports.load_ani_dataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = self.imports.BatchedANIDataset(
self.validation_set = self.imports.load_ani_dataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment