"plugins/amoeba/vscode:/vscode.git/clone" did not exist on "2584685c6ae87498dd4237df23c615e6ba349b5e"
utils.py 2.17 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
from typing import Any, Optional, Tuple

Chenggang Zhao's avatar
Chenggang Zhao committed
3
import torch
4
import torch.distributed as dist
Chenggang Zhao's avatar
Chenggang Zhao committed
5

lijian6's avatar
lijian6 committed
6
from .deep_ep_cpp import EventHandle
Chenggang Zhao's avatar
Chenggang Zhao committed
7
8
9
10
11
12
13
14
15
16
17


class EventOverlap:
    """
    A wrapper class to manage CUDA events, also for better overlapping convenience.

    Attributes:
        event: the CUDA event captured.
        extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
    """

lijian6's avatar
lijian6 committed
18
19
20
    def __init__(
        self, event: Optional[EventHandle] = None, extra_tensors: Optional[Tuple[torch.Tensor]] = None
    ) -> None:
Chenggang Zhao's avatar
Chenggang Zhao committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        """
        Initialize the class.

        Arguments:
            event: the CUDA event captured.
            extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
        """
        self.event = event

        # NOTES: we use extra tensors to achieve stream recording, otherwise,
        # stream recording will be incompatible with CUDA graph.
        self.extra_tensors = extra_tensors

    def current_stream_wait(self) -> None:
        """
        The current stream `torch.cuda.current_stream()` waits for the event to be finished.
        """
        assert self.event is not None
        self.event.current_stream_wait()

    def __enter__(self) -> Any:
        """
        Utility for overlapping and Python `with` syntax.

        You can overlap the kernels on the current stream with the following example:
        ```python
        event_overlap = event_after_all_to_all_kernels()
        with event_overlap():
            do_something_on_current_stream()
        # After exiting the `with` scope, the current stream with wait the event to be finished.
        ```
        """
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        """
        Utility for overlapping and Python `with` syntax.

        Please follow the example in the `__enter__` function.
        """
        if self.event is not None:
            self.event.current_stream_wait()
63
64
65
66
67
68
69
70
71


def check_nvlink_connections(group: dist.ProcessGroup):
    """
    Check NVLink connection between every pair of GPUs.

    Arguments:
        group: the communication group.
    """
lijian6's avatar
lijian6 committed
72
    # TODO