Unverified Commit 22975fa7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

fix padding strip when batching (#65)

parent 7c9402d0
......@@ -11,10 +11,13 @@ aev = torchani.AEVComputer()
class TestData(unittest.TestCase):
def setUp(self):
self.ds = torchani.training.BatchedANIDataset(dataset_path,
aev.species,
batch_size)
def testTensorShape(self):
ds = torchani.training.BatchedANIDataset(dataset_path, aev.species,
batch_size)
for i in ds:
for i in self.ds:
input, output = i
species, coordinates = input
energies = output['energies']
......@@ -27,6 +30,13 @@ class TestData(unittest.TestCase):
self.assertEqual(len(energies.shape), 1)
self.assertEqual(coordinates.shape[0], energies.shape[0])
def testNoUnnecessaryPadding(self):
for i in self.ds:
input, _ = i
species, _ = input
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
if __name__ == '__main__':
unittest.main()
......@@ -87,5 +87,36 @@ class TestPadAndBatch(unittest.TestCase):
self.assertEqual((expected - present_species).abs().max().item(), 0)
class TestStripRedundantPadding(unittest.TestCase):
def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1 - t2).abs().max().item(), 0)
def testStripRestore(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long)
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
])
species1_, coordinates1_ = torchani.padding.strip_redundant_padding(
species123[:5, ...], coordinates123[:5, ...])
self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1)
species12_, coordinates12_ = torchani.padding.strip_redundant_padding(
species123[:7, ...], coordinates123[:7, ...])
self._assertTensorEqual(species12_, species12)
self._assertTensorEqual(coordinates12_, coordinates12)
if __name__ == '__main__':
unittest.main()
......@@ -30,7 +30,7 @@ def present_species(species):
def strip_redundant_padding(species, coordinates):
non_padding = (species >= 0).any(dim=0)
species = species.masked_select(non_padding, dim=1)
coordinates = coordinates.masked_select(non_padding, dim=1)
non_padding = (species >= 0).any(dim=0).nonzero().squeeze()
species = species.index_select(1, non_padding)
coordinates = coordinates.index_select(1, non_padding)
return species, coordinates
......@@ -81,7 +81,8 @@ class BatchedANIDataset(Dataset):
properties_batch = {
k: properties[k][start:end, ...] for k in properties
}
batches.append(((species_batch, coordinates_batch),
batches.append((padding.strip_redundant_padding(species_batch,
coordinates_batch),
properties_batch))
self.batches = batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment