Commit 4f63c32d authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Support caching sparse representation of AEVs into disk (#230)

parent a86ea658
......@@ -2,7 +2,7 @@ import os
import torch
import torchani
import unittest
from torchani.data.cache_aev import cache_aev
from torchani.data.cache_aev import cache_aev, cache_sparse_aev
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4')
......@@ -93,6 +93,22 @@ class TestData(unittest.TestCase):
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)
def testSparseAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_sparse_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.SparseAEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)
if __name__ == '__main__':
unittest.main()
......@@ -8,6 +8,8 @@ from ._pyanitools import anidataloader
import torch
from .. import utils, neurochem, aev
import pickle
import numpy as np
from scipy.sparse import bsr_matrix
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
......@@ -252,6 +254,48 @@ class AEVCacheLoader(Dataset):
return len(self.dataset)
class SparseAEVCacheLoader(Dataset):
"""Build a factory for AEV.
The computation of AEV is the most time-consuming part of the training.
AEV never changes during training and contains a large number of zeros.
Therefore, we can store the computed AEVs as sparse representation and
load it during the training rather than compute it from scratch. The
storage requirement for ```'cache_sparse_aev'``` is considerably less
than ```'cache_aev'```.
Arguments:
disk_cache (str): Directory storing disk caches.
device (:class:`torch.dtype`): device to put tensors when iterating.
"""
def __init__(self, disk_cache=None, device=torch.device('cpu')):
super(SparseAEVCacheLoader, self).__init__()
self.disk_cache = disk_cache
self.device = device
# load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)
def __getitem__(self, index):
_, output = self.dataset.batches[index]
aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f)
batch_X = []
for species_, aev_ in species_aevs:
species_np = np.array(species_.todense())
species = torch.from_numpy(species_np).to(self.device)
aevs_np = np.stack([np.array(i.todense()) for i in aev_], axis=0)
aevs = torch.from_numpy(aevs_np).to(self.device)
batch_X.append((species, aevs))
return batch_X, output
def __len__(self):
return len(self.dataset)
builtin = neurochem.Builtins()
......@@ -295,4 +339,49 @@ def cache_aev(output, dataset_path, batchsize, device=default_device,
pickle.dump(aevs, f)
__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'cache_aev']
def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()
dataset = BatchedANIDataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if enable_tqdm:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = []
for j in input_:
species_, aev_ = aev_computer(j)
species_ = bsr_matrix(species_.cpu().numpy())
aev_ = [bsr_matrix(i.cpu().numpy()) for i in aev_]
aevs.append((species_, aev_))
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev']
......@@ -5,7 +5,7 @@ computed aevs. Use the ``-h`` option for help.
"""
import torch
from . import cache_aev, builtin, default_device
from . import cache_aev, cache_sparse_aev, builtin, default_device
if __name__ == '__main__':
......@@ -41,3 +41,7 @@ if __name__ == '__main__':
cache_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
parser.constfile, parser.tqdm, shuffle=parser.shuffle,
properties=parser.properties, dtype=getattr(torch, parser.dtype))
cache_sparse_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
parser.constfile, parser.tqdm, shuffle=parser.shuffle,
properties=parser.properties, dtype=getattr(torch, parser.dtype))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment