Unverified Commit 99c2c3dc authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Split batch to avoid performance penalty on padding (#66)

parent 22975fa7
......@@ -16,7 +16,7 @@ parser.add_argument('-d', '--device',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--batch_size',
help='Number of conformations of each batch',
default=256, type=int)
default=1024, type=int)
parser = parser.parse_args()
# set up benchmark
......
import os
import torch
import torchani
import unittest
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset')
print(dataset_path)
batch_size = 256
aev = torchani.AEVComputer()
......@@ -16,10 +16,47 @@ class TestData(unittest.TestCase):
aev.species,
batch_size)
def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1-t2).abs().max(), 0)
def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long)
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 8), dtype=torch.long)
coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
])
natoms = (species >= 0).to(torch.long).sum(1)
chunks = torchani.training.data.split_batch(natoms, species,
coordinates)
start = 0
last = None
for s, c in chunks:
n = (s >= 0).to(torch.long).sum(1)
if last is not None:
self.assertNotEqual(last[-1], n[0])
conformations = s.shape[0]
self.assertGreater(conformations, 0)
s_ = species[start:start+conformations, ...]
c_ = coordinates[start:start+conformations, ...]
s_, c_ = torchani.padding.strip_redundant_padding(s_, c_)
self._assertTensorEqual(s, s_)
self._assertTensorEqual(c, c_)
start += conformations
s, c = torchani.padding.pad_and_batch(chunks)
self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates)
def testTensorShape(self):
for i in self.ds:
input, output = i
species, coordinates = input
species, coordinates = torchani.padding.pad_and_batch(input)
energies = output['energies']
self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size)
......@@ -32,10 +69,10 @@ class TestData(unittest.TestCase):
def testNoUnnecessaryPadding(self):
for i in self.ds:
input, _ = i
species, _ = input
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
for input in i[0]:
species, _ = input
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
if __name__ == '__main__':
......
import torch
from .. import padding
class Container(torch.nn.Module):
......@@ -10,12 +11,14 @@ class Container(torch.nn.Module):
setattr(self, 'model_' + i, models[i])
def forward(self, species_coordinates):
species, coordinates = species_coordinates
results = {
'species': species,
'coordinates': coordinates,
}
results = {k: [] for k in self.keys}
for sc in species_coordinates:
for k in self.keys:
model = getattr(self, 'model_' + k)
_, result = model(sc)
results[k].append(result)
results['species'], results['coordinates'] = \
padding.pad_and_batch(species_coordinates)
for k in self.keys:
model = getattr(self, 'model_' + k)
_, results[k] = model((species, coordinates))
results[k] = torch.cat(results[k])
return results
......@@ -8,6 +8,73 @@ import pickle
from .. import padding
def chunk_counts(counts, split):
split = [x + 1 for x in split] + [None]
count_chunks = []
start = 0
for i in split:
count_chunks.append(counts[start:i])
start = i
chunk_conformations = [sum([y[1] for y in x]) for x in count_chunks]
chunk_maxatoms = [x[-1][0] for x in count_chunks]
return chunk_conformations, chunk_maxatoms
def split_cost(counts, split):
split_min_cost = 40000
cost = 0
chunk_conformations, chunk_maxatoms = chunk_counts(counts, split)
for conformations, maxatoms in zip(chunk_conformations, chunk_maxatoms):
cost += max(conformations * maxatoms ** 2, split_min_cost)
return cost
def split_batch(natoms, species, coordinates):
# count number of conformation by natoms
natoms = natoms.tolist()
counts = []
for i in natoms:
if len(counts) == 0:
counts.append([i, 1])
continue
if i == counts[-1][0]:
counts[-1][1] += 1
else:
counts.append([i, 1])
# find best split using greedy strategy
split = []
cost = split_cost(counts, split)
improved = True
while improved:
improved = False
cycle_split = split
cycle_cost = cost
for i in range(len(counts)-1):
if i not in split:
s = sorted(split + [i])
c = split_cost(counts, s)
if c < cycle_cost:
improved = True
cycle_cost = c
cycle_split = s
if improved:
split = cycle_split
cost = cycle_cost
# do split
start = 0
species_coordinates = []
chunk_conformations, _ = chunk_counts(counts, split)
for i in chunk_conformations:
s = species
end = start + i
s = species[start:end, ...]
c = coordinates[start:end, ...]
s, c = padding.strip_redundant_padding(s, c)
species_coordinates.append((s, c))
start = end
return species_coordinates
class BatchedANIDataset(Dataset):
def __init__(self, path, species, batch_size, shuffle=True,
......@@ -71,29 +138,36 @@ class BatchedANIDataset(Dataset):
properties)
# split into minibatches, and strip reduncant padding
natoms = (species >= 0).to(torch.long).sum(1)
batches = []
num_batches = (conformations + batch_size - 1) // batch_size
for i in range(num_batches):
start = i * batch_size
end = min((i + 1) * batch_size, conformations)
species_batch = species[start:end, ...]
coordinates_batch = coordinates[start:end, ...]
natoms_batch = natoms[start:end]
natoms_batch, indices = natoms_batch.sort()
species_batch = species[start:end, ...].index_select(0, indices)
coordinates_batch = coordinates[start:end, ...] \
.index_select(0, indices)
properties_batch = {
k: properties[k][start:end, ...] for k in properties
k: properties[k][start:end, ...].index_select(0, indices)
for k in properties
}
batches.append((padding.strip_redundant_padding(species_batch,
coordinates_batch),
properties_batch))
# further split batch into chunks
species_coordinates = split_batch(natoms_batch, species_batch,
coordinates_batch)
batch = species_coordinates, properties_batch
batches.append(batch)
self.batches = batches
def __getitem__(self, idx):
(species, coordinates), properties = self.batches[idx]
species = species.to(self.device)
coordinates = coordinates.to(self.device)
species_coordinates, properties = self.batches[idx]
species_coordinates = [(s.to(self.device), c.to(self.device))
for s, c in species_coordinates]
properties = {
k: properties[k].to(self.device) for k in properties
}
return (species, coordinates), properties
return species_coordinates, properties
def __len__(self):
return len(self.batches)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment