Unverified Commit 3957d19c authored by Richard Xue's avatar Richard Xue Committed by GitHub
Browse files

New Dataset API add other properties (#300)

* cached

* typo and comments

* easy to read

* change some names

* fix unit test

* empty line

* fix

* fix

* add docs and add whether include_energies

* docs

* other properties for shuffled dataset

* docs

* dtype for benchmark

* add properties to test

* style
parent 47c96afe
......@@ -10,6 +10,18 @@ dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5')
batch_size = 2560
chunk_threshold = 5
other_properties = {'properties': ['dipoles', 'forces', 'energies'],
'padding_values': [None, 0, None],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3), (batch_size, )],
'dtypes': [torch.float32, torch.float32, torch.float64],
}
other_properties = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(batch_size, )],
'dtypes': [torch.float64],
}
class TestFindThreshold(unittest.TestCase):
def setUp(self):
......@@ -23,15 +35,21 @@ class TestShuffledData(unittest.TestCase):
def setUp(self):
print('.. setup shuffle dataset')
self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2)
self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size,
chunk_threshold=chunk_threshold,
num_workers=2,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = iter(self.ds).next()
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), chunk[0].dtype, list(chunk[1].size()), chunk[1].dtype)
print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
......@@ -39,12 +57,15 @@ class TestShuffledData(unittest.TestCase):
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
for key, value in self.properties.items():
print(key, list(value.size()), value.dtype)
self.assertEqual(value.dtype, torch.float32)
self.assertEqual(len(value.shape), 1)
self.assertEqual(value.shape[0], batch_len)
print('2. properties')
for i, key in enumerate(other_properties['properties']):
print(key, list(self.properties[key].size()), self.properties[key].dtype)
# check dtype
self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
# shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self):
print('=> test loading all dataset')
......@@ -72,15 +93,20 @@ class TestCachedData(unittest.TestCase):
def setUp(self):
print('.. setup cached dataset')
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu', chunk_threshold=chunk_threshold)
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
chunk_threshold=chunk_threshold,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = self.ds[0]
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), chunk[0].dtype, list(chunk[1].size()), chunk[1].dtype)
print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
......@@ -88,12 +114,15 @@ class TestCachedData(unittest.TestCase):
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
for key, value in self.properties.items():
print(key, list(value.size()), value.dtype)
self.assertEqual(value.dtype, torch.float32)
self.assertEqual(len(value.shape), 1)
self.assertEqual(value.shape[0], batch_len)
print('2. properties')
for i, key in enumerate(other_properties['properties']):
print(key, list(self.properties[key].size()), self.properties[key].dtype)
# check dtype
self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
# shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self):
print('=> test loading all dataset')
......
......@@ -165,7 +165,7 @@ if __name__ == "__main__":
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
predicted_energies = torch.cat(predicted_energies).to(true_energies.dtype)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcal((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
loss.backward()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment