from collections.abc import Callable import torch def maybe_execute_in_stream( fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs ): if STORE_STREAM is not None: tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)] obj = getattr(fn, "__self__", None) if isinstance(obj, torch.Tensor): tensors.append(obj) STORE_STREAM.wait_stream(torch.cuda.default_stream()) # Some PyTorch builds don't make `torch.cuda.Stream` a context manager. # The portable API is `torch.cuda.stream(stream)`. stream_ctx = ( STORE_STREAM if hasattr(STORE_STREAM, "__enter__") else torch.cuda.stream(STORE_STREAM) ) with stream_ctx: output = fn(*args, **kwargs) for t in tensors: t.record_stream(STORE_STREAM) if isinstance(output, tuple): for o in output: if isinstance(o, torch.Tensor): o.record_stream(torch.cuda.default_stream()) elif isinstance(output, torch.Tensor): output.record_stream(torch.cuda.default_stream()) return output else: return fn(*args, **kwargs)