helpers.py 1.29 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from collections.abc import Callable

import torch


def maybe_execute_in_stream(
    fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs
):
    if STORE_STREAM is not None:
        tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
        tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)]
        obj = getattr(fn, "__self__", None)
        if isinstance(obj, torch.Tensor):
            tensors.append(obj)
        STORE_STREAM.wait_stream(torch.cuda.default_stream())
        # Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
        # The portable API is `torch.cuda.stream(stream)`.
        stream_ctx = (
            STORE_STREAM
            if hasattr(STORE_STREAM, "__enter__")
            else torch.cuda.stream(STORE_STREAM)
        )
        with stream_ctx:
            output = fn(*args, **kwargs)
        for t in tensors:
            t.record_stream(STORE_STREAM)
        if isinstance(output, tuple):
            for o in output:
                if isinstance(o, torch.Tensor):
                    o.record_stream(torch.cuda.default_stream())
        elif isinstance(output, torch.Tensor):
            output.record_stream(torch.cuda.default_stream())
        return output
    else:
        return fn(*args, **kwargs)