import functools import time from collections import defaultdict import torch time_maps = defaultdict(lambda :0.) count_maps = defaultdict(lambda :0.) def run_time(name): def middle(fn): def wrapper(*args, **kwargs): torch.cuda.synchronize() start = time.time() res = fn(*args, **kwargs) torch.cuda.synchronize() time_maps['%s : %s'%(name, fn.__name__) ] += time.time()-start count_maps['%s : %s'%(name, fn.__name__) ] +=1 print("%s : %s takes up %f "% (name, fn.__name__,time_maps['%s : %s'%(name, fn.__name__) ] /count_maps['%s : %s'%(name, fn.__name__) ] )) return res return wrapper return middle