Unverified Commit 965efee2 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

helper function to create dataloader, helper module that handle batch (#27)

parent 194e88ff
import sys
if sys.version_info.major >= 3:
import os
import unittest
import torch
import torchani
import torchani.data
import itertools
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'dataset')
chunksize = 32
batch_chunks = 32
dtype = torch.float32
device = torch.device('cpu')
class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
batch_nnp = torchani.models.BatchModel(nnp)
for batch_input, batch_output in itertools.islice(loader, 10):
batch_output_ = batch_nnp(batch_input).squeeze()
self.assertListEqual(list(batch_output_.shape),
list(batch_output['energies'].shape))
if __name__ == '__main__':
unittest.main()
......@@ -13,7 +13,7 @@ if sys.version_info.major >= 3:
def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize)
for i in ds:
self.assertLessEqual(i[0].shape[0], chunksize)
self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def testChunk64(self):
self._test_chunksize(64)
......
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir
from os import listdir
from .pyanitools import anidataloader
......@@ -7,8 +7,13 @@ import torch
class ANIDataset(Dataset):
def __init__(self, path, chunk_size, randomize_chunk=True):
def __init__(self, path, chunk_size, shuffle=True,
properties=['energies']):
super(ANIDataset, self).__init__()
self.path = path
self.chunks_size = chunk_size
self.shuffle = shuffle
self.properties = properties
# get name of files storing data
files = []
......@@ -26,11 +31,14 @@ class ANIDataset(Dataset):
chunks = []
for f in files:
for m in anidataloader(f):
xyz = torch.from_numpy(m['coordinates'])
conformations = xyz.shape[0]
energies = torch.from_numpy(m['energies'])
full = {
'coordinates': torch.from_numpy(m['coordinates'])
}
conformations = full['coordinates'].shape[0]
for i in properties:
full[i] = torch.from_numpy(m[i])
species = m['species']
if randomize_chunk:
if shuffle:
indices = torch.randperm(conformations)
else:
indices = torch.arange(conformations, dtype=torch.int64)
......@@ -39,9 +47,11 @@ class ANIDataset(Dataset):
chunk_start = i * chunk_size
chunk_end = min(chunk_start + chunk_size, conformations)
chunk_indices = indices[chunk_start:chunk_end]
chunk_xyz = xyz.index_select(0, chunk_indices)
chunk_energies = energies.index_select(0, chunk_indices)
chunks.append((chunk_xyz, chunk_energies, species))
chunk = {}
for j in full:
chunk[j] = full[j].index_select(0, chunk_indices)
chunk['species'] = species
chunks.append(chunk)
self.chunks = chunks
def __getitem__(self, idx):
......@@ -49,3 +59,24 @@ class ANIDataset(Dataset):
def __len__(self):
return len(self.chunks)
def _collate(batch):
input_keys = ['coordinates', 'species']
inputs = [{k: i[k] for k in input_keys} for i in batch]
outputs = {}
for i in batch:
for j in i:
if j in input_keys:
continue
if j not in outputs:
outputs[j] = []
outputs[j].append(i[j])
for i in outputs:
outputs[i] = torch.cat(outputs[i])
return inputs, outputs
def dataloader(dataset, batch_chunks, **kwargs):
return DataLoader(dataset, batch_chunks, dataset.shuffle,
collate_fn=_collate, **kwargs)
from .custom import CustomModel
from .neurochem_nnp import NeuroChemNNP
from .batch import BatchModel
__all__ = ['CustomModel', 'NeuroChemNNP']
__all__ = ['CustomModel', 'NeuroChemNNP', 'BatchModel']
import torch
class BatchModel(torch.nn.Module):
def __init__(self, model):
super(BatchModel, self).__init__()
self.model = model
def forward(self, batch):
results = []
for i in batch:
coordinates = i['coordinates']
species = i['species']
results.append(self.model(coordinates, species))
return torch.cat(results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment