Unverified Commit 22975fa7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

fix padding strip when batching (#65)

parent 7c9402d0
...@@ -11,10 +11,13 @@ aev = torchani.AEVComputer() ...@@ -11,10 +11,13 @@ aev = torchani.AEVComputer()
class TestData(unittest.TestCase): class TestData(unittest.TestCase):
def setUp(self):
self.ds = torchani.training.BatchedANIDataset(dataset_path,
aev.species,
batch_size)
def testTensorShape(self): def testTensorShape(self):
ds = torchani.training.BatchedANIDataset(dataset_path, aev.species, for i in self.ds:
batch_size)
for i in ds:
input, output = i input, output = i
species, coordinates = input species, coordinates = input
energies = output['energies'] energies = output['energies']
...@@ -27,6 +30,13 @@ class TestData(unittest.TestCase): ...@@ -27,6 +30,13 @@ class TestData(unittest.TestCase):
self.assertEqual(len(energies.shape), 1) self.assertEqual(len(energies.shape), 1)
self.assertEqual(coordinates.shape[0], energies.shape[0]) self.assertEqual(coordinates.shape[0], energies.shape[0])
def testNoUnnecessaryPadding(self):
for i in self.ds:
input, _ = i
species, _ = input
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -87,5 +87,36 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -87,5 +87,36 @@ class TestPadAndBatch(unittest.TestCase):
self.assertEqual((expected - present_species).abs().max().item(), 0) self.assertEqual((expected - present_species).abs().max().item(), 0)
class TestStripRedundantPadding(unittest.TestCase):
def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1 - t2).abs().max().item(), 0)
def testStripRestore(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long)
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
])
species1_, coordinates1_ = torchani.padding.strip_redundant_padding(
species123[:5, ...], coordinates123[:5, ...])
self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1)
species12_, coordinates12_ = torchani.padding.strip_redundant_padding(
species123[:7, ...], coordinates123[:7, ...])
self._assertTensorEqual(species12_, species12)
self._assertTensorEqual(coordinates12_, coordinates12)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -30,7 +30,7 @@ def present_species(species): ...@@ -30,7 +30,7 @@ def present_species(species):
def strip_redundant_padding(species, coordinates): def strip_redundant_padding(species, coordinates):
non_padding = (species >= 0).any(dim=0) non_padding = (species >= 0).any(dim=0).nonzero().squeeze()
species = species.masked_select(non_padding, dim=1) species = species.index_select(1, non_padding)
coordinates = coordinates.masked_select(non_padding, dim=1) coordinates = coordinates.index_select(1, non_padding)
return species, coordinates return species, coordinates
...@@ -81,7 +81,8 @@ class BatchedANIDataset(Dataset): ...@@ -81,7 +81,8 @@ class BatchedANIDataset(Dataset):
properties_batch = { properties_batch = {
k: properties[k][start:end, ...] for k in properties k: properties[k][start:end, ...] for k in properties
} }
batches.append(((species_batch, coordinates_batch), batches.append((padding.strip_redundant_padding(species_batch,
coordinates_batch),
properties_batch)) properties_batch))
self.batches = batches self.batches = batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment