"src/libtorchaudio/sox/utils.h" did not exist on "ec13a815b13ec6be3eeb8c3eb9ccb725dc322233"
Unverified Commit 51cab350 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Make Container able to handle aev cache, add benchmark for aev cache (#90)

parent b546adb8
......@@ -22,4 +22,4 @@ benchmark_xyz
/*.dat
/tmp
*_cache
datacache
\ No newline at end of file
......@@ -23,7 +23,8 @@ Utilities
=========
.. automodule:: torchani.utils
.. autofunction:: torchani.utils.pad_and_batch
.. autofunction:: torchani.utils.pad
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
......
......@@ -44,7 +44,7 @@ class TestAEV(unittest.TestCase):
coordinates, species, radial, angular, _, _ = pickle.load(f)
species_coordinates.append((species, coordinates))
radial_angular.append((radial, angular))
species, coordinates = torchani.utils.pad_and_batch(
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates)
_, aev = self.aev_computer((species, coordinates))
start = 0
......
......@@ -30,7 +30,7 @@ class TestData(unittest.TestCase):
coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
......@@ -52,14 +52,14 @@ class TestData(unittest.TestCase):
self._assertTensorEqual(c, c_)
start += conformations
s, c = torchani.utils.pad_and_batch(chunks)
s, c = torchani.utils.pad_coordinates(chunks)
self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates)
def testTensorShape(self):
for i in self.ds:
input, output = i
species, coordinates = torchani.utils.pad_and_batch(input)
species, coordinates = torchani.utils.pad_coordinates(input)
energies = output['energies']
self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size)
......
......@@ -37,7 +37,7 @@ class TestEnergies(unittest.TestCase):
coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append((species, coordinates))
energies.append(e)
species, coordinates = torchani.utils.pad_and_batch(
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates)
energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates))
......
......@@ -39,7 +39,7 @@ class TestForce(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.utils.pad_and_batch(
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates)
_, energies = self.model((species, coordinates))
energies = energies.sum()
......
......@@ -3,14 +3,14 @@ import torch
import torchani
class TestPadAndBatch(unittest.TestCase):
class TestPaddings(unittest.TestCase):
def testVectorSpecies(self):
species1 = torch.LongTensor([0, 2, 3, 1])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -33,7 +33,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -62,7 +62,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_and_batch([
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -80,6 +80,29 @@ class TestPadAndBatch(unittest.TestCase):
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testPadSpecies(self):
species1 = torch.LongTensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
])
species2 = torch.LongTensor([3, 2, 0, 1, 0]).expand(2, 5)
species = torchani.utils.pad([species1, species2])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.utils.present_species(species)
......@@ -97,13 +120,13 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.utils.pad_and_batch([
species12, coordinates12 = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
])
species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.utils.pad_and_batch([
species123, coordinates123 = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
......
import torch
import ignite
import torchani
import timeit
import tqdm
import argparse
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('cache_path',
help='Path of the aev cache')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
builtins = torchani.neurochem.Builtins()
consts = builtins.consts
aev_computer = builtins.aev_computer
shift_energy = builtins.energy_shifter
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
model = torchani.ANIModel([atomic() for _ in range(4)])
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
nnp = torch.nn.Sequential(model, Flatten()).to(device)
dataset = torchani.data.AEVCacheLoader(parser.cache_path)
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
timers = {}
def time_func(key, func):
timers[key] = 0
def wrapper(*args, **kwargs):
start = timeit.default_timer()
ret = func(*args, **kwargs)
end = timeit.default_timer()
timers[key] += end - start
return ret
return wrapper
# enable timers
nnp[0].forward = time_func('forward', nnp[0].forward)
# run it!
start = timeit.default_timer()
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('NN:', timers['forward'])
print('Epoch time:', elapsed)
dataset.__del__()
......@@ -161,7 +161,7 @@ class BatchedANIDataset(Dataset):
for i in properties:
p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p)
species, coordinates = utils.pad_and_batch(species_coordinates)
species, coordinates = utils.pad_coordinates(species_coordinates)
for i in properties:
properties[i] = torch.cat(properties[i])
......
......@@ -22,22 +22,19 @@ class Container(torch.nn.ModuleDict):
def __init__(self, modules):
super(Container, self).__init__(modules)
def forward(self, species_coordinates):
def forward(self, species_x):
"""Takes sequence of species, coordinates pair as input, and returns
computed properties as a dictionary. Same property from different
chunks will be concatenated to form a single tensor for a batch. The
input, i.e. species and coordinates of chunks, will also be batched by
:func:`torchani.utils.pad_and_batch` and copied to output.
chunks will be concatenated to form a single tensor for a batch.
"""
results = {k: [] for k in self}
for sc in species_coordinates:
for sx in species_x:
for k in self:
_, result = self[k](sc)
_, result = self[k](sx)
results[k].append(result)
for k in self:
results[k] = torch.cat(results[k])
results['species'], results['coordinates'] = \
utils.pad_and_batch(species_coordinates)
results['species'] = utils.pad([s for s, _ in species_x])
return results
......
import torch
def pad_and_batch(species_coordinates):
def pad(species):
"""Put different species together into single tensor.
If the species are from molecules of different number of total atoms, then
ghost atoms with atom type -1 will be added to make it fit into the same
shape.
Arguments:
species (:class:`collections.abc.Sequence`): sequence of species.
Species must be of shape ``(N, A)``, where ``N`` is the number of
3D structures, ``A`` is the number of atoms.
Returns:
:class:`torch.Tensor`: species batched together.
"""
max_atoms = max([s.shape[1] for s in species])
padded_species = []
for s in species:
natoms = s.shape[1]
if natoms < max_atoms:
padding = torch.full((s.shape[0], max_atoms - natoms), -1,
dtype=torch.long, device=s.device)
s = torch.cat([s, padding], dim=1)
padded_species.append(s)
return torch.cat(padded_species)
def pad_coordinates(species_coordinates):
"""Put different species and coordinates together into single tensor.
If the species and coordinates are from molecules of different number of
......@@ -124,4 +151,5 @@ class EnergyShifter(torch.nn.Module):
return species, energies + sae
__all__ = ['pad_and_batch', 'present_species', 'strip_redundant_padding']
__all__ = ['pad', 'pad_coordinates', 'present_species',
'strip_redundant_padding']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment