Unverified Commit a6d7fba3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add an argument for inserting synchronization (#430)



* Add an argument for inserting synchronization

* Update training-benchmark.py

* Update training-benchmark.py
Co-authored-by: default avatarFarhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
parent c18f4a5e
......@@ -6,6 +6,8 @@ import argparse
import pkbar
from torchani.units import hartree2kcalmol
synchronize = False
def atomic():
model = torch.nn.Sequential(
......@@ -26,6 +28,8 @@ def time_func(key, func):
def wrapper(*args, **kwargs):
start = timeit.default_timer()
ret = func(*args, **kwargs)
if synchronize:
torch.cuda.synchronize()
end = timeit.default_timer()
timers[key] += end - start
return ret
......@@ -60,12 +64,18 @@ if __name__ == "__main__":
dest='dataset',
action='store_const',
const='cache')
parser.add_argument('-y', '--synchronize',
action='store_true',
help='whether to insert torch.cuda.synchronize() at the end of each function')
parser.set_defaults(dataset='shuffle')
parser.add_argument('-n', '--num_epochs',
help='epochs',
default=1, type=int)
parser = parser.parse_args()
if parser.synchronize:
synchronize = True
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=parser.device)
......@@ -89,6 +99,7 @@ if __name__ == "__main__":
torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms)
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.neighbor_pairs_nopbc = time_func('torchani.aev.neighbor_pairs_nopbc', torchani.aev.neighbor_pairs_nopbc)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
......@@ -168,6 +179,8 @@ if __name__ == "__main__":
optimizer.step()
progbar.update(i, values=[("rmse", rmse)])
if synchronize:
torch.cuda.synchronize()
stop = time.time()
print('=> more detail about benchmark')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment