Unverified Commit 86500df0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove AEV cacher (#361)

* Remove AEV cacher

* more

* more

* more

* flake8

* further cleanup
parent 7cdd405c
......@@ -32,6 +32,5 @@ jobs:
run: ./download.sh
- name: Run submodules
run: |
python -m torchani.data.cache_aev tmp dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 256
python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.ipt dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.yaml dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
......@@ -29,10 +29,8 @@ Datasets
.. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
.. autoclass:: torchani.data.AEVCacheLoader
.. automodule:: torchani.data.cache_aev
Utilities
......
......@@ -2,7 +2,6 @@ import os
import torch
import torchani
import unittest
from torchani.data.cache_aev import cache_aev, cache_sparse_aev
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4')
......@@ -87,38 +86,6 @@ class TestData(unittest.TestCase):
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
def testAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.AEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)
def testSparseAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_sparse_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.SparseAEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)
if __name__ == '__main__':
unittest.main()
......@@ -6,10 +6,7 @@ from os.path import join, isfile, isdir
import os
from ._pyanitools import anidataloader
import torch
from .. import utils, neurochem, aev, models
import pickle
import numpy as np
from scipy.sparse import bsr_matrix
from .. import utils
import warnings
from .new import CachedDataset, ShuffledDataset, find_threshold
......@@ -364,153 +361,4 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
return tuple(ret)
class AEVCacheLoader(Dataset):
"""Build a factory for AEV.
The computation of AEV is the most time consuming part during training.
Since during training, the AEV never changes, it is not hard to see that,
If we have a fast enough storage (this is usually the case for good SSDs,
but not for HDD), we could cache the computed AEVs into disk and load it
rather than compute it from scratch everytime we use it.
Arguments:
disk_cache (str): Directory storing disk caches.
"""
def __init__(self, disk_cache=None):
super(AEVCacheLoader, self).__init__()
self.disk_cache = disk_cache
# load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)
def __getitem__(self, index):
_, output = self.dataset.batches[index]
aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f)
for i, sa in enumerate(species_aevs):
species, aevs = self.decode_aev(*sa)
species_aevs[i] = (
species.to(self.dataset.device),
aevs.to(self.dataset.device)
)
return species_aevs, output
def __len__(self):
return len(self.dataset)
@staticmethod
def decode_aev(encoded_species, encoded_aev):
return encoded_species, encoded_aev
@staticmethod
def encode_aev(species, aev):
return species, aev
class SparseAEVCacheLoader(AEVCacheLoader):
"""Build a factory for AEV.
The computation of AEV is the most time-consuming part of the training.
AEV never changes during training and contains a large number of zeros.
Therefore, we can store the computed AEVs as sparse representation and
load it during the training rather than compute it from scratch. The
storage requirement for ```'cache_sparse_aev'``` is considerably less
than ```'cache_aev'```.
Arguments:
disk_cache (str): Directory storing disk caches.
"""
@staticmethod
def decode_aev(encoded_species, encoded_aev):
species = torch.from_numpy(encoded_species.todense())
aevs_np = np.stack([np.array(i.todense()) for i in encoded_aev], axis=0)
aevs = torch.from_numpy(aevs_np)
return species, aevs
@staticmethod
def encode_aev(species, aev):
encoded_species = bsr_matrix(species.cpu().numpy())
encoded_aev = [bsr_matrix(i.cpu().numpy()) for i in aev]
return encoded_species, encoded_aev
ani1x = models.ANI1x()
def create_aev_cache(dataset, aev_computer, output, progress_bar=True, encoder=lambda *x: x):
"""Cache AEV for the given dataset.
Arguments:
dataset (:class:`torchani.data.PaddedBatchChunkDataset`): the dataset to be cached
aev_computer (:class:`torchani.AEVComputer`): the AEV computer used to compute aev
output (str): path to the directory where cache will be stored
progress_bar (bool): whether to show progress bar
encoder (:class:`collections.abc.Callable`): The callable
(species, aev) -> (encoded_species, encoded_aev) that encode species and aev
"""
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if progress_bar:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = [encoder(*aev_computer(j)) for j in input_]
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
def _cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm, encoder, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()
dataset = load_ani_dataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
create_aev_cache(dataset, aev_computer, output, enable_tqdm, encoder)
def cache_aev(output, dataset_path, batchsize, device=default_device,
constfile=ani1x.const_file, subtract_sae=False,
sae_file=ani1x.sae_file, enable_tqdm=True, **kwargs):
_cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm, AEVCacheLoader.encode_aev,
**kwargs)
def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
constfile=ani1x.const_file, subtract_sae=False,
sae_file=ani1x.sae_file, enable_tqdm=True, **kwargs):
_cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm,
SparseAEVCacheLoader.encode_aev, **kwargs)
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'AEVCacheLoader',
'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev',
'CachedDataset', 'ShuffledDataset', 'find_threshold']
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'CachedDataset', 'ShuffledDataset', 'find_threshold']
# -*- coding: utf-8 -*-
"""AEVs for a dataset can be precomputed by invoking
``python -m torchani.data.cache_aev``, this would dump the dataset and
computed aevs. Use the ``-h`` option for help.
"""
import torch
from . import cache_aev, cache_sparse_aev, ani1x, default_device
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('output',
help='Path of the output directory')
parser.add_argument('dataset',
help='Path of the dataset, can be a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('batchsize', help='batch size', type=int)
parser.add_argument('--constfile',
help='Path of the constant file `.params`',
default=ani1x.const_file)
parser.add_argument('--properties', nargs='+',
help='Output properties to load.`',
default=['energies'])
default_dtype = str(torch.get_default_dtype()).split('.')[1]
parser.add_argument('--dtype', help='Data type', default=default_dtype)
parser.add_argument('-d', '--device', help='Device for training',
default=default_device)
parser.add_argument('--no-shuffle', help='Whether to shuffle dataset',
dest='shuffle', action='store_false')
parser.add_argument('--no-tqdm', dest='tqdm', action='store_false',
help='Whether to use tqdm to display progress')
parser.add_argument('--subtract-sae', dest='subtract_sae',
help='Whether to subtrace self atomic energies',
default=None, action='store_true')
parser.add_argument('--sae-file', help='Path to SAE file',
default=ani1x.sae_file)
parser = parser.parse_args()
cache_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
parser.constfile, parser.tqdm, shuffle=parser.shuffle,
properties=parser.properties, dtype=getattr(torch, parser.dtype))
cache_sparse_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
parser.constfile, parser.tqdm, shuffle=parser.shuffle,
properties=parser.properties, dtype=getattr(torch, parser.dtype))
......@@ -274,28 +274,23 @@ if sys.version_info[0] > 2:
tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to
``None`` to disable tensorboard.
aev_caching (bool): Whether to use AEV caching.
checkpoint_name (str): Name of the checkpoint file, checkpoints
will be stored in the network directory with this file name.
"""
def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, aev_caching=False,
checkpoint_name='model.pt'):
tensorboard=None, checkpoint_name='model.pt'):
from ..data import load_ani_dataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
class dummy:
pass
self.imports = dummy()
self.imports.load_ani_dataset = load_ani_dataset
self.imports.AEVCacheLoader = AEVCacheLoader
self.filename = filename
self.device = device
self.aev_caching = aev_caching
self.checkpoint_name = checkpoint_name
self.weights = []
self.biases = []
......@@ -540,11 +535,7 @@ if sys.version_info[0] > 2:
# initialize weights and biases
self.nn.apply(init_params)
if self.aev_caching:
self.model = self.nn.to(self.device)
else:
self.model = Sequential(self.aev_computer, self.nn).to(self.device)
self.model = Sequential(self.aev_computer, self.nn).to(self.device)
# loss functions
self.mse_se = torch.nn.MSELoss(reduction='none')
......@@ -556,23 +547,15 @@ if sys.version_info[0] > 2:
self.best_validation_rmse = math.inf
def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file.
If AEV caching is enabled, then the arguments are path to the cache
directory, otherwise it should be path to the dataset.
"""
if self.aev_caching:
self.training_set = self.imports.AEVCacheLoader(training_path)
self.validation_set = self.imports.AEVCacheLoader(validation_path)
else:
self.training_set = self.imports.load_ani_dataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = self.imports.load_ani_dataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
"""Load training and validation dataset from file."""
self.training_set = self.imports.load_ani_dataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = self.imports.load_ani_dataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
def evaluate(self, dataset):
"""Run the evaluation"""
......
......@@ -26,15 +26,12 @@ if __name__ == '__main__':
parser.add_argument('--tensorboard',
help='Directory to store tensorboard log files',
default=None)
parser.add_argument('--cache-aev', dest='cache_aev', action='store_true',
help='Whether to cache AEV', default=None)
parser.add_argument('--checkpoint_name',
help='Name of checkpoint file',
default='model.pt')
parser = parser.parse_args()
d = torch.device(parser.device)
trainer = Trainer(parser.config_path, d, parser.tqdm, parser.tensorboard,
parser.cache_aev, parser.checkpoint_name)
trainer = Trainer(parser.config_path, d, parser.tqdm, parser.tensorboard, parser.checkpoint_name)
trainer.load_data(parser.training_path, parser.validation_path)
trainer.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment