Unverified Commit e69e59a4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Git rid of multiprocessing (#91)

parent 51cab350
...@@ -185,13 +185,3 @@ def log_loss(trainer): ...@@ -185,13 +185,3 @@ def log_loss(trainer):
trainer.run(training, max_epochs) trainer.run(training, max_epochs)
###############################################################################
# In the end, we explicitly close the opened loader's process. If the loading
# processes are not closed, these processes would prevent the whole program
# from terminating. The closing of loading process can be done automatically
# when an :class:`torchani.data.AEVCacheLoader` object is garbage collected,
# but here since our cache loader objects are in global scope, it won't be
# garbage collected, se we need to terminate these processes manually.
training.__del__()
validation.__del__()
...@@ -93,4 +93,3 @@ trainer.run(dataset, max_epochs=1) ...@@ -93,4 +93,3 @@ trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
print('NN:', timers['forward']) print('NN:', timers['forward'])
print('Epoch time:', elapsed) print('Epoch time:', elapsed)
dataset.__del__()
...@@ -200,7 +200,7 @@ class BatchedANIDataset(Dataset): ...@@ -200,7 +200,7 @@ class BatchedANIDataset(Dataset):
.index_select(0, indices) .index_select(0, indices)
properties_batch = { properties_batch = {
k: properties[k][start:end, ...].index_select(0, indices) k: properties[k][start:end, ...].index_select(0, indices)
for k in properties .to(self.device) for k in properties
} }
# further split batch into chunks # further split batch into chunks
species_coordinates = split_batch(natoms_batch, species_batch, species_coordinates = split_batch(natoms_batch, species_batch,
...@@ -213,24 +213,12 @@ class BatchedANIDataset(Dataset): ...@@ -213,24 +213,12 @@ class BatchedANIDataset(Dataset):
species_coordinates, properties = self.batches[idx] species_coordinates, properties = self.batches[idx]
species_coordinates = [(s.to(self.device), c.to(self.device)) species_coordinates = [(s.to(self.device), c.to(self.device))
for s, c in species_coordinates] for s, c in species_coordinates]
properties = {
k: properties[k].to(self.device) for k in properties
}
return species_coordinates, properties return species_coordinates, properties
def __len__(self): def __len__(self):
return len(self.batches) return len(self.batches)
def _disk_cache_loader(index_queue, tensor_queue, disk_cache, device):
"""Get index and load from disk cache."""
while True:
index = index_queue.get()
aev_path = os.path.join(disk_cache, str(index))
with open(aev_path, 'rb') as f:
tensor_queue.put(pickle.load(f))
class AEVCacheLoader: class AEVCacheLoader:
"""Build a factory for AEV. """Build a factory for AEV.
...@@ -244,52 +232,22 @@ class AEVCacheLoader: ...@@ -244,52 +232,22 @@ class AEVCacheLoader:
disk_cache (str): Directory storing disk caches. disk_cache (str): Directory storing disk caches.
""" """
def __init__(self, disk_cache=None, in_memory_size=64): def __init__(self, disk_cache=None):
self.current = 0
self.disk_cache = disk_cache self.disk_cache = disk_cache
# load dataset from disk cache # load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset') dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f: with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f) self.dataset = pickle.load(f)
# initialize queues and processes
self.tensor_queue = torch.multiprocessing.Queue()
self.index_queue = torch.multiprocessing.Queue()
self.in_memory_size = in_memory_size
if len(self.dataset) < in_memory_size:
self.in_memory_size = len(self.dataset)
for i in range(self.in_memory_size):
self.index_queue.put(i)
self.loader = torch.multiprocessing.Process(
target=_disk_cache_loader,
args=(self.index_queue, self.tensor_queue, disk_cache,
self.dataset.device)
)
self.loader.start()
def __iter__(self):
if self.current != 0:
raise ValueError('Only one iterator of AEVCacheLoader is allowed')
else:
return self
def __next__(self):
if self.current < len(self.dataset):
new_idx = (self.current + self.in_memory_size) % len(self.dataset)
self.index_queue.put(new_idx)
species_aevs = self.tensor_queue.get()
species_aevs = [(x.to(self.dataset.device),
y.to(self.dataset.device))
for x, y in species_aevs]
_, output = self.dataset[self.current]
self.current += 1
return species_aevs, output
else:
self.current = 0
raise StopIteration
def __del__(self): def __getitem__(self, index):
self.loader.terminate() if index >= self.__len__():
raise IndexError()
aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f)
_, output = self.dataset.batches[index]
return species_aevs, output
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
...@@ -333,7 +291,6 @@ def cache_aev(output, dataset_path, batchsize, device=default_device, ...@@ -333,7 +291,6 @@ def cache_aev(output, dataset_path, batchsize, device=default_device,
for i in indices: for i in indices:
input_, _ = dataset[i] input_, _ = dataset[i]
aevs = [aev_computer(j) for j in input_] aevs = [aev_computer(j) for j in input_]
aevs = [(x.cpu(), y.cpu()) for x, y in aevs]
filename = os.path.join(output, '{}'.format(i)) filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
pickle.dump(aevs, f) pickle.dump(aevs, f)
......
...@@ -44,7 +44,7 @@ class ANIModel(torch.nn.ModuleList): ...@@ -44,7 +44,7 @@ class ANIModel(torch.nn.ModuleList):
for i in present_species: for i in present_species:
mask = (species_ == i) mask = (species_ == i)
input = aev.index_select(0, mask.nonzero().squeeze()) input = aev.index_select(0, mask.nonzero().squeeze())
output[mask] = self[i](input).squeeze() output.masked_scatter_(mask, self[i](input).squeeze())
output = output.view_as(species) output = output.view_as(species)
return species, self.reducer(output, dim=1) return species, self.reducer(output, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment