import torch from functools import wraps from _thread import start_new_thread import torch.multiprocessing as mp def thread_wrapped_func(func): """Wrapped func for torch.multiprocessing.Process. With this wrapper we can use OMP threads in subprocesses otherwise, OMP_NUM_THREADS=1 is mandatory. How to use: @thread_wrapped_func def func_to_wrap(args ...): """ @wraps(func) def decorated_function(*args, **kwargs): queue = mp.Queue() def _queue_result(): exception, trace, res = None, None, None try: res = func(*args, **kwargs) except Exception as e: exception = e trace = traceback.format_exc() queue.put((res, exception, trace)) start_new_thread(_queue_result, ()) result, exception, trace = queue.get() if exception is None: return result else: assert isinstance(exception, Exception) raise exception.__class__(trace) return decorated_function def shuffle_walks(walks): seeds = torch.randperm(walks.size()[0]) return walks[seeds]