Unverified Commit 10699bf7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

allow energy shifter as transformations to dataset (#30)

parent 18e4867d
...@@ -18,7 +18,10 @@ if sys.version_info.major >= 3: ...@@ -18,7 +18,10 @@ if sys.version_info.major >= 3:
class TestIgnite(unittest.TestCase): class TestIgnite(unittest.TestCase):
def testIgnite(self): def testIgnite(self):
ds = torchani.data.ANIDataset(path, chunksize) shift_energy = torchani.EnergyShifter()
ds = torchani.data.ANIDataset(
path, chunksize,
transform=[shift_energy.dataset_subtract_sae])
loader = torchani.data.dataloader(ds, batch_chunks) loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device) aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer) nnp = torchani.models.NeuroChemNNP(aev_computer)
......
...@@ -8,8 +8,8 @@ import torch ...@@ -8,8 +8,8 @@ import torch
class ANIDataset(Dataset): class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, def __init__(self, path, chunk_size, shuffle=True, properties=['energies'],
properties=['energies'], dtype=default_dtype): transform=(), dtype=default_dtype):
super(ANIDataset, self).__init__() super(ANIDataset, self).__init__()
self.path = path self.path = path
self.chunks_size = chunk_size self.chunks_size = chunk_size
...@@ -54,6 +54,8 @@ class ANIDataset(Dataset): ...@@ -54,6 +54,8 @@ class ANIDataset(Dataset):
for j in full: for j in full:
chunk[j] = full[j].index_select(0, chunk_indices) chunk[j] = full[j].index_select(0, chunk_indices)
chunk['species'] = species chunk['species'] = species
for t in transform:
chunk = t(chunk)
chunks.append(chunk) chunks.append(chunk)
self.chunks = chunks self.chunks = chunks
...@@ -80,6 +82,6 @@ def _collate(batch): ...@@ -80,6 +82,6 @@ def _collate(batch):
return inputs, outputs return inputs, outputs
def dataloader(dataset, batch_chunks, **kwargs): def dataloader(dataset, batch_chunks, shuffle=True, **kwargs):
return DataLoader(dataset, batch_chunks, dataset.shuffle, return DataLoader(dataset, batch_chunks, shuffle,
collate_fn=_collate, **kwargs) collate_fn=_collate, **kwargs)
...@@ -67,3 +67,10 @@ class EnergyShifter: ...@@ -67,3 +67,10 @@ class EnergyShifter:
for i in species: for i in species:
s += self.self_energies[i] s += self.self_energies[i]
return energies + s return energies + s
def dataset_subtract_sae(self, data):
"""Allow object of this class to be used as transforms of pytorch's
dataset.
"""
data['energies'] = self.subtract_sae(data['energies'], data['species'])
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment