Unverified Commit 51cab350 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Make Container able to handle aev cache, add benchmark for aev cache (#90)

parent b546adb8
...@@ -22,4 +22,4 @@ benchmark_xyz ...@@ -22,4 +22,4 @@ benchmark_xyz
/*.dat /*.dat
/tmp /tmp
*_cache *_cache
datacache
\ No newline at end of file
...@@ -23,7 +23,8 @@ Utilities ...@@ -23,7 +23,8 @@ Utilities
========= =========
.. automodule:: torchani.utils .. automodule:: torchani.utils
.. autofunction:: torchani.utils.pad_and_batch .. autofunction:: torchani.utils.pad
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species .. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding .. autofunction:: torchani.utils.strip_redundant_padding
......
...@@ -44,7 +44,7 @@ class TestAEV(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestAEV(unittest.TestCase):
coordinates, species, radial, angular, _, _ = pickle.load(f) coordinates, species, radial, angular, _, _ = pickle.load(f)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
radial_angular.append((radial, angular)) radial_angular.append((radial, angular))
species, coordinates = torchani.utils.pad_and_batch( species, coordinates = torchani.utils.pad_coordinates(
species_coordinates) species_coordinates)
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
start = 0 start = 0
......
...@@ -30,7 +30,7 @@ class TestData(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestData(unittest.TestCase):
coordinates2 = torch.randn(2, 8, 3) coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long) species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3) coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.utils.pad_and_batch([ species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
(species3, coordinates3), (species3, coordinates3),
...@@ -52,14 +52,14 @@ class TestData(unittest.TestCase): ...@@ -52,14 +52,14 @@ class TestData(unittest.TestCase):
self._assertTensorEqual(c, c_) self._assertTensorEqual(c, c_)
start += conformations start += conformations
s, c = torchani.utils.pad_and_batch(chunks) s, c = torchani.utils.pad_coordinates(chunks)
self._assertTensorEqual(s, species) self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates) self._assertTensorEqual(c, coordinates)
def testTensorShape(self): def testTensorShape(self):
for i in self.ds: for i in self.ds:
input, output = i input, output = i
species, coordinates = torchani.utils.pad_and_batch(input) species, coordinates = torchani.utils.pad_coordinates(input)
energies = output['energies'] energies = output['energies']
self.assertEqual(len(species.shape), 2) self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size) self.assertLessEqual(species.shape[0], batch_size)
......
...@@ -37,7 +37,7 @@ class TestEnergies(unittest.TestCase): ...@@ -37,7 +37,7 @@ class TestEnergies(unittest.TestCase):
coordinates, species, _, _, e, _ = pickle.load(f) coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
energies.append(e) energies.append(e)
species, coordinates = torchani.utils.pad_and_batch( species, coordinates = torchani.utils.pad_coordinates(
species_coordinates) species_coordinates)
energies = torch.cat(energies) energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates)) _, energies_ = self.model((species, coordinates))
......
...@@ -39,7 +39,7 @@ class TestForce(unittest.TestCase): ...@@ -39,7 +39,7 @@ class TestForce(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces)) coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.utils.pad_and_batch( species, coordinates = torchani.utils.pad_coordinates(
species_coordinates) species_coordinates)
_, energies = self.model((species, coordinates)) _, energies = self.model((species, coordinates))
energies = energies.sum() energies = energies.sum()
......
...@@ -3,14 +3,14 @@ import torch ...@@ -3,14 +3,14 @@ import torch
import torchani import torchani
class TestPadAndBatch(unittest.TestCase): class TestPaddings(unittest.TestCase):
def testVectorSpecies(self): def testVectorSpecies(self):
species1 = torch.LongTensor([0, 2, 3, 1]) species1 = torch.LongTensor([0, 2, 3, 1])
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([ species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
...@@ -33,7 +33,7 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([ species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
...@@ -62,7 +62,7 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -62,7 +62,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([ species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
...@@ -80,6 +80,29 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -80,6 +80,29 @@ class TestPadAndBatch(unittest.TestCase):
self.assertEqual((species - expected_species).abs().max().item(), 0) self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0) self.assertEqual(coordinates.abs().max().item(), 0)
def testPadSpecies(self):
species1 = torch.LongTensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
])
species2 = torch.LongTensor([3, 2, 0, 1, 0]).expand(2, 5)
species = torchani.utils.pad([species1, species2])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
def testPresentSpecies(self): def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1]) species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.utils.present_species(species) present_species = torchani.utils.present_species(species)
...@@ -97,13 +120,13 @@ class TestStripRedundantPadding(unittest.TestCase): ...@@ -97,13 +120,13 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1 = torch.randn(5, 4, 3) coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long) species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3) coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.utils.pad_and_batch([ species12, coordinates12 = torchani.utils.pad_coordinates([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
species3 = torch.randint(4, (2, 10), dtype=torch.long) species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3) coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.utils.pad_and_batch([ species123, coordinates123 = torchani.utils.pad_coordinates([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
(species3, coordinates3), (species3, coordinates3),
......
import torch
import ignite
import torchani
import timeit
import tqdm
import argparse
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('cache_path',
help='Path of the aev cache')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
builtins = torchani.neurochem.Builtins()
consts = builtins.consts
aev_computer = builtins.aev_computer
shift_energy = builtins.energy_shifter
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
model = torchani.ANIModel([atomic() for _ in range(4)])
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
nnp = torch.nn.Sequential(model, Flatten()).to(device)
dataset = torchani.data.AEVCacheLoader(parser.cache_path)
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
timers = {}
def time_func(key, func):
timers[key] = 0
def wrapper(*args, **kwargs):
start = timeit.default_timer()
ret = func(*args, **kwargs)
end = timeit.default_timer()
timers[key] += end - start
return ret
return wrapper
# enable timers
nnp[0].forward = time_func('forward', nnp[0].forward)
# run it!
start = timeit.default_timer()
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('NN:', timers['forward'])
print('Epoch time:', elapsed)
dataset.__del__()
...@@ -161,7 +161,7 @@ class BatchedANIDataset(Dataset): ...@@ -161,7 +161,7 @@ class BatchedANIDataset(Dataset):
for i in properties: for i in properties:
p = torch.from_numpy(m[i]).to(torch.double) p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p) properties[i].append(p)
species, coordinates = utils.pad_and_batch(species_coordinates) species, coordinates = utils.pad_coordinates(species_coordinates)
for i in properties: for i in properties:
properties[i] = torch.cat(properties[i]) properties[i] = torch.cat(properties[i])
......
...@@ -22,22 +22,19 @@ class Container(torch.nn.ModuleDict): ...@@ -22,22 +22,19 @@ class Container(torch.nn.ModuleDict):
def __init__(self, modules): def __init__(self, modules):
super(Container, self).__init__(modules) super(Container, self).__init__(modules)
def forward(self, species_coordinates): def forward(self, species_x):
"""Takes sequence of species, coordinates pair as input, and returns """Takes sequence of species, coordinates pair as input, and returns
computed properties as a dictionary. Same property from different computed properties as a dictionary. Same property from different
chunks will be concatenated to form a single tensor for a batch. The chunks will be concatenated to form a single tensor for a batch.
input, i.e. species and coordinates of chunks, will also be batched by
:func:`torchani.utils.pad_and_batch` and copied to output.
""" """
results = {k: [] for k in self} results = {k: [] for k in self}
for sc in species_coordinates: for sx in species_x:
for k in self: for k in self:
_, result = self[k](sc) _, result = self[k](sx)
results[k].append(result) results[k].append(result)
for k in self: for k in self:
results[k] = torch.cat(results[k]) results[k] = torch.cat(results[k])
results['species'], results['coordinates'] = \ results['species'] = utils.pad([s for s, _ in species_x])
utils.pad_and_batch(species_coordinates)
return results return results
......
import torch import torch
def pad_and_batch(species_coordinates): def pad(species):
"""Put different species together into single tensor.
If the species are from molecules of different number of total atoms, then
ghost atoms with atom type -1 will be added to make it fit into the same
shape.
Arguments:
species (:class:`collections.abc.Sequence`): sequence of species.
Species must be of shape ``(N, A)``, where ``N`` is the number of
3D structures, ``A`` is the number of atoms.
Returns:
:class:`torch.Tensor`: species batched together.
"""
max_atoms = max([s.shape[1] for s in species])
padded_species = []
for s in species:
natoms = s.shape[1]
if natoms < max_atoms:
padding = torch.full((s.shape[0], max_atoms - natoms), -1,
dtype=torch.long, device=s.device)
s = torch.cat([s, padding], dim=1)
padded_species.append(s)
return torch.cat(padded_species)
def pad_coordinates(species_coordinates):
"""Put different species and coordinates together into single tensor. """Put different species and coordinates together into single tensor.
If the species and coordinates are from molecules of different number of If the species and coordinates are from molecules of different number of
...@@ -124,4 +151,5 @@ class EnergyShifter(torch.nn.Module): ...@@ -124,4 +151,5 @@ class EnergyShifter(torch.nn.Module):
return species, energies + sae return species, energies + sae
__all__ = ['pad_and_batch', 'present_species', 'strip_redundant_padding'] __all__ = ['pad', 'pad_coordinates', 'present_species',
'strip_redundant_padding']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment