Unverified Commit bfc04ac8 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Support caching AEVs into disk (#88)

parent c059adbf
...@@ -2,12 +2,15 @@ import os ...@@ -2,12 +2,15 @@ import os
import torch import torch
import torchani import torchani
import unittest import unittest
from torchani.data.cache_aev import cache_aev
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset') dataset_path = os.path.join(path, '../dataset')
dataset_path2 = os.path.join(path, '../dataset/ani_gdb_s01.h5')
batch_size = 256 batch_size = 256
builtins = torchani.neurochem.Builtins() builtins = torchani.neurochem.Builtins()
consts = builtins.consts consts = builtins.consts
aev_computer = builtins.aev_computer
class TestData(unittest.TestCase): class TestData(unittest.TestCase):
...@@ -18,7 +21,7 @@ class TestData(unittest.TestCase): ...@@ -18,7 +21,7 @@ class TestData(unittest.TestCase):
batch_size) batch_size)
def _assertTensorEqual(self, t1, t2): def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1-t2).abs().max(), 0) self.assertEqual((t1-t2).abs().max().item(), 0)
def testSplitBatch(self): def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long) species1 = torch.randint(4, (5, 4), dtype=torch.long)
...@@ -74,6 +77,22 @@ class TestData(unittest.TestCase): ...@@ -74,6 +77,22 @@ class TestData(unittest.TestCase):
non_padding = (species >= 0)[:, -1].nonzero() non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0) self.assertGreater(non_padding.numel(), 0)
def testAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.AEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from ._pyanitools import anidataloader from ._pyanitools import anidataloader
import torch import torch
from .. import utils from .. import utils
import pickle
def chunk_counts(counts, split): def chunk_counts(counts, split):
...@@ -217,3 +218,76 @@ class BatchedANIDataset(Dataset): ...@@ -217,3 +218,76 @@ class BatchedANIDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.batches) return len(self.batches)
def _disk_cache_loader(index_queue, tensor_queue, disk_cache, device):
"""Get index and load from disk cache."""
while True:
index = index_queue.get()
aev_path = os.path.join(disk_cache, str(index))
with open(aev_path, 'rb') as f:
tensor_queue.put(pickle.load(f))
class AEVCacheLoader:
"""Build a factory for AEV.
The computation of AEV is the most time consuming part during training.
Since during training, the AEV never changes, it is not hard to see that,
If we have a fast enough storage (this is usually the case for good SSDs,
but not for HDD), we could cache the computed AEVs into disk and load it
rather than compute it from scratch everytime we use it.
Arguments:
disk_cache (str): Directory storing disk caches.
"""
def __init__(self, disk_cache=None, in_memory_size=64):
self.current = 0
self.disk_cache = disk_cache
# load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)
# initialize queues and processes
self.tensor_queue = torch.multiprocessing.Queue()
self.index_queue = torch.multiprocessing.Queue()
self.in_memory_size = in_memory_size
if len(self.dataset) < in_memory_size:
self.in_memory_size = len(self.dataset)
for i in range(in_memory_size):
self.index_queue.put(i)
self.loader = torch.multiprocessing.Process(
target=_disk_cache_loader,
args=(self.index_queue, self.tensor_queue, disk_cache,
self.dataset.device)
)
self.loader.start()
def __iter__(self):
if self.current != 0:
raise ValueError('Only one iterator of AEVCacheLoader is allowed')
else:
return self
def __next__(self):
if self.current < len(self.dataset):
new_idx = (self.current + self.in_memory_size) % len(self.dataset)
self.index_queue.put(new_idx)
species_aevs = self.tensor_queue.get()
species_aevs = [(x.to(self.dataset.device),
y.to(self.dataset.device))
for x, y in species_aevs]
_, output = self.dataset[self.current]
self.current += 1
return species_aevs, output
else:
self.current = 0
raise StopIteration
def __del__(self):
self.loader.terminate()
__all__ = ['BatchedANIDataset', 'AEVCacheLoader']
...@@ -11,9 +11,54 @@ from . import BatchedANIDataset ...@@ -11,9 +11,54 @@ from . import BatchedANIDataset
import pickle import pickle
builtin = neurochem.Builtins()
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
default_dtype = str(torch.get_default_dtype()).split('.')[1]
def cache_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()
dataset = BatchedANIDataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if enable_tqdm:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = [aev_computer(j) for j in input_]
aevs = [(x.cpu(), y.cpu()) for x, y in aevs]
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
builtin = neurochem.Builtins()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('output', parser.add_argument('output',
help='Path of the output directory') help='Path of the output directory')
...@@ -27,44 +72,20 @@ if __name__ == '__main__': ...@@ -27,44 +72,20 @@ if __name__ == '__main__':
parser.add_argument('--properties', nargs='+', parser.add_argument('--properties', nargs='+',
help='Output properties to load.`', help='Output properties to load.`',
default=['energies']) default=['energies'])
parser.add_argument('--dtype', parser.add_argument('--dtype', help='Data type', default=default_dtype)
help='Data type',
default=str(torch.get_default_dtype()).split('.')[1])
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser.add_argument('-d', '--device', help='Device for training', parser.add_argument('-d', '--device', help='Device for training',
default=default_device) default=default_device)
parser.add_argument('--no-shuffle', help='Whether to shuffle dataset', parser.add_argument('--no-shuffle', help='Whether to shuffle dataset',
dest='shuffle', action='store_false') dest='shuffle', action='store_false')
parser.add_argument('--no-tqdm', dest='tqdm', action='store_false', parser.add_argument('--no-tqdm', dest='tqdm', action='store_false',
help='Whether to use tqdm to display progress') help='Whether to use tqdm to display progress')
parser.add_argument('--subtract-sae', dest='subtract_sae',
help='Whether to subtrace self atomic energies',
default=None, action='store_true')
parser.add_argument('--sae-file', help='Path to SAE file',
default=builtin.sae_file)
parser = parser.parse_args() parser = parser.parse_args()
# if output directory does not exist, then create it cache_aev(parser.output, parser.dataset, parser.batchsize, parser.device,
if not os.path.exists(parser.output): parser.constfile, parser.tqdm, shuffle=parser.shuffle,
os.makedirs(parser.output) properties=parser.properties, dtype=getattr(torch, parser.dtype))
device = torch.device(parser.device)
consts = neurochem.Constants(parser.constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
dataset = BatchedANIDataset(parser.dataset, consts.species_to_tensor,
parser.batchsize, shuffle=parser.shuffle,
properties=parser.properties, device=device,
dtype=getattr(torch, parser.dtype))
# dump out the dataset
filename = os.path.join(parser.output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if parser.tqdm:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = [aev_computer(j) for j in input_]
aevs = [(x.cpu(), y.cpu()) for x, y in aevs]
filename = os.path.join(parser.output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
...@@ -138,3 +138,7 @@ class MaxAbsoluteError(Metric): ...@@ -138,3 +138,7 @@ class MaxAbsoluteError(Metric):
def MAEMetric(key): def MAEMetric(key):
"""Create max absolute error metric on key.""" """Create max absolute error metric on key."""
return DictMetric(key, MaxAbsoluteError()) return DictMetric(key, MaxAbsoluteError())
__all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
'MAEMetric']
...@@ -673,3 +673,7 @@ class Trainer: ...@@ -673,3 +673,7 @@ class Trainer:
decorate(trainer) decorate(trainer)
trainer.run(self.training_set, max_epochs=math.inf) trainer.run(self.training_set, max_epochs=math.inf)
lr *= self.lr_decay lr *= self.lr_decay
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble',
'Trainer']
...@@ -122,3 +122,6 @@ class EnergyShifter(torch.nn.Module): ...@@ -122,3 +122,6 @@ class EnergyShifter(torch.nn.Module):
species, energies = species_energies species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device) sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae return species, energies + sae
__all__ = ['pad_and_batch', 'present_species', 'strip_redundant_padding']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment