Commit ffb075e6 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Simplify triple_by_molecule (#368)

* Simplify triple_by_molecule

* fix

* fix

* fix
parent 89ff3b46
...@@ -43,7 +43,6 @@ def enable_timers(model): ...@@ -43,7 +43,6 @@ def enable_timers(model):
torchani.aev.compute_shifts = time_func('compute_shifts', torchani.aev.compute_shifts) torchani.aev.compute_shifts = time_func('compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('neighbor_pairs', torchani.aev.neighbor_pairs) torchani.aev.neighbor_pairs = time_func('neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('triu_index', torchani.aev.triu_index) torchani.aev.triu_index = time_func('triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('cumsum_from_zero', torchani.aev.cumsum_from_zero) torchani.aev.cumsum_from_zero = time_func('cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('triple_by_molecule', torchani.aev.triple_by_molecule) torchani.aev.triple_by_molecule = time_func('triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('compute_aev', torchani.aev.compute_aev) torchani.aev.compute_aev = time_func('compute_aev', torchani.aev.compute_aev)
......
import torch
import ignite
import torchani
import timeit
import tqdm
import argparse
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('cache_path',
help='Path of the aev cache')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
ani1x = torchani.models.ANI1x()
consts = ani1x.consts
aev_computer = ani1x.aev_computer
shift_energy = ani1x.energy_shifter
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
model = torchani.ANIModel([atomic() for _ in range(4)])
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
nnp = torch.nn.Sequential(model, Flatten()).to(device)
dataset = torchani.data.AEVCacheLoader(parser.cache_path)
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
timers = {}
def time_func(key, func):
timers[key] = 0
def wrapper(*args, **kwargs):
start = timeit.default_timer()
ret = func(*args, **kwargs)
end = timeit.default_timer()
timers[key] += end - start
return ret
return wrapper
# enable timers
nnp[0].forward = time_func('forward', nnp[0].forward)
# run it!
start = timeit.default_timer()
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('NN:', timers['forward'])
print('Epoch time:', elapsed)
...@@ -93,7 +93,6 @@ if __name__ == "__main__": ...@@ -93,7 +93,6 @@ if __name__ == "__main__":
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts) torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs) torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index) torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('torchani.aev.convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero) torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule) torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev) torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
......
...@@ -174,31 +174,6 @@ def triu_index(num_species): ...@@ -174,31 +174,6 @@ def triu_index(num_species):
return ret return ret
def convert_pair_index(index):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
elem2: 1 2 2 3 3 3 4 4 4 4 ...
This function convert index back to elem1 and elem2
To implement this, divide it into groups, the first group contains 1
elements, the second contains 2 elements, ..., the nth group contains
n elements.
Let's say we want to compute the elem1 and elem2 for index i. We first find
the number of complete groups contained in index 0, 1, ..., i - 1
(all inclusive, not including i), then i will be in the next group. Let's
say there are N complete groups, then these N groups contains
N * (N + 1) / 2 elements, solving for the largest N that satisfies
N * (N + 1) / 2 <= i, will get the N we want.
"""
n = (torch.sqrt(1.0 + 8.0 * index.to(torch.float)) - 1.0) / 2.0
n = torch.floor(n).to(torch.long)
num_elems = n * (n + 1) / 2
return index - num_elems, n + 1
def cumsum_from_zero(input_): def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor # type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0) cumsum = torch.cumsum(input_, dim=0)
...@@ -219,7 +194,6 @@ def triple_by_molecule(atom_index1, atom_index2): ...@@ -219,7 +194,6 @@ def triple_by_molecule(atom_index1, atom_index2):
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
""" """
# convert representation from pair to central-others # convert representation from pair to central-others
n = atom_index1.shape[0]
ai1 = torch.cat([atom_index1, atom_index2]) ai1 = torch.cat([atom_index1, atom_index2])
sorted_ai1, rev_indices = ai1.sort() sorted_ai1, rev_indices = ai1.sort()
...@@ -228,17 +202,18 @@ def triple_by_molecule(atom_index1, atom_index2): ...@@ -228,17 +202,18 @@ def triple_by_molecule(atom_index1, atom_index2):
uniqued_central_atom_index = unique_results[0] uniqued_central_atom_index = unique_results[0]
counts = unique_results[-1] counts = unique_results[-1]
# do local combinations within unique key, assuming sorted # compute central_atom_index
pair_sizes = (counts * (counts - 1) / 2).long() pair_sizes = (counts * (counts - 1) / 2).long()
total_size = pair_sizes.sum()
pair_indices = torch.repeat_interleave(pair_sizes) pair_indices = torch.repeat_interleave(pair_sizes)
central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices) central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)
cumsum = cumsum_from_zero(pair_sizes)
cumsum = cumsum.index_select(0, pair_indices) # do local combinations within unique key, assuming sorted
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device, dtype=torch.long) - cumsum m = counts.max().item() if counts.numel() > 0 else 0
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index) n = pair_sizes.shape[0]
cumsum = cumsum_from_zero(counts) intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).t().unsqueeze(0).expand(n, -1, -1)
cumsum = cumsum.index_select(0, pair_indices) mask = (torch.arange(intra_pair_indices.shape[1], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
sorted_local_index1, sorted_local_index2 = intra_pair_indices.flatten(0, 1)[mask, :].unbind(-1)
cumsum = cumsum_from_zero(counts).index_select(0, pair_indices)
sorted_local_index1 += cumsum sorted_local_index1 += cumsum
sorted_local_index2 += cumsum sorted_local_index2 += cumsum
...@@ -247,6 +222,7 @@ def triple_by_molecule(atom_index1, atom_index2): ...@@ -247,6 +222,7 @@ def triple_by_molecule(atom_index1, atom_index2):
local_index2 = rev_indices[sorted_local_index2] local_index2 = rev_indices[sorted_local_index2]
# compute mapping between representation of central-other to pair # compute mapping between representation of central-other to pair
n = atom_index1.shape[0]
sign1 = ((local_index1 < n).to(torch.long) * 2) - 1 sign1 = ((local_index1 < n).to(torch.long) * 2) - 1
sign2 = ((local_index2 < n).to(torch.long) * 2) - 1 sign2 = ((local_index2 < n).to(torch.long) * 2) - 1
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2 return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment