Unverified Commit 3957d19c authored by Richard Xue's avatar Richard Xue Committed by GitHub
Browse files

New Dataset API add other properties (#300)

* cached

* typo and comments

* easy to read

* change some names

* fix unit test

* empty line

* fix

* fix

* add docs and add whether include_energies

* docs

* other properties for shuffled dataset

* docs

* dtype for benchmark

* add properties to test

* style
parent 47c96afe
...@@ -10,6 +10,18 @@ dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5') ...@@ -10,6 +10,18 @@ dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5')
batch_size = 2560 batch_size = 2560
chunk_threshold = 5 chunk_threshold = 5
other_properties = {'properties': ['dipoles', 'forces', 'energies'],
'padding_values': [None, 0, None],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3), (batch_size, )],
'dtypes': [torch.float32, torch.float32, torch.float64],
}
other_properties = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(batch_size, )],
'dtypes': [torch.float64],
}
class TestFindThreshold(unittest.TestCase): class TestFindThreshold(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -23,15 +35,21 @@ class TestShuffledData(unittest.TestCase): ...@@ -23,15 +35,21 @@ class TestShuffledData(unittest.TestCase):
def setUp(self): def setUp(self):
print('.. setup shuffle dataset') print('.. setup shuffle dataset')
self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2) self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size,
chunk_threshold=chunk_threshold,
num_workers=2,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = iter(self.ds).next() self.chunks, self.properties = iter(self.ds).next()
def testTensorShape(self): def testTensorShape(self):
print('=> checking tensor shape') print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)') print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0 batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks): for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), chunk[0].dtype, list(chunk[1].size()), chunk[1].dtype) print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype # check dtype
self.assertEqual(chunk[0].dtype, torch.int64) self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32) self.assertEqual(chunk[1].dtype, torch.float32)
...@@ -39,12 +57,15 @@ class TestShuffledData(unittest.TestCase): ...@@ -39,12 +57,15 @@ class TestShuffledData(unittest.TestCase):
self.assertEqual(chunk[1].shape[2], 3) self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2]) self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0] batch_len += chunk[0].shape[0]
print('2. properties')
for key, value in self.properties.items(): for i, key in enumerate(other_properties['properties']):
print(key, list(value.size()), value.dtype) print(key, list(self.properties[key].size()), self.properties[key].dtype)
self.assertEqual(value.dtype, torch.float32) # check dtype
self.assertEqual(len(value.shape), 1) self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
self.assertEqual(value.shape[0], batch_len) # shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self): def testLoadDataset(self):
print('=> test loading all dataset') print('=> test loading all dataset')
...@@ -72,15 +93,20 @@ class TestCachedData(unittest.TestCase): ...@@ -72,15 +93,20 @@ class TestCachedData(unittest.TestCase):
def setUp(self): def setUp(self):
print('.. setup cached dataset') print('.. setup cached dataset')
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu', chunk_threshold=chunk_threshold) self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
chunk_threshold=chunk_threshold,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = self.ds[0] self.chunks, self.properties = self.ds[0]
def testTensorShape(self): def testTensorShape(self):
print('=> checking tensor shape') print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)') print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0 batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks): for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), chunk[0].dtype, list(chunk[1].size()), chunk[1].dtype) print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype # check dtype
self.assertEqual(chunk[0].dtype, torch.int64) self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32) self.assertEqual(chunk[1].dtype, torch.float32)
...@@ -88,12 +114,15 @@ class TestCachedData(unittest.TestCase): ...@@ -88,12 +114,15 @@ class TestCachedData(unittest.TestCase):
self.assertEqual(chunk[1].shape[2], 3) self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2]) self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0] batch_len += chunk[0].shape[0]
print('2. properties')
for key, value in self.properties.items(): for i, key in enumerate(other_properties['properties']):
print(key, list(value.size()), value.dtype) print(key, list(self.properties[key].size()), self.properties[key].dtype)
self.assertEqual(value.dtype, torch.float32) # check dtype
self.assertEqual(len(value.shape), 1) self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
self.assertEqual(value.shape[0], batch_len) # shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self): def testLoadDataset(self):
print('=> test loading all dataset') print('=> test loading all dataset')
......
...@@ -165,7 +165,7 @@ if __name__ == "__main__": ...@@ -165,7 +165,7 @@ if __name__ == "__main__":
predicted_energies.append(chunk_energies) predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms) num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies) predicted_energies = torch.cat(predicted_energies).to(true_energies.dtype)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcal((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() rmse = hartree2kcal((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
loss.backward() loss.backward()
......
...@@ -28,32 +28,48 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -28,32 +28,48 @@ class CachedDataset(torch.utils.data.Dataset):
Arguments: Arguments:
file_path (str): Path to one hdf5 file. file_path (str): Path to one hdf5 file.
batch_size (int): batch size. batch_size (int): batch size.
device (str): ``'cuda'`` or ``'cpu'``, cache to CPU or GPU. Commonly, 'cpu' is already fast enough.
Default is ``'cpu'``.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` will not split chunks. chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` will not split chunks.
Use ``torchani.data.find_threshold`` to find resonable ``chunk_threshold``.
other_properties (dict): A dict which is used to extract properties other than
``energies`` from dataset with correct padding, shape and dtype.\n
The example below will extract ``dipoles`` and ``forces``.\n
``padding_values``: set to ``None`` means there is no need to pad for this property.
.. code-block:: python
other_properties = {'properties': ['dipoles', 'forces'],
'padding_values': [None, 0],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3)],
'dtypes': [torch.float32, torch.float32]
}
include_energies (bool): Whether include energies into properties. Default is ``True``.
species_order (list): a list which specify how species are transfomed to int. species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``. for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``. subtract_self_energies (bool): whether subtract self energies from ``energies``.
self_energies (list): if `subtract_self_energies` is True, the order should keep self_energies (list): if `subtract_self_energies` is True, the order should keep
the same as ``species_order``. the same as ``species_order``.
for example :``[-0.600953, -38.08316, -54.707756, -75.194466]`` will be converted for example :``[-0.600953, -38.08316, -54.707756, -75.194466]`` will be converted
to ``{'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}``.. to ``{'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}``.
.. note:: .. note::
The resulting dataset will be: The resulting dataset will be:
``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a ``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a
tuple of ``(species, coordinates)``. tuple of ``(species, coordinates)``.
e.g. the shape of e.g. the shape of\n
chunk1: ``[[1807, 21], [1807, 21, 3]]``\n
chunk1: ``[[1807, 21], [1807, 21, 3]]`` chunk2: ``[[193, 50], [193, 50, 3]]``\n
chunk2: ``[[193, 50], [193, 50, 3]]``
'energies': ``[2000, 1]`` 'energies': ``[2000, 1]``
""" """
def __init__(self, file_path, def __init__(self, file_path,
batch_size=1000, batch_size=1000,
device='cpu', device='cpu',
chunk_threshold=20, chunk_threshold=20,
other_properties={},
include_energies=True,
species_order=['H', 'C', 'N', 'O'], species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=False, subtract_self_energies=False,
self_energies=[-0.600953, -38.08316, -54.707756, -75.194466]): self_energies=[-0.600953, -38.08316, -54.707756, -75.194466]):
...@@ -69,42 +85,53 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -69,42 +85,53 @@ class CachedDataset(torch.utils.data.Dataset):
species_dict[s] = i species_dict[s] = i
self_energies_dict[s] = self_energies[i] self_energies_dict[s] = self_energies[i]
self.batch_size = batch_size
self.data_species = [] self.data_species = []
self.data_coordinates = [] self.data_coordinates = []
self.data_energies = [] data_self_energies = []
self.data_self_energies = [] self.data_properties = {}
self.properties_info = other_properties
# whether include energies to properties
if include_energies:
self.add_energies_to_properties()
# let user check the properties will be loaded
self.check_properties()
# anidataloader
anidata = anidataloader(file_path) anidata = anidataloader(file_path)
anidata_size = anidata.group_size() anidata_size = anidata.group_size()
self.enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED self.enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED
if self.enable_pkbar: if self.enable_pkbar:
pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size) pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size)
# load h5 data into cpu memory as lists
for i, molecule in enumerate(anidata): for i, molecule in enumerate(anidata):
# conformations
num_conformations = len(molecule['coordinates']) num_conformations = len(molecule['coordinates'])
# species and coordinates # species and coordinates
self.data_coordinates += list(molecule['coordinates'].reshape(num_conformations, -1).astype(np.float32)) self.data_coordinates += list(molecule['coordinates'].reshape(num_conformations, -1).astype(np.float32))
species = np.array([species_dict[x] for x in molecule['species']]) species = np.array([species_dict[x] for x in molecule['species']])
self.data_species += list(np.tile(species, (num_conformations, 1))) self.data_species += list(np.tile(species, (num_conformations, 1)))
# energies # if subtract_self_energies
self.data_energies += list(molecule['energies'].reshape((-1, 1)))
if subtract_self_energies: if subtract_self_energies:
self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']])) self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']]))
self.data_self_energies += list(np.tile(self_energies, (num_conformations, 1))) data_self_energies += list(np.tile(self_energies, (num_conformations, 1)))
# properties
for key in self.data_properties:
self.data_properties[key] += list(molecule[key].reshape(num_conformations, -1))
# pkbar update
if self.enable_pkbar: if self.enable_pkbar:
pbar.update(i) pbar.update(i)
if subtract_self_energies: # if subtract self energies
self.data_energies = np.array(self.data_energies) - np.array(self.data_self_energies) if subtract_self_energies and 'energies' in self.properties_info['properties']:
del self.data_self_energies self.data_properties['energies'] = np.array(self.data_properties['energies']) - np.array(data_self_energies)
del self_energies del data_self_energies
gc.collect() gc.collect()
self.batch_size = batch_size
self.length = (len(self.data_species) + self.batch_size - 1) // self.batch_size self.length = (len(self.data_species) + self.batch_size - 1) // self.batch_size
self.device = device self.device = device
self.shuffled_index = np.arange(len(self.data_species)) self.shuffled_index = np.arange(len(self.data_species))
np.random.shuffle(self.shuffled_index) np.random.shuffle(self.shuffled_index)
...@@ -112,6 +139,7 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -112,6 +139,7 @@ class CachedDataset(torch.utils.data.Dataset):
if not self.chunk_threshold: if not self.chunk_threshold:
self.chunk_threshold = np.inf self.chunk_threshold = np.inf
# clean trash
anidata.cleanup() anidata.cleanup()
del num_conformations del num_conformations
del species del species
...@@ -129,7 +157,6 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -129,7 +157,6 @@ class CachedDataset(torch.utils.data.Dataset):
batch_species = [self.data_species[i] for i in batch_indices_shuffled] batch_species = [self.data_species[i] for i in batch_indices_shuffled]
batch_coordinates = [self.data_coordinates[i] for i in batch_indices_shuffled] batch_coordinates = [self.data_coordinates[i] for i in batch_indices_shuffled]
batch_energies = [self.data_energies[i] for i in batch_indices_shuffled]
# get sort index # get sort index
num_atoms_each_mole = [b.shape[0] for b in batch_species] num_atoms_each_mole = [b.shape[0] for b in batch_species]
...@@ -139,7 +166,6 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -139,7 +166,6 @@ class CachedDataset(torch.utils.data.Dataset):
# sort each batch of data # sort each batch of data
batch_species = self.sort_list_with_index(batch_species, sorted_atoms_idx.numpy()) batch_species = self.sort_list_with_index(batch_species, sorted_atoms_idx.numpy())
batch_coordinates = self.sort_list_with_index(batch_coordinates, sorted_atoms_idx.numpy()) batch_coordinates = self.sort_list_with_index(batch_coordinates, sorted_atoms_idx.numpy())
batch_energies = self.sort_list_with_index(batch_energies, sorted_atoms_idx.numpy())
# get chunk size # get chunk size
output, count = torch.unique(atoms, sorted=True, return_counts=True) output, count = torch.unique(atoms, sorted=True, return_counts=True)
...@@ -150,19 +176,32 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -150,19 +176,32 @@ class CachedDataset(torch.utils.data.Dataset):
# split into chunks # split into chunks
chunks_batch_species = self.split_list_with_size(batch_species, chunk_size_list.numpy()) chunks_batch_species = self.split_list_with_size(batch_species, chunk_size_list.numpy())
chunks_batch_coordinates = self.split_list_with_size(batch_coordinates, chunk_size_list.numpy()) chunks_batch_coordinates = self.split_list_with_size(batch_coordinates, chunk_size_list.numpy())
batch_energies = self.split_list_with_size(batch_energies, np.array([self.batch_size]))
# padding each data # padding each data
chunks_batch_species = self.pad_and_convert_to_tensor(chunks_batch_species, padding_value=-1) chunks_batch_species = self.pad_and_convert_to_tensor(chunks_batch_species, padding_value=-1)
chunks_batch_coordinates = self.pad_and_convert_to_tensor(chunks_batch_coordinates) chunks_batch_coordinates = self.pad_and_convert_to_tensor(chunks_batch_coordinates)
batch_energies = self.pad_and_convert_to_tensor(batch_energies, no_padding=True)
# chunks
chunks = list(zip(chunks_batch_species, chunks_batch_coordinates)) chunks = list(zip(chunks_batch_species, chunks_batch_coordinates))
for i, _ in enumerate(chunks): for i, _ in enumerate(chunks):
chunks[i] = (chunks[i][0], chunks[i][1].reshape(chunks[i][1].shape[0], -1, 3)) chunks[i] = (chunks[i][0], chunks[i][1].reshape(chunks[i][1].shape[0], -1, 3))
properties = {'energies': batch_energies[0].flatten().float()} # properties
properties = {}
for i, key in enumerate(self.properties_info['properties']):
# get a batch of property
prop = [self.data_properties[key][i] for i in batch_indices_shuffled]
# sort with number of atoms
prop = self.sort_list_with_index(prop, sorted_atoms_idx.numpy())
# padding and convert to tensor
if self.properties_info['padding_values'][i] is None:
prop = self.pad_and_convert_to_tensor([prop], no_padding=True)[0]
else:
prop = self.pad_and_convert_to_tensor([prop], padding_value=self.properties_info['padding_values'][i])[0]
# set property shape and dtype
padded_shape = list(self.properties_info['padded_shapes'][i])
padded_shape[0] = prop.shape[0] # the last batch may does not have one batch data
properties[key] = prop.reshape(padded_shape).to(self.properties_info['dtypes'][i])
# return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates) # return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates)
# e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]] # e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]]
...@@ -209,6 +248,32 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -209,6 +248,32 @@ class CachedDataset(torch.utils.data.Dataset):
if self.enable_pkbar: if self.enable_pkbar:
pbar.update(i) pbar.update(i)
def add_energies_to_properties(self):
# if user does not provide energies info
if 'properties' in self.properties_info and 'energies' not in self.properties_info['properties']:
# setup energies info, so the user does not need to input energies
self.properties_info['properties'].append('energies')
self.properties_info['padding_values'].append(None)
self.properties_info['padded_shapes'].append((self.batch_size, ))
self.properties_info['dtypes'].append(torch.float64)
# if no properties provided
if 'properties' not in self.properties_info:
self.properties_info = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(self.batch_size, )],
'dtypes': [torch.float64],
}
def check_properties(self):
# print properties information
print('... The following properties will be loaded:')
for i, prop in enumerate(self.properties_info['properties']):
self.data_properties[prop] = []
message = '{}: (dtype: {}, padding_value: {}, padded_shape: {})'
print(message.format(prop, self.properties_info['dtypes'][i],
self.properties_info['padding_values'][i],
self.properties_info['padded_shapes'][i]))
@staticmethod @staticmethod
def sort_list_with_index(inputs, index): def sort_list_with_index(inputs, index):
return [inputs[i] for i in index] return [inputs[i] for i in index]
...@@ -265,7 +330,10 @@ class CachedDataset(torch.utils.data.Dataset): ...@@ -265,7 +330,10 @@ class CachedDataset(torch.utils.data.Dataset):
def ShuffledDataset(file_path, def ShuffledDataset(file_path,
batch_size=1000, num_workers=0, shuffle=True, chunk_threshold=20, batch_size=1000, num_workers=0, shuffle=True,
chunk_threshold=20,
other_properties={},
include_energies=True,
validation_split=0.0, validation_split=0.0,
species_order=['H', 'C', 'N', 'O'], species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=False, subtract_self_energies=False,
...@@ -278,8 +346,22 @@ def ShuffledDataset(file_path, ...@@ -278,8 +346,22 @@ def ShuffledDataset(file_path,
num_workers (int): multiple process to prepare dataset at background when num_workers (int): multiple process to prepare dataset at background when
training is going. training is going.
shuffle (bool): whether to shuffle. shuffle (bool): whether to shuffle.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` will not split chunks.
will not split chunks. Use ``torchani.data.find_threshold`` to find resonable ``chunk_threshold``.
other_properties (dict): A dict which is used to extract properties other than
``energies`` from dataset with correct padding, shape and dtype.\n
The example below will extract ``dipoles`` and ``forces``.\n
``padding_values``: set to ``None`` means there is no need to pad for this property.
.. code-block:: python
other_properties = {'properties': ['dipoles', 'forces'],
'padding_values': [None, 0],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3)],
'dtypes': [torch.float32, torch.float32]
}
include_energies (bool): Whether include energies into properties. Default is ``True``.
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data. as validation data.
species_order (list): a list which specify how species are transfomed to int. species_order (list): a list which specify how species are transfomed to int.
...@@ -294,24 +376,27 @@ def ShuffledDataset(file_path, ...@@ -294,24 +376,27 @@ def ShuffledDataset(file_path,
Return a dataloader that, when iterating, you will get Return a dataloader that, when iterating, you will get
``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a ``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a
tuple of ``(species, coordinates)``. tuple of ``(species, coordinates)``.\n
e.g. the shape of\n
e.g. the shape of chunk1: ``[[1807, 21], [1807, 21, 3]]``\n
chunk2: ``[[193, 50], [193, 50, 3]]``\n
chunk1: ``[[1807, 21], [1807, 21, 3]]``
chunk2: ``[[193, 50], [193, 50, 3]]``
'energies': ``[2000, 1]`` 'energies': ``[2000, 1]``
""" """
dataset = TorchData(file_path, species_order, subtract_self_energies, self_energies) dataset = TorchData(file_path,
batch_size,
other_properties,
include_energies,
species_order,
subtract_self_energies,
self_energies)
properties_info = dataset.get_properties_info()
if not chunk_threshold: if not chunk_threshold:
chunk_threshold = np.inf chunk_threshold = np.inf
def my_collate_fn(data, chunk_threshold=chunk_threshold): def my_collate_fn(data, chunk_threshold=chunk_threshold, properties_info=properties_info):
return collate_fn(data, chunk_threshold) return collate_fn(data, chunk_threshold, properties_info)
val_size = int(validation_split * len(dataset)) val_size = int(validation_split * len(dataset))
train_size = len(dataset) - val_size train_size = len(dataset) - val_size
...@@ -338,7 +423,13 @@ def ShuffledDataset(file_path, ...@@ -338,7 +423,13 @@ def ShuffledDataset(file_path,
class TorchData(torch.utils.data.Dataset): class TorchData(torch.utils.data.Dataset):
def __init__(self, file_path, species_order, subtract_self_energies, self_energies): def __init__(self, file_path,
batch_size,
other_properties,
include_energies,
species_order,
subtract_self_energies,
self_energies):
super(TorchData, self).__init__() super(TorchData, self).__init__()
...@@ -348,38 +439,55 @@ class TorchData(torch.utils.data.Dataset): ...@@ -348,38 +439,55 @@ class TorchData(torch.utils.data.Dataset):
species_dict[s] = i species_dict[s] = i
self_energies_dict[s] = self_energies[i] self_energies_dict[s] = self_energies[i]
self.batch_size = batch_size
self.data_species = [] self.data_species = []
self.data_coordinates = [] self.data_coordinates = []
self.data_energies = [] data_self_energies = []
self.data_self_energies = [] self.data_properties = {}
self.properties_info = other_properties
# whether include energies to properties
if include_energies:
self.add_energies_to_properties()
# let user check the properties will be loaded
self.check_properties()
# anidataloader
anidata = anidataloader(file_path) anidata = anidataloader(file_path)
anidata_size = anidata.group_size() anidata_size = anidata.group_size()
enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED self.enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED
if enable_pkbar: if self.enable_pkbar:
pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size) pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size)
# load h5 data into cpu memory as lists
for i, molecule in enumerate(anidata): for i, molecule in enumerate(anidata):
# conformations
num_conformations = len(molecule['coordinates']) num_conformations = len(molecule['coordinates'])
# species and coordinates
self.data_coordinates += list(molecule['coordinates'].reshape(num_conformations, -1).astype(np.float32)) self.data_coordinates += list(molecule['coordinates'].reshape(num_conformations, -1).astype(np.float32))
self.data_energies += list(molecule['energies'].reshape((-1, 1)))
species = np.array([species_dict[x] for x in molecule['species']]) species = np.array([species_dict[x] for x in molecule['species']])
self.data_species += list(np.tile(species, (num_conformations, 1))) self.data_species += list(np.tile(species, (num_conformations, 1)))
# if subtract_self_energies
if subtract_self_energies: if subtract_self_energies:
self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']])) self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']]))
self.data_self_energies += list(np.tile(self_energies, (num_conformations, 1))) data_self_energies += list(np.tile(self_energies, (num_conformations, 1)))
if enable_pkbar: # properties
for key in self.data_properties:
self.data_properties[key] += list(molecule[key].reshape(num_conformations, -1))
# pkbar update
if self.enable_pkbar:
pbar.update(i) pbar.update(i)
if subtract_self_energies: # if subtract self energies
self.data_energies = np.array(self.data_energies) - np.array(self.data_self_energies) if subtract_self_energies and 'energies' in self.properties_info['properties']:
del self.data_self_energies self.data_properties['energies'] = np.array(self.data_properties['energies']) - np.array(data_self_energies)
del self_energies del data_self_energies
gc.collect() gc.collect()
self.length = len(self.data_species) self.length = len(self.data_species)
anidata.cleanup()
# clean trash
anidata.cleanup()
del num_conformations del num_conformations
del species del species
del anidata del anidata
...@@ -392,20 +500,52 @@ class TorchData(torch.utils.data.Dataset): ...@@ -392,20 +500,52 @@ class TorchData(torch.utils.data.Dataset):
species = torch.from_numpy(self.data_species[index]) species = torch.from_numpy(self.data_species[index])
coordinates = torch.from_numpy(self.data_coordinates[index]).float() coordinates = torch.from_numpy(self.data_coordinates[index]).float()
energies = torch.from_numpy(self.data_energies[index]).float() properties = {}
for key in self.data_properties:
properties[key] = torch.from_numpy(self.data_properties[key][index])
return [species, coordinates, energies] return [species, coordinates, properties]
def __len__(self): def __len__(self):
return self.length return self.length
def add_energies_to_properties(self):
def collate_fn(data, chunk_threshold): # if user does not provide energies info
if 'properties' in self.properties_info and 'energies' not in self.properties_info['properties']:
# setup energies info, so the user does not need to input energies
self.properties_info['properties'].append('energies')
self.properties_info['padding_values'].append(None)
self.properties_info['padded_shapes'].append((self.batch_size, ))
self.properties_info['dtypes'].append(torch.float64)
# if no properties provided
if 'properties' not in self.properties_info:
self.properties_info = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(self.batch_size, )],
'dtypes': [torch.float64],
}
def check_properties(self):
# print properties information
print('... The following properties will be loaded:')
for i, prop in enumerate(self.properties_info['properties']):
self.data_properties[prop] = []
message = '{}: (dtype: {}, padding_value: {}, padded_shape: {})'
print(message.format(prop, self.properties_info['dtypes'][i],
self.properties_info['padding_values'][i],
self.properties_info['padded_shapes'][i]))
def get_properties_info(self):
return self.properties_info
def collate_fn(data, chunk_threshold, properties_info):
"""Creates a batch of chunked data. """Creates a batch of chunked data.
""" """
# unzip a batch of molecules (each molecule is a list) # unzip a batch of molecules (each molecule is a list)
batch_species, batch_coordinates, batch_energies = zip(*data) batch_species, batch_coordinates, batch_properties = zip(*data)
batch_size = len(batch_species) batch_size = len(batch_species)
# padding - time: 13.2s # padding - time: 13.2s
...@@ -415,7 +555,6 @@ def collate_fn(data, chunk_threshold): ...@@ -415,7 +555,6 @@ def collate_fn(data, chunk_threshold):
batch_coordinates = torch.nn.utils.rnn.pad_sequence(batch_coordinates, batch_coordinates = torch.nn.utils.rnn.pad_sequence(batch_coordinates,
batch_first=True, batch_first=True,
padding_value=0) padding_value=0)
batch_energies = torch.stack(batch_energies)
# sort - time: 0.7s # sort - time: 0.7s
atoms = torch.sum(~(batch_species == -1), dim=-1, dtype=torch.int32) atoms = torch.sum(~(batch_species == -1), dim=-1, dtype=torch.int32)
...@@ -423,7 +562,6 @@ def collate_fn(data, chunk_threshold): ...@@ -423,7 +562,6 @@ def collate_fn(data, chunk_threshold):
batch_species = torch.index_select(batch_species, dim=0, index=sorted_atoms_idx) batch_species = torch.index_select(batch_species, dim=0, index=sorted_atoms_idx)
batch_coordinates = torch.index_select(batch_coordinates, dim=0, index=sorted_atoms_idx) batch_coordinates = torch.index_select(batch_coordinates, dim=0, index=sorted_atoms_idx)
batch_energies = torch.index_select(batch_energies, dim=0, index=sorted_atoms_idx)
# get chunk size - time: 2.1s # get chunk size - time: 2.1s
output, count = torch.unique(atoms, sorted=True, return_counts=True) output, count = torch.unique(atoms, sorted=True, return_counts=True)
...@@ -446,7 +584,24 @@ def collate_fn(data, chunk_threshold): ...@@ -446,7 +584,24 @@ def collate_fn(data, chunk_threshold):
for i, _ in enumerate(chunks): for i, _ in enumerate(chunks):
chunks[i] = (chunks[i][0], chunks[i][1]) chunks[i] = (chunks[i][0], chunks[i][1])
properties = {'energies': batch_energies.flatten().float()} # properties
properties = {}
for i, key in enumerate(properties_info['properties']):
# get a batch of property
prop = tuple(p[key] for p in batch_properties)
# padding and convert to tensor
if properties_info['padding_values'][i] is None:
prop = torch.stack(prop)
else:
prop = torch.nn.utils.rnn.pad_sequence(batch_species,
batch_first=True,
padding_value=properties_info['padding_values'][i])
# sort with number of atoms
prop = torch.index_select(prop, dim=0, index=sorted_atoms_idx)
# set property shape and dtype
padded_shape = list(properties_info['padded_shapes'][i])
padded_shape[0] = prop.shape[0] # the last batch may does not have one batch data
properties[key] = prop.reshape(padded_shape).to(properties_info['dtypes'][i])
# return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates) # return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates)
# e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]] # e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment