"...text-generation-inference.git" did not exist on "8b182eb98662ea781990a4a2e869eb8859e26073"
Unverified Commit 7526da82 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Move core function of BatchedANIDataset.__init__ outside (#236)

parent d9c0130f
...@@ -85,6 +85,80 @@ def split_batch(natoms, atomic_properties): ...@@ -85,6 +85,80 @@ def split_batch(natoms, atomic_properties):
return chunks return chunks
def load_and_pad_whole_dataset(path, species_tensor_converter, shuffle=True,
properties=('energies',), atomic_properties=()):
# get name of files storing data
files = []
if isdir(path):
for f in os.listdir(path):
f = join(path, f)
if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
files.append(f)
elif isfile(path):
files = [path]
else:
raise ValueError('Bad path')
# load full dataset
atomic_properties_ = []
properties = {k: [] for k in properties}
for f in files:
for m in anidataloader(f):
atomic_properties_.append(dict(
species=species_tensor_converter(m['species']).unsqueeze(0),
**{
k: torch.from_numpy(m[k]).to(torch.double)
for k in ['coordinates'] + list(atomic_properties)
}
))
for i in properties:
p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p)
atomic_properties = utils.pad_atomic_properties(atomic_properties_)
for i in properties:
properties[i] = torch.cat(properties[i])
# shuffle if required
molecules = atomic_properties['species'].shape[0]
if shuffle:
indices = torch.randperm(molecules)
for i in properties:
properties[i] = properties[i].index_select(0, indices)
for i in atomic_properties:
atomic_properties[i] = atomic_properties[i].index_select(0, indices)
return atomic_properties, properties
def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size):
molecules = atomic_properties['species'].shape[0]
# split into minibatches
for k in properties:
properties[k] = properties[k].split(batch_size)
for k in atomic_properties:
atomic_properties[k] = atomic_properties[k].split(batch_size)
# further split batch into chunks and strip redundant padding
batches = []
num_batches = (molecules + batch_size - 1) // batch_size
for i in range(num_batches):
batch_properties = {k: v[i] for k, v in properties.items()}
batch_atomic_properties = {k: v[i] for k, v in atomic_properties.items()}
species = batch_atomic_properties['species']
natoms = (species >= 0).to(torch.long).sum(1)
# sort batch by number of atoms to prepare for splitting
natoms, indices = natoms.sort()
for k in batch_properties:
batch_properties[k] = batch_properties[k].index_select(0, indices)
for k in batch_atomic_properties:
batch_atomic_properties[k] = batch_atomic_properties[k].index_select(0, indices)
batch_atomic_properties = split_batch(natoms, batch_atomic_properties)
batches.append((batch_atomic_properties, batch_properties))
return batches
class BatchedANIDataset(Dataset): class BatchedANIDataset(Dataset):
"""Load data from hdf5 files, create minibatches, and convert to tensors. """Load data from hdf5 files, create minibatches, and convert to tensors.
...@@ -153,47 +227,12 @@ class BatchedANIDataset(Dataset): ...@@ -153,47 +227,12 @@ class BatchedANIDataset(Dataset):
dtype=torch.get_default_dtype(), device=default_device): dtype=torch.get_default_dtype(), device=default_device):
super(BatchedANIDataset, self).__init__() super(BatchedANIDataset, self).__init__()
self.properties = properties self.properties = properties
self.atomic_properties = atomic_properties
self.device = device self.device = device
self.dtype = dtype
# get name of files storing data atomic_properties, properties = load_and_pad_whole_dataset(
files = [] path, species_tensor_converter, shuffle, properties, atomic_properties)
if isdir(path):
for f in os.listdir(path):
f = join(path, f)
if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
files.append(f)
elif isfile(path):
files = [path]
else:
raise ValueError('Bad path')
# load full dataset
atomic_properties_ = []
properties = {k: [] for k in self.properties}
for f in files:
for m in anidataloader(f):
atomic_properties_.append(dict(
species=species_tensor_converter(m['species']).unsqueeze(0),
**{
k: torch.from_numpy(m[k]).to(torch.double)
for k in ['coordinates'] + list(atomic_properties)
}
))
for i in properties:
p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p)
atomic_properties = utils.pad_atomic_properties(atomic_properties_)
for i in properties:
properties[i] = torch.cat(properties[i])
# shuffle if required
molecules = atomic_properties['species'].shape[0]
if shuffle:
indices = torch.randperm(molecules)
for i in properties:
properties[i] = properties[i].index_select(0, indices)
for i in atomic_properties:
atomic_properties[i] = atomic_properties[i].index_select(0, indices)
# do transformations on data # do transformations on data
for t in transform: for t in transform:
...@@ -207,30 +246,7 @@ class BatchedANIDataset(Dataset): ...@@ -207,30 +246,7 @@ class BatchedANIDataset(Dataset):
continue continue
atomic_properties[k] = atomic_properties[k].to(dtype) atomic_properties[k] = atomic_properties[k].to(dtype)
# split into minibatches self.batches = split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size)
for k in properties:
properties[k] = properties[k].split(batch_size)
for k in atomic_properties:
atomic_properties[k] = atomic_properties[k].split(batch_size)
# further split batch into chunks and strip redundant padding
self.batches = []
num_batches = (molecules + batch_size - 1) // batch_size
for i in range(num_batches):
batch_properties = {k: v[i] for k, v in properties.items()}
batch_atomic_properties = {k: v[i] for k, v in atomic_properties.items()}
species = batch_atomic_properties['species']
natoms = (species >= 0).to(torch.long).sum(1)
# sort batch by number of atoms to prepare for splitting
natoms, indices = natoms.sort()
for k in batch_properties:
batch_properties[k] = batch_properties[k].index_select(0, indices)
for k in batch_atomic_properties:
batch_atomic_properties[k] = batch_atomic_properties[k].index_select(0, indices)
batch_atomic_properties = split_batch(natoms, batch_atomic_properties)
self.batches.append((batch_atomic_properties, batch_properties))
def __getitem__(self, idx): def __getitem__(self, idx):
atomic_properties, properties = self.batches[idx] atomic_properties, properties = self.batches[idx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment