Unverified Commit bfc04ac8 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Support caching AEVs into disk (#88)

parent c059adbf
......@@ -2,12 +2,15 @@ import os
import torch
import torchani
import unittest
from torchani.data.cache_aev import cache_aev
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset')
dataset_path2 = os.path.join(path, '../dataset/ani_gdb_s01.h5')
batch_size = 256
builtins = torchani.neurochem.Builtins()
consts = builtins.consts
aev_computer = builtins.aev_computer
class TestData(unittest.TestCase):
......@@ -18,7 +21,7 @@ class TestData(unittest.TestCase):
batch_size)
def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1-t2).abs().max(), 0)
self.assertEqual((t1-t2).abs().max().item(), 0)
def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long)
......@@ -74,6 +77,22 @@ class TestData(unittest.TestCase):
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
def testAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.AEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)
if __name__ == '__main__':
unittest.main()
......@@ -7,6 +7,7 @@ import os
from ._pyanitools import anidataloader
import torch
from .. import utils
import pickle
def chunk_counts(counts, split):
......@@ -217,3 +218,76 @@ class BatchedANIDataset(Dataset):
def __len__(self):
return len(self.batches)
def _disk_cache_loader(index_queue, tensor_queue, disk_cache, device):
"""Get index and load from disk cache."""
while True:
index = index_queue.get()
aev_path = os.path.join(disk_cache, str(index))
with open(aev_path, 'rb') as f:
tensor_queue.put(pickle.load(f))
class AEVCacheLoader:
"""Build a factory for AEV.
The computation of AEV is the most time consuming part during training.
Since during training, the AEV never changes, it is not hard to see that,
If we have a fast enough storage (this is usually the case for good SSDs,
but not for HDD), we could cache the computed AEVs into disk and load it
rather than compute it from scratch everytime we use it.
Arguments:
disk_cache (str): Directory storing disk caches.
"""
def __init__(self, disk_cache=None, in_memory_size=64):
self.current = 0
self.disk_cache = disk_cache
# load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)
# initialize queues and processes
self.tensor_queue = torch.multiprocessing.Queue()
self.index_queue = torch.multiprocessing.Queue()
self.in_memory_size = in_memory_size
if len(self.dataset) < in_memory_size:
self.in_memory_size = len(self.dataset)
for i in range(in_memory_size):
self.index_queue.put(i)
self.loader = torch.multiprocessing.Process(
target=_disk_cache_loader,
args=(self.index_queue, self.tensor_queue, disk_cache,
self.dataset.device)
)
self.loader.start()
def __iter__(self):
if self.current != 0:
raise ValueError('Only one iterator of AEVCacheLoader is allowed')
else:
return self
def __next__(self):
if self.current < len(self.dataset):
new_idx = (self.current + self.in_memory_size) % len(self.dataset)
self.index_queue.put(new_idx)
species_aevs = self.tensor_queue.get()
species_aevs = [(x.to(self.dataset.device),
y.to(self.dataset.device))
for x, y in species_aevs]
_, output = self.dataset[self.current]
self.current += 1
return species_aevs, output
else:
self.current = 0
raise StopIteration
def __del__(self):
self.loader.terminate()
__all__ = ['BatchedANIDataset', 'AEVCacheLoader']
......@@ -11,9 +11,54 @@ from . import BatchedANIDataset
import pickle
builtin = neurochem.Builtins()
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
default_dtype = str(torch.get_default_dtype()).split('.')[1]
def cache_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()
dataset = BatchedANIDataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if enable_tqdm:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = [aev_computer(j) for j in input_]
aevs = [(x.cpu(), y.cpu()) for x, y in aevs]
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
if __name__ == '__main__':
import argparse
builtin = neurochem.Builtins()
parser = argparse.ArgumentParser()
parser.add_argument('output',
help='Path of the output directory')
......@@ -27,44 +72,20 @@ if __name__ == '__main__':
parser.add_argument('--properties', nargs='+',
help='Output properties to load.`',
default=['energies'])
parser.add_argument('--dtype',
help='Data type',
default=str(torch.get_default_dtype()).split('.')[1])
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser.add_argument('--dtype', help='Data type', default=default_dtype)
parser.add_argument('-d', '--device', help='Device for training',
default=default_device)
parser.add_argument('--no-shuffle', help='Whether to shuffle dataset',
dest='shuffle', action='store_false')
parser.add_argument('--no-tqdm', dest='tqdm', action='store_false',
help='Whether to use tqdm to display progress')
parser.add_argument('--subtract-sae', dest='subtract_sae',
help='Whether to subtrace self atomic energies',
default=None, action='store_true')
parser.add_argument('--sae-file', help='Path to SAE file',
default=builtin.sae_file)
parser = parser.parse_args()
# if output directory does not exist, then create it
if not os.path.exists(parser.output):
os.makedirs(parser.output)
device = torch.device(parser.device)
consts = neurochem.Constants(parser.constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
dataset = BatchedANIDataset(parser.dataset, consts.species_to_tensor,
parser.batchsize, shuffle=parser.shuffle,
properties=parser.properties, device=device,
dtype=getattr(torch, parser.dtype))
# dump out the dataset
filename = os.path.join(parser.output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if parser.tqdm:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = [aev_computer(j) for j in input_]
aevs = [(x.cpu(), y.cpu()) for x, y in aevs]
filename = os.path.join(parser.output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
cache_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
parser.constfile, parser.tqdm, shuffle=parser.shuffle,
properties=parser.properties, dtype=getattr(torch, parser.dtype))
......@@ -138,3 +138,7 @@ class MaxAbsoluteError(Metric):
def MAEMetric(key):
"""Create max absolute error metric on key."""
return DictMetric(key, MaxAbsoluteError())
__all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
'MAEMetric']
......@@ -673,3 +673,7 @@ class Trainer:
decorate(trainer)
trainer.run(self.training_set, max_epochs=math.inf)
lr *= self.lr_decay
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble',
'Trainer']
......@@ -122,3 +122,6 @@ class EnergyShifter(torch.nn.Module):
species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae
__all__ = ['pad_and_batch', 'present_species', 'strip_redundant_padding']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment