Unverified Commit 6e50a99e authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

benchmark each function (#201)

parent 6f2a3d5f
...@@ -92,6 +92,16 @@ def time_func(key, func): ...@@ -92,6 +92,16 @@ def time_func(key, func):
# enable timers # enable timers
torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine', torchani.aev.cutoff_cosine)
torchani.aev.radial_terms = time_func('torchani.aev.radial_terms', torchani.aev.radial_terms)
torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms)
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('torchani.aev.convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
nnp[0].forward = time_func('total', nnp[0].forward) nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward) nnp[1].forward = time_func('forward', nnp[1].forward)
...@@ -99,6 +109,9 @@ nnp[1].forward = time_func('forward', nnp[1].forward) ...@@ -99,6 +109,9 @@ nnp[1].forward = time_func('forward', nnp[1].forward)
start = timeit.default_timer() start = timeit.default_timer()
trainer.run(dataset, max_epochs=1) trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
for k in timers:
if k.startswith('torchani.'):
print(k, timers[k])
print('Total AEV:', timers['total']) print('Total AEV:', timers['total'])
print('NN:', timers['forward']) print('NN:', timers['forward'])
print('Epoch time:', elapsed) print('Epoch time:', elapsed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment