cupy_utils.py 3.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""CuPy utilities for all-reduce.

We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.

TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib

import torch
from torch.distributed import ReduceOp

try:
    import cupy
    from cupyx.distributed import NCCLBackend
    from cupy.cuda import nccl
except ImportError as e:
    cupy = e
    nccl = None

    class NCCLBackend:
        ...


_OP_MAPPING = {
    ReduceOp.SUM: "sum",
    ReduceOp.PRODUCT: "prod",
    ReduceOp.MIN: "min",
    ReduceOp.MAX: "max",
}


class NCCLBackendWithBFloat16(NCCLBackend):
    # This is enough to add bfloat16 support for most operations,
    # but broadcast will fail (will require changes in compiled
    # cupy code).
    def _get_nccl_dtype_and_count(self, array, count=None):
        nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
        torch_dtype = getattr(array, "_torch_dtype", None)
        if torch_dtype is torch.bfloat16:
            nccl_dtype = nccl.NCCL_BFLOAT16
        return nccl_dtype, count


_NCCL_BACKEND = None
_WORLD_SIZE = 0


def is_initialized() -> bool:
    """Returns whether the NCCL backend is initialized."""
    return _NCCL_BACKEND is not None


@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream) -> None:
    """Set the cuda stream for communication"""
    cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
                                           stream.device_index)
    with cupy_stream:
        yield


def init_process_group(world_size: int, rank: int, host: str,
                       port: int) -> None:
    """Initializes the CuPy NCCL backend.

    # TODO: handle NCCL timeouts.
    """
    assert not is_initialized()

    if isinstance(cupy, Exception):
        raise ImportError(
            "NCCLBackend is not available. Please install cupy.") from cupy

    # TODO(woosuk): Create TP and PP process groups for CuPy.
    global _NCCL_BACKEND
    global _WORLD_SIZE
    assert world_size > 0, f"{world_size=} should be a positive integer"
    assert 0 <= rank < world_size, (
        f"{rank=} should be a integer between [0, {world_size})")

    cupy.cuda.runtime.setDevice(torch.cuda.current_device())
    _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
    _WORLD_SIZE = world_size


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
    """All-reduces the input tensor across the process group."""
    assert input_.is_cuda, f"{input_} should be a cuda tensor"
    # Hack to support bfloat16
    torch_dtype = input_.dtype
    if torch_dtype is torch.bfloat16:
        # We need to view as float16, otherwise
        # cupy will fail. This will not change
        # the underlying data.
        input_ = input_.view(torch.float16)
    cupy_input = cupy.asarray(input_)
    cupy_input._torch_dtype = torch_dtype  # pylint: disable=protected-access
    _NCCL_BACKEND.all_reduce(in_array=cupy_input,
                             out_array=cupy_input,
                             op=_OP_MAPPING[op])


def destroy_process_group() -> None:
    """Destroys the NCCL backend."""
    global _NCCL_BACKEND
    global _WORLD_SIZE
    _NCCL_BACKEND = None
    _WORLD_SIZE = 0


def get_world_size() -> int:
    """Returns the world size."""
    return _WORLD_SIZE