"tests/vscode:/vscode.git/clone" did not exist on "92e383c3b8c85c98050dd1d35edeffaf1c81b4ff"
Unverified Commit 965efee2 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

helper function to create dataloader, helper module that handle batch (#27)

parent 194e88ff
import sys
if sys.version_info.major >= 3:
import os
import unittest
import torch
import torchani
import torchani.data
import itertools
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'dataset')
chunksize = 32
batch_chunks = 32
dtype = torch.float32
device = torch.device('cpu')
class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
batch_nnp = torchani.models.BatchModel(nnp)
for batch_input, batch_output in itertools.islice(loader, 10):
batch_output_ = batch_nnp(batch_input).squeeze()
self.assertListEqual(list(batch_output_.shape),
list(batch_output['energies'].shape))
if __name__ == '__main__':
unittest.main()
...@@ -13,7 +13,7 @@ if sys.version_info.major >= 3: ...@@ -13,7 +13,7 @@ if sys.version_info.major >= 3:
def _test_chunksize(self, chunksize): def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize) ds = torchani.data.ANIDataset(path, chunksize)
for i in ds: for i in ds:
self.assertLessEqual(i[0].shape[0], chunksize) self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def testChunk64(self): def testChunk64(self):
self._test_chunksize(64) self._test_chunksize(64)
......
from torch.utils.data import Dataset from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
from os import listdir from os import listdir
from .pyanitools import anidataloader from .pyanitools import anidataloader
...@@ -7,8 +7,13 @@ import torch ...@@ -7,8 +7,13 @@ import torch
class ANIDataset(Dataset): class ANIDataset(Dataset):
def __init__(self, path, chunk_size, randomize_chunk=True): def __init__(self, path, chunk_size, shuffle=True,
properties=['energies']):
super(ANIDataset, self).__init__() super(ANIDataset, self).__init__()
self.path = path
self.chunks_size = chunk_size
self.shuffle = shuffle
self.properties = properties
# get name of files storing data # get name of files storing data
files = [] files = []
...@@ -26,11 +31,14 @@ class ANIDataset(Dataset): ...@@ -26,11 +31,14 @@ class ANIDataset(Dataset):
chunks = [] chunks = []
for f in files: for f in files:
for m in anidataloader(f): for m in anidataloader(f):
xyz = torch.from_numpy(m['coordinates']) full = {
conformations = xyz.shape[0] 'coordinates': torch.from_numpy(m['coordinates'])
energies = torch.from_numpy(m['energies']) }
conformations = full['coordinates'].shape[0]
for i in properties:
full[i] = torch.from_numpy(m[i])
species = m['species'] species = m['species']
if randomize_chunk: if shuffle:
indices = torch.randperm(conformations) indices = torch.randperm(conformations)
else: else:
indices = torch.arange(conformations, dtype=torch.int64) indices = torch.arange(conformations, dtype=torch.int64)
...@@ -39,9 +47,11 @@ class ANIDataset(Dataset): ...@@ -39,9 +47,11 @@ class ANIDataset(Dataset):
chunk_start = i * chunk_size chunk_start = i * chunk_size
chunk_end = min(chunk_start + chunk_size, conformations) chunk_end = min(chunk_start + chunk_size, conformations)
chunk_indices = indices[chunk_start:chunk_end] chunk_indices = indices[chunk_start:chunk_end]
chunk_xyz = xyz.index_select(0, chunk_indices) chunk = {}
chunk_energies = energies.index_select(0, chunk_indices) for j in full:
chunks.append((chunk_xyz, chunk_energies, species)) chunk[j] = full[j].index_select(0, chunk_indices)
chunk['species'] = species
chunks.append(chunk)
self.chunks = chunks self.chunks = chunks
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -49,3 +59,24 @@ class ANIDataset(Dataset): ...@@ -49,3 +59,24 @@ class ANIDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.chunks) return len(self.chunks)
def _collate(batch):
input_keys = ['coordinates', 'species']
inputs = [{k: i[k] for k in input_keys} for i in batch]
outputs = {}
for i in batch:
for j in i:
if j in input_keys:
continue
if j not in outputs:
outputs[j] = []
outputs[j].append(i[j])
for i in outputs:
outputs[i] = torch.cat(outputs[i])
return inputs, outputs
def dataloader(dataset, batch_chunks, **kwargs):
return DataLoader(dataset, batch_chunks, dataset.shuffle,
collate_fn=_collate, **kwargs)
from .custom import CustomModel from .custom import CustomModel
from .neurochem_nnp import NeuroChemNNP from .neurochem_nnp import NeuroChemNNP
from .batch import BatchModel
__all__ = ['CustomModel', 'NeuroChemNNP'] __all__ = ['CustomModel', 'NeuroChemNNP', 'BatchModel']
import torch
class BatchModel(torch.nn.Module):
def __init__(self, model):
super(BatchModel, self).__init__()
self.model = model
def forward(self, batch):
results = []
for i in batch:
coordinates = i['coordinates']
species = i['species']
results.append(self.model(coordinates, species))
return torch.cat(results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment