Commit 4f63c32d authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Support caching sparse representation of AEVs into disk (#230)

parent a86ea658
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import torch import torch
import torchani import torchani
import unittest import unittest
from torchani.data.cache_aev import cache_aev from torchani.data.cache_aev import cache_aev, cache_sparse_aev
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4') dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4')
...@@ -93,6 +93,22 @@ class TestData(unittest.TestCase): ...@@ -93,6 +93,22 @@ class TestData(unittest.TestCase):
self._assertTensorEqual(s1, s2) self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2) self._assertTensorEqual(a, a2)
def testSparseAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_sparse_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.SparseAEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -8,6 +8,8 @@ from ._pyanitools import anidataloader ...@@ -8,6 +8,8 @@ from ._pyanitools import anidataloader
import torch import torch
from .. import utils, neurochem, aev from .. import utils, neurochem, aev
import pickle import pickle
import numpy as np
from scipy.sparse import bsr_matrix
default_device = 'cuda' if torch.cuda.is_available() else 'cpu' default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
...@@ -252,6 +254,48 @@ class AEVCacheLoader(Dataset): ...@@ -252,6 +254,48 @@ class AEVCacheLoader(Dataset):
return len(self.dataset) return len(self.dataset)
class SparseAEVCacheLoader(Dataset):
"""Build a factory for AEV.
The computation of AEV is the most time-consuming part of the training.
AEV never changes during training and contains a large number of zeros.
Therefore, we can store the computed AEVs as sparse representation and
load it during the training rather than compute it from scratch. The
storage requirement for ```'cache_sparse_aev'``` is considerably less
than ```'cache_aev'```.
Arguments:
disk_cache (str): Directory storing disk caches.
device (:class:`torch.dtype`): device to put tensors when iterating.
"""
def __init__(self, disk_cache=None, device=torch.device('cpu')):
super(SparseAEVCacheLoader, self).__init__()
self.disk_cache = disk_cache
self.device = device
# load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)
def __getitem__(self, index):
_, output = self.dataset.batches[index]
aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f)
batch_X = []
for species_, aev_ in species_aevs:
species_np = np.array(species_.todense())
species = torch.from_numpy(species_np).to(self.device)
aevs_np = np.stack([np.array(i.todense()) for i in aev_], axis=0)
aevs = torch.from_numpy(aevs_np).to(self.device)
batch_X.append((species, aevs))
return batch_X, output
def __len__(self):
return len(self.dataset)
builtin = neurochem.Builtins() builtin = neurochem.Builtins()
...@@ -295,4 +339,49 @@ def cache_aev(output, dataset_path, batchsize, device=default_device, ...@@ -295,4 +339,49 @@ def cache_aev(output, dataset_path, batchsize, device=default_device,
pickle.dump(aevs, f) pickle.dump(aevs, f)
__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'cache_aev'] def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()
dataset = BatchedANIDataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if enable_tqdm:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = []
for j in input_:
species_, aev_ = aev_computer(j)
species_ = bsr_matrix(species_.cpu().numpy())
aev_ = [bsr_matrix(i.cpu().numpy()) for i in aev_]
aevs.append((species_, aev_))
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev']
...@@ -5,7 +5,7 @@ computed aevs. Use the ``-h`` option for help. ...@@ -5,7 +5,7 @@ computed aevs. Use the ``-h`` option for help.
""" """
import torch import torch
from . import cache_aev, builtin, default_device from . import cache_aev, cache_sparse_aev, builtin, default_device
if __name__ == '__main__': if __name__ == '__main__':
...@@ -41,3 +41,7 @@ if __name__ == '__main__': ...@@ -41,3 +41,7 @@ if __name__ == '__main__':
cache_aev(parser.output, parser.dataset, parser.batchsize, parser.device, cache_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
parser.constfile, parser.tqdm, shuffle=parser.shuffle, parser.constfile, parser.tqdm, shuffle=parser.shuffle,
properties=parser.properties, dtype=getattr(torch, parser.dtype)) properties=parser.properties, dtype=getattr(torch, parser.dtype))
cache_sparse_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
parser.constfile, parser.tqdm, shuffle=parser.shuffle,
properties=parser.properties, dtype=getattr(torch, parser.dtype))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment