Unverified Commit 51cab350 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Make Container able to handle aev cache, add benchmark for aev cache (#90)

parent b546adb8
......@@ -22,4 +22,4 @@ benchmark_xyz
/*.dat
/tmp
*_cache
datacache
\ No newline at end of file
......@@ -23,7 +23,8 @@ Utilities
=========
.. automodule:: torchani.utils
.. autofunction:: torchani.utils.pad_and_batch
.. autofunction:: torchani.utils.pad
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
......
......@@ -44,7 +44,7 @@ class TestAEV(unittest.TestCase):
coordinates, species, radial, angular, _, _ = pickle.load(f)
species_coordinates.append((species, coordinates))
radial_angular.append((radial, angular))
species, coordinates = torchani.utils.pad_and_batch(
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates)
_, aev = self.aev_computer((species, coordinates))
start = 0
......
......@@ -30,7 +30,7 @@ class TestData(unittest.TestCase):
coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
......@@ -52,14 +52,14 @@ class TestData(unittest.TestCase):
self._assertTensorEqual(c, c_)
start += conformations
s, c = torchani.utils.pad_and_batch(chunks)
s, c = torchani.utils.pad_coordinates(chunks)
self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates)
def testTensorShape(self):
for i in self.ds:
input, output = i
species, coordinates = torchani.utils.pad_and_batch(input)
species, coordinates = torchani.utils.pad_coordinates(input)
energies = output['energies']
self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size)
......
......@@ -37,7 +37,7 @@ class TestEnergies(unittest.TestCase):
coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append((species, coordinates))
energies.append(e)
species, coordinates = torchani.utils.pad_and_batch(
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates)
energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates))
......
......@@ -39,7 +39,7 @@ class TestForce(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.utils.pad_and_batch(
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates)
_, energies = self.model((species, coordinates))
energies = energies.sum()
......
......@@ -3,14 +3,14 @@ import torch
import torchani
class TestPadAndBatch(unittest.TestCase):
class TestPaddings(unittest.TestCase):
def testVectorSpecies(self):
species1 = torch.LongTensor([0, 2, 3, 1])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -33,7 +33,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -62,7 +62,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -80,6 +80,29 @@ class TestPadAndBatch(unittest.TestCase):
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testPadSpecies(self):
species1 = torch.LongTensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
])
species2 = torch.LongTensor([3, 2, 0, 1, 0]).expand(2, 5)
species = torchani.utils.pad([species1, species2])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.utils.present_species(species)
......@@ -97,13 +120,13 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.utils.pad_and_batch([
species12, coordinates12 = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.utils.pad_and_batch([
species123, coordinates123 = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
......
import torch
import ignite
import torchani
import timeit
import tqdm
import argparse
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('cache_path',
help='Path of the aev cache')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
builtins = torchani.neurochem.Builtins()
consts = builtins.consts
aev_computer = builtins.aev_computer
shift_energy = builtins.energy_shifter
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
model = torchani.ANIModel([atomic() for _ in range(4)])
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
nnp = torch.nn.Sequential(model, Flatten()).to(device)
dataset = torchani.data.AEVCacheLoader(parser.cache_path)
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
timers = {}
def time_func(key, func):
timers[key] = 0
def wrapper(*args, **kwargs):
start = timeit.default_timer()
ret = func(*args, **kwargs)
end = timeit.default_timer()
timers[key] += end - start
return ret
return wrapper
# enable timers
nnp[0].forward = time_func('forward', nnp[0].forward)
# run it!
start = timeit.default_timer()
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('NN:', timers['forward'])
print('Epoch time:', elapsed)
dataset.__del__()
......@@ -161,7 +161,7 @@ class BatchedANIDataset(Dataset):
for i in properties:
p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p)
species, coordinates = utils.pad_and_batch(species_coordinates)
species, coordinates = utils.pad_coordinates(species_coordinates)
for i in properties:
properties[i] = torch.cat(properties[i])
......
......@@ -22,22 +22,19 @@ class Container(torch.nn.ModuleDict):
def __init__(self, modules):
super(Container, self).__init__(modules)
def forward(self, species_coordinates):
def forward(self, species_x):
"""Takes sequence of species, coordinates pair as input, and returns
computed properties as a dictionary. Same property from different
chunks will be concatenated to form a single tensor for a batch. The
input, i.e. species and coordinates of chunks, will also be batched by
:func:`torchani.utils.pad_and_batch` and copied to output.
chunks will be concatenated to form a single tensor for a batch.
"""
results = {k: [] for k in self}
for sc in species_coordinates:
for sx in species_x:
for k in self:
_, result = self[k](sc)
_, result = self[k](sx)
results[k].append(result)
for k in self:
results[k] = torch.cat(results[k])
results['species'], results['coordinates'] = \
utils.pad_and_batch(species_coordinates)
results['species'] = utils.pad([s for s, _ in species_x])
return results
......
import torch
def pad_and_batch(species_coordinates):
def pad(species):
"""Put different species together into single tensor.
If the species are from molecules of different number of total atoms, then
ghost atoms with atom type -1 will be added to make it fit into the same
shape.
Arguments:
species (:class:`collections.abc.Sequence`): sequence of species.
Species must be of shape ``(N, A)``, where ``N`` is the number of
3D structures, ``A`` is the number of atoms.
Returns:
:class:`torch.Tensor`: species batched together.
"""
max_atoms = max([s.shape[1] for s in species])
padded_species = []
for s in species:
natoms = s.shape[1]
if natoms < max_atoms:
padding = torch.full((s.shape[0], max_atoms - natoms), -1,
dtype=torch.long, device=s.device)
s = torch.cat([s, padding], dim=1)
padded_species.append(s)
return torch.cat(padded_species)
def pad_coordinates(species_coordinates):
"""Put different species and coordinates together into single tensor.
If the species and coordinates are from molecules of different number of
......@@ -124,4 +151,5 @@ class EnergyShifter(torch.nn.Module):
return species, energies + sae
__all__ = ['pad_and_batch', 'present_species', 'strip_redundant_padding']
__all__ = ['pad', 'pad_coordinates', 'present_species',
'strip_redundant_padding']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment