"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "a6ba254fa78b063f7367d2495b9bd4b64c1eb7db"
Unverified Commit 10699bf7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

allow energy shifter as transformations to dataset (#30)

parent 18e4867d
......@@ -18,7 +18,10 @@ if sys.version_info.major >= 3:
class TestIgnite(unittest.TestCase):
def testIgnite(self):
ds = torchani.data.ANIDataset(path, chunksize)
shift_energy = torchani.EnergyShifter()
ds = torchani.data.ANIDataset(
path, chunksize,
transform=[shift_energy.dataset_subtract_sae])
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
......
......@@ -8,8 +8,8 @@ import torch
class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True,
properties=['energies'], dtype=default_dtype):
def __init__(self, path, chunk_size, shuffle=True, properties=['energies'],
transform=(), dtype=default_dtype):
super(ANIDataset, self).__init__()
self.path = path
self.chunks_size = chunk_size
......@@ -54,6 +54,8 @@ class ANIDataset(Dataset):
for j in full:
chunk[j] = full[j].index_select(0, chunk_indices)
chunk['species'] = species
for t in transform:
chunk = t(chunk)
chunks.append(chunk)
self.chunks = chunks
......@@ -80,6 +82,6 @@ def _collate(batch):
return inputs, outputs
def dataloader(dataset, batch_chunks, **kwargs):
return DataLoader(dataset, batch_chunks, dataset.shuffle,
def dataloader(dataset, batch_chunks, shuffle=True, **kwargs):
return DataLoader(dataset, batch_chunks, shuffle,
collate_fn=_collate, **kwargs)
......@@ -67,3 +67,10 @@ class EnergyShifter:
for i in species:
s += self.self_energies[i]
return energies + s
def dataset_subtract_sae(self, data):
"""Allow object of this class to be used as transforms of pytorch's
dataset.
"""
data['energies'] = self.subtract_sae(data['energies'], data['species'])
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment