Unverified Commit d47f39c0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add sync before functions and print a warning if not synching (#17) (#542)



* Make training benchmark print a warning if not synchronizing, and add a sync before function also

* Take triu_index out since it does not run

* slightly increase default precision

* flake8
Co-authored-by: default avatarIgnacio Pickering <ign.pickering@gmail.com>
parent 86ac402c
......@@ -5,9 +5,6 @@ import timeit
import argparse
import pkbar
from torchani.units import hartree2kcalmol
synchronize = False
H_network = torch.nn.Sequential(
torch.nn.Linear(384, 160),
torch.nn.CELU(0.1),
......@@ -53,6 +50,8 @@ def time_func(key, func):
timers[key] = 0
def wrapper(*args, **kwargs):
if synchronize:
torch.cuda.synchronize()
start = timeit.default_timer()
ret = func(*args, **kwargs)
if synchronize:
......@@ -64,6 +63,13 @@ def time_func(key, func):
return wrapper
def time_functions_in_module(module, function_names_list):
# Wrap all the functions from "function_names_list" from the module
# "module" with a timer
for n in function_names_list:
setattr(module, n, time_func(f'{module.__name__}.{n}', getattr(module, n)))
if __name__ == "__main__":
# parse command line arguments
parser = argparse.ArgumentParser()
......@@ -78,7 +84,7 @@ if __name__ == "__main__":
default=2560, type=int)
parser.add_argument('-y', '--synchronize',
action='store_true',
help='whether to insert torch.cuda.synchronize() at the end of each function')
help='whether to insert torch.cuda.synchronize() at the start and end of each function')
parser.add_argument('-n', '--num_epochs',
help='epochs',
default=1, type=int)
......@@ -86,6 +92,13 @@ if __name__ == "__main__":
if parser.synchronize:
synchronize = True
else:
synchronize = False
print('WARNING: Synchronization creates some small overhead but if CUDA'
' streams are not synchronized the timings before and after a'
' function do not reflect the actual calculation load that'
' function is performing. Only run this benchmark without'
' synchronization if you know very well what you are doing')
Rcr = 5.2000e+00
Rca = 3.5000e+00
......@@ -105,22 +118,19 @@ if __name__ == "__main__":
timers = {}
# enable timers
torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine', torchani.aev.cutoff_cosine)
torchani.aev.radial_terms = time_func('torchani.aev.radial_terms', torchani.aev.radial_terms)
torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms)
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.neighbor_pairs_nopbc = time_func('torchani.aev.neighbor_pairs_nopbc', torchani.aev.neighbor_pairs_nopbc)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
functions_to_time = ['cutoff_cosine', 'radial_terms', 'angular_terms',
'compute_shifts', 'neighbor_pairs',
'neighbor_pairs_nopbc', 'cumsum_from_zero',
'triple_by_molecule', 'compute_aev']
time_functions_in_module(torchani.aev, functions_to_time)
model[0].forward = time_func('total', model[0].forward)
model[1].forward = time_func('forward', model[1].forward)
print('=> loading dataset...')
shifter = torchani.EnergyShifter(None)
dataset = list(torchani.data.load(parser.dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle().collate(parser.batch_size))
dataset = torchani.data.load(parser.dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle().collate(parser.batch_size).cache()
print('=> start training')
start = time.time()
......@@ -149,7 +159,7 @@ if __name__ == "__main__":
print('=> more detail about benchmark')
for k in timers:
if k.startswith('torchani.'):
print('{} - {:.1f}s'.format(k, timers[k]))
print('Total AEV - {:.1f}s'.format(timers['total']))
print('NN - {:.1f}s'.format(timers['forward']))
print('Epoch time - {:.1f}s'.format(stop - start))
print('{} - {:.2f}s'.format(k, timers[k]))
print('Total AEV - {:.2f}s'.format(timers['total']))
print('NN - {:.2f}s'.format(timers['forward']))
print('Epoch time - {:.2f}s'.format(stop - start))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment