pynccl.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
#  often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
#  contains many other potential cuda APIs, that are not allowed during
#  capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.

import ctypes
23
import platform
24
from typing import Optional, Union
25
26
27
28

# ===================== import region =====================
import torch
import torch.distributed as dist
29
from torch.distributed import ProcessGroup, ReduceOp
30

31
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
32
from vllm.logger import init_logger
33
from vllm.utils import find_nccl_library, nccl_integrity_check
34
35

logger = init_logger(__name__)
36

37
so_file = find_nccl_library()
38
39

try:
40
41
42
    # load the library in another process.
    # if it core dumps, it will not crash the current process
    nccl_integrity_check(so_file)
43
44
45
46
47
    nccl = ctypes.CDLL(so_file)
except Exception as e:
    logger.error(
        f"Failed to load NCCL library from {so_file} ."
        "It is expected if you are not running on NVIDIA/AMD GPUs."
48
49
50
51
52
53
        "Otherwise, the nccl library might not exist, be corrupted "
        f"or it does not support the current platform {platform.platform()}."
        f"One solution is to download libnccl2 version 2.18 from "
        f"https://developer.download.nvidia.com/compute/cuda/repos/ "
        f"and extract the libnccl.so.2 file. If you already have the "
        f"library, please set the environment variable VLLM_NCCL_SO_PATH"
54
55
56
57
58
59
60
61
62
        " to point to the correct nccl library path.")
    raise e

# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in

ncclResult_t = ctypes.c_int

63
64
65
66
67
68
69
70
71
72
73
74
_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]


def NCCL_CHECK(result: ncclResult_t) -> None:
    if result != 0:
        error_str = _c_ncclGetErrorString(result)
        error_str = error_str.decode("utf-8")
        raise RuntimeError(f"NCCL error: {error_str}")


75
76
77
78
79
80
81
82
83
# equivalent to c declaration:
# ncclResult_t  ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
_c_ncclGetVersion.restype = ctypes.c_int
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]


def ncclGetVersion() -> str:
    version = ctypes.c_int()
84
    NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    # something like 21903 --> "2.19.3"
    version_str = str(version.value)
    major = version_str[0].lstrip("0")
    minor = version_str[1:3].lstrip("0")
    patch = version_str[3:].lstrip("0")
    return f"{major}.{minor}.{patch}"


class NcclUniqueId(ctypes.Structure):
    _fields_ = [("internal", ctypes.c_byte * 128)]


# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
_c_ncclGetUniqueId.restype = ctypes.c_int
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]


def ncclGetUniqueId() -> NcclUniqueId:
    unique_id = NcclUniqueId()
106
    NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    return unique_id


# equivalent to c declaration:
# ncclResult_t  ncclCommInitRank(
#   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank = nccl.ncclCommInitRank
_c_ncclCommInitRank.restype = ctypes.c_int
_c_ncclCommInitRank.argtypes = [
    ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]

121
ncclDataType_t = ctypes.c_int
122

123
124

class ncclDataTypeEnum:
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    ncclInt8 = 0
    ncclChar = 0
    ncclUint8 = 1
    ncclInt32 = 2
    ncclInt = 2
    ncclUint32 = 3
    ncclInt64 = 4
    ncclUint64 = 5
    ncclFloat16 = 6
    ncclHalf = 6
    ncclFloat32 = 7
    ncclFloat = 7
    ncclFloat64 = 8
    ncclDouble = 8
    ncclBfloat16 = 9
    ncclNumTypes = 10

    @classmethod
143
    def from_torch(cls, dtype: torch.dtype) -> int:
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        if dtype == torch.int8:
            return cls.ncclInt8
        if dtype == torch.uint8:
            return cls.ncclUint8
        if dtype == torch.int32:
            return cls.ncclInt32
        if dtype == torch.int64:
            return cls.ncclInt64
        if dtype == torch.float16:
            return cls.ncclFloat16
        if dtype == torch.float32:
            return cls.ncclFloat32
        if dtype == torch.float64:
            return cls.ncclFloat64
        if dtype == torch.bfloat16:
            return cls.ncclBfloat16
        raise ValueError(f"Unsupported dtype: {dtype}")


163
164
165
166
ncclRedOp_t = ctypes.c_int


class ncclRedOpTypeEnum:
167
168
169
170
171
172
173
174
    ncclSum = 0
    ncclProd = 1
    ncclMax = 2
    ncclMin = 3
    ncclAvg = 4
    ncclNumOps = 5

    @classmethod
175
    def from_torch(cls, op: ReduceOp) -> int:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        if op == ReduceOp.SUM:
            return cls.ncclSum
        if op == ReduceOp.PRODUCT:
            return cls.ncclProd
        if op == ReduceOp.MAX:
            return cls.ncclMax
        if op == ReduceOp.MIN:
            return cls.ncclMin
        if op == ReduceOp.AVG:
            return cls.ncclAvg
        raise ValueError(f"Unsupported op: {op}")


# equivalent to c declaration:
# ncclResult_t  ncclAllReduce(
#   const void* sendbuff, void* recvbuff, size_t count,
#   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
#   udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
198
199
    ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
    ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
200
201
202
203
204
205
206
207
208
209
210
211
212
]

# equivalent to c declaration:
# ncclResult_t  ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy
_c_ncclCommDestroy.restype = ctypes.c_int
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]


class NCCLCommunicator:

    def __init__(
        self,
213
214
        group: Optional[ProcessGroup] = None,
        device: Optional[Union[int, str, torch.device]] = None,
215
    ):
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
            device: the device to bind the NCCLCommunicator to. If None,
                it will be bind to f"cuda:{local_rank}".
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
        assert dist.is_initialized()
        group = get_cpu_world_group() if group is None else group
        assert dist.get_backend(group) != dist.Backend.NCCL, (
            "NCCLCommunicator should be attached to a non-NCCL group.")
        self.group = group
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
232
        if self.rank == 0:
233
234
235
            self.unique_id = ncclGetUniqueId()
        else:
            self.unique_id = NcclUniqueId()
236
237
238
        tensor = torch.ByteTensor(list(self.unique_id.internal))
        dist.broadcast(tensor, src=0, group=group)
        byte_list = tensor.tolist()
239
240
241
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
        self.comm = ctypes.c_void_p()
242
243
244
245
246
247
248
249
250
251
252
        if device is None:
            local_rank = get_local_rank()
            device = torch.device(f"cuda:{local_rank}")
        elif isinstance(device, int):
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
253
254
255
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
256
257
258
259
            NCCL_CHECK(
                _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
                                    self.unique_id, self.rank))
            self.stream = torch.cuda.Stream()
260
261
262
263
264

    def all_reduce(self,
                   tensor: torch.Tensor,
                   op: ReduceOp = ReduceOp.SUM,
                   stream=None):
265
266
267
268
269
270
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
271
272
        if stream is None:
            stream = self.stream
273
274
275
276
277
278
279
        NCCL_CHECK(
            _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
                             ctypes.c_void_p(tensor.data_ptr()),
                             tensor.numel(),
                             ncclDataTypeEnum.from_torch(tensor.dtype),
                             ncclRedOpTypeEnum.from_torch(op), self.comm,
                             ctypes.c_void_p(stream.cuda_stream)))
280
281

    def __del__(self):
282
283
284
        # `dist` module might have been already destroyed
        if hasattr(dist, 'destroy_process_group'):
            dist.destroy_process_group()
285
286
287
        # function might have been already destroyed
        if _c_ncclCommDestroy is not None:
            _c_ncclCommDestroy(self.comm)