cupy_utils.py 3.89 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""CuPy utilities for all-reduce.

We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.

NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib

import torch
from torch.distributed import ReduceOp

try:
    import cupy
    from cupy.cuda import nccl
    from cupyx.distributed import NCCLBackend
except ImportError as e:
    cupy = e
    nccl = None

    class NCCLBackend:
        ...


_OP_MAPPING = {
    ReduceOp.SUM: "sum",
    ReduceOp.PRODUCT: "prod",
    ReduceOp.MIN: "min",
    ReduceOp.MAX: "max",
}


class NCCLBackendWithBFloat16(NCCLBackend):
    # This is enough to add bfloat16 support for most operations,
    # but broadcast will fail (will require changes in compiled
    # cupy code).
    def _get_nccl_dtype_and_count(self, array, count=None):
        nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
        torch_dtype = getattr(array, "_torch_dtype", None)
        if torch_dtype is torch.bfloat16:
            nccl_dtype = nccl.NCCL_BFLOAT16
        return nccl_dtype, count

    def barrier(self) -> None:
        raise RuntimeError(
            "Currently, CuPy NCCL barrier is not supported since the TCP "
            "store is immediately stopped after the initialization.")


_NCCL_BACKEND = None
_WORLD_SIZE = 0


def is_initialized() -> bool:
    """Returns whether the NCCL backend is initialized."""
    return _NCCL_BACKEND is not None


@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream):
    """Set the cuda stream for communication"""
    cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
                                           stream.device_index)
    with cupy_stream:
        yield


def init_process_group(world_size: int, rank: int, host: str,
                       port: int) -> None:
    """Initializes the CuPy NCCL backend.

    # TODO: handle NCCL timeouts.
    """
    assert not is_initialized()

    if isinstance(cupy, Exception):
        raise ImportError(
            "NCCLBackend is not available. Please install cupy.") from cupy

    # TODO(woosuk): Create TP and PP process groups for CuPy.
    global _NCCL_BACKEND
    global _WORLD_SIZE
    assert world_size > 0, f"{world_size=} should be a positive integer"
    assert 0 <= rank < world_size, (
        f"{rank=} should be a integer between [0, {world_size})")

    cupy.cuda.runtime.setDevice(torch.cuda.current_device())
    _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
    _WORLD_SIZE = world_size

    # Stop the TCP store to prevent the deadlock issues at termination time.
    # FIXME(woosuk): This is hacky. Find a more robust solution.
    if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
        _NCCL_BACKEND._store.stop()


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
    """All-reduces the input tensor across the process group."""
    assert input_.is_cuda, f"{input_} should be a cuda tensor"
    # Hack to support bfloat16
    torch_dtype = input_.dtype
    if torch_dtype is torch.bfloat16:
        # We need to view as float16, otherwise
        # cupy will fail. This will not change
        # the underlying data.
        input_ = input_.view(torch.float16)
    cupy_input = cupy.asarray(input_)
    cupy_input._torch_dtype = torch_dtype  # pylint: disable=protected-access
    _NCCL_BACKEND.all_reduce(in_array=cupy_input,
                             out_array=cupy_input,
                             op=_OP_MAPPING[op])


def destroy_process_group() -> None:
    """Destroys the NCCL backend."""
    global _NCCL_BACKEND
    global _WORLD_SIZE
    _NCCL_BACKEND = None
    _WORLD_SIZE = 0


def get_world_size() -> int:
    """Returns the world size."""
    return _WORLD_SIZE


def get_nccl_backend():
    return _NCCL_BACKEND