pynccl_wrapper.py 18.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
#  often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
#  contains many other potential cuda APIs, that are not allowed during
#  capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.

import ctypes
26
import functools
27
28
import platform
from dataclasses import dataclass
29
from typing import Any
30
31
32
33

import torch
from torch.distributed import ReduceOp

34
from vllm import envs
35
from vllm.logger import init_logger
36
from vllm.platforms import current_platform
37
from vllm.utils.nccl import find_nccl_library
38
39
40
41
42
43
44
45
46

logger = init_logger(__name__)

# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in

ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
47
ncclWindow_t = ctypes.c_void_p
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


class ncclUniqueId(ctypes.Structure):
    _fields_ = [("internal", ctypes.c_byte * 128)]


cudaStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p

ncclDataType_t = ctypes.c_int


class ncclDataTypeEnum:
    ncclInt8 = 0
    ncclChar = 0
    ncclUint8 = 1
    ncclInt32 = 2
    ncclInt = 2
    ncclUint32 = 3
    ncclInt64 = 4
    ncclUint64 = 5
    ncclFloat16 = 6
    ncclHalf = 6
    ncclFloat32 = 7
    ncclFloat = 7
    ncclFloat64 = 8
    ncclDouble = 8
    ncclBfloat16 = 9
76
77
    ncclFloat8e4m3 = 10
    ncclNumTypes = 11
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    @classmethod
    @functools.lru_cache(maxsize=1)
    def _torch_to_nccl_map(cls) -> dict[torch.dtype, int]:
        return {
            torch.int8: cls.ncclInt8,
            torch.uint8: cls.ncclUint8,
            torch.int32: cls.ncclInt32,
            torch.int64: cls.ncclInt64,
            torch.float16: cls.ncclFloat16,
            torch.float32: cls.ncclFloat32,
            torch.float64: cls.ncclFloat64,
            torch.bfloat16: cls.ncclBfloat16,
            current_platform.fp8_dtype(): cls.ncclFloat8e4m3,
        }

    @classmethod
    def supports_torch_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls._torch_to_nccl_map()

    @classmethod
    def try_from_torch(cls, dtype: torch.dtype) -> int | None:
        return cls._torch_to_nccl_map().get(dtype)

102
103
    @classmethod
    def from_torch(cls, dtype: torch.dtype) -> int:
104
105
106
        nccl_dtype = cls.try_from_torch(dtype)
        if nccl_dtype is not None:
            return nccl_dtype
107
108
        raise ValueError(
            f"Unsupported dtype {dtype}: should be one of "
109
110
            f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
            " float8e4m3."
111
        )
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143


ncclRedOp_t = ctypes.c_int


class ncclRedOpTypeEnum:
    ncclSum = 0
    ncclProd = 1
    ncclMax = 2
    ncclMin = 3
    ncclAvg = 4
    ncclNumOps = 5

    @classmethod
    def from_torch(cls, op: ReduceOp) -> int:
        if op == ReduceOp.SUM:
            return cls.ncclSum
        if op == ReduceOp.PRODUCT:
            return cls.ncclProd
        if op == ReduceOp.MAX:
            return cls.ncclMax
        if op == ReduceOp.MIN:
            return cls.ncclMin
        if op == ReduceOp.AVG:
            return cls.ncclAvg
        raise ValueError(f"Unsupported op: {op}")


@dataclass
class Function:
    name: str
    restype: Any
144
    argtypes: list[Any]
145
146
147
148
149
150
151


class NCCLLibrary:
    exported_functions = [
        # const char* ncclGetErrorString(ncclResult_t result)
        Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
        # ncclResult_t  ncclGetVersion(int *version);
152
        Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
153
        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
154
        Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
155
156
157
158
        # ncclResult_t  ncclCommInitRank(
        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
        # note that ncclComm_t is a pointer type, so the first argument
        # is a pointer to a pointer
159
160
161
162
163
        Function(
            "ncclCommInitRank",
            ncclResult_t,
            [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
        ),
164
165
166
167
168
169
        # ncclResult_t  ncclAllReduce(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
        #   cudaStream_t stream);
        # note that cudaStream_t is a pointer type, so the last argument
        # is a pointer
170
171
172
173
174
175
176
177
178
179
180
181
182
        Function(
            "ncclAllReduce",
            ncclResult_t,
            [
                buffer_type,
                buffer_type,
                ctypes.c_size_t,
                ncclDataType_t,
                ncclRedOp_t,
                ncclComm_t,
                cudaStream_t,
            ],
        ),
183
184
185
186
187
188
        # ncclResult_t  ncclReduce(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, ncclRedOp_t op, int root,
        #   ncclComm_t comm,  cudaStream_t stream);
        # note that cudaStream_t is a pointer type, so the last argument
        # is a pointer
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        Function(
            "ncclReduce",
            ncclResult_t,
            [
                buffer_type,
                buffer_type,
                ctypes.c_size_t,
                ncclDataType_t,
                ncclRedOp_t,
                ctypes.c_int,
                ncclComm_t,
                cudaStream_t,
            ],
        ),
203
204
205
206
207
208
        # ncclResult_t  ncclAllGather(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, ncclComm_t comm,
        #   cudaStream_t stream);
        # note that cudaStream_t is a pointer type, so the last argument
        # is a pointer
209
210
211
212
213
214
215
216
217
218
219
220
        Function(
            "ncclAllGather",
            ncclResult_t,
            [
                buffer_type,
                buffer_type,
                ctypes.c_size_t,
                ncclDataType_t,
                ncclComm_t,
                cudaStream_t,
            ],
        ),
221
222
223
224
225
226
        # ncclResult_t  ncclReduceScatter(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
        #   cudaStream_t stream);
        # note that cudaStream_t is a pointer type, so the last argument
        # is a pointer
227
228
229
230
231
232
233
234
235
236
237
238
239
        Function(
            "ncclReduceScatter",
            ncclResult_t,
            [
                buffer_type,
                buffer_type,
                ctypes.c_size_t,
                ncclDataType_t,
                ncclRedOp_t,
                ncclComm_t,
                cudaStream_t,
            ],
        ),
240
241
242
        # ncclResult_t  ncclSend(
        #   const void* sendbuff, size_t count, ncclDataType_t datatype,
        #   int dest, ncclComm_t comm, cudaStream_t stream);
243
244
245
246
247
248
249
250
251
252
253
254
        Function(
            "ncclSend",
            ncclResult_t,
            [
                buffer_type,
                ctypes.c_size_t,
                ncclDataType_t,
                ctypes.c_int,
                ncclComm_t,
                cudaStream_t,
            ],
        ),
255
256
257
        # ncclResult_t  ncclRecv(
        #   void* recvbuff, size_t count, ncclDataType_t datatype,
        #   int src, ncclComm_t comm, cudaStream_t stream);
258
259
260
261
262
263
264
265
266
267
268
269
        Function(
            "ncclRecv",
            ncclResult_t,
            [
                buffer_type,
                ctypes.c_size_t,
                ncclDataType_t,
                ctypes.c_int,
                ncclComm_t,
                cudaStream_t,
            ],
        ),
270
271
272
273
        # ncclResult_t ncclBroadcast(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, int root, ncclComm_t comm,
        #   cudaStream_t stream);
274
275
276
277
278
279
280
281
282
283
284
285
286
        Function(
            "ncclBroadcast",
            ncclResult_t,
            [
                buffer_type,
                buffer_type,
                ctypes.c_size_t,
                ncclDataType_t,
                ctypes.c_int,
                ncclComm_t,
                cudaStream_t,
            ],
        ),
287
288
289
290
291
292
        # be cautious! this is a collective call, it will block until all
        # processes in the communicator have called this function.
        # because Python object destruction can happen in random order,
        # it is better not to call it at all.
        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);
        Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
293
294
295
296
        # ncclResult_t ncclGroupStart();
        Function("ncclGroupStart", ncclResult_t, []),
        # ncclResult_t ncclGroupEnd();
        Function("ncclGroupEnd", ncclResult_t, []),
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        # ncclResult_t ncclCommWindowRegister(
        #   ncclComm_t comm, void* buff, size_t size,
        #   ncclWindow_t* win, int winFlags);
        Function(
            "ncclCommWindowRegister",
            ncclResult_t,
            [
                ncclComm_t,
                buffer_type,
                ctypes.c_size_t,
                ctypes.POINTER(ncclWindow_t),
                ctypes.c_int,
            ],
        ),
        # ncclResult_t ncclCommWindowDeregister(
        #   ncclComm_t comm, ncclWindow_t win);
313
        Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
314
315
316
317
    ]

    # class attribute to store the mapping from the path to the library
    # to avoid loading the same library multiple times
318
    path_to_library_cache: dict[str, Any] = {}
319
320
321

    # class attribute to store the mapping from library path
    #  to the corresponding dictionary
322
    path_to_dict_mapping: dict[str, dict[str, Any]] = {}
323

324
    def __init__(self, so_file: str | None = None):
325
326
327
        so_file = so_file or find_nccl_library()

        try:
328
329
330
331
            if so_file not in NCCLLibrary.path_to_dict_mapping:
                lib = ctypes.CDLL(so_file)
                NCCLLibrary.path_to_library_cache[so_file] = lib
            self.lib = NCCLLibrary.path_to_library_cache[so_file]
332
333
        except Exception as e:
            logger.error(
334
                "Failed to load NCCL library from %s. "
335
336
                "It is expected if you are not running on NVIDIA/AMD GPUs."
                "Otherwise, the nccl library might not exist, be corrupted "
337
                "or it does not support the current platform %s. "
338
339
                "If you already have the library, please set the "
                "environment variable VLLM_NCCL_SO_PATH"
340
341
342
343
                " to point to the correct nccl library path.",
                so_file,
                platform.platform(),
            )
344
345
346
            raise e

        if so_file not in NCCLLibrary.path_to_dict_mapping:
347
            _funcs: dict[str, Any] = {}
348
            for func in NCCLLibrary.exported_functions:
349
350
351
352
353
354
355
                try:
                    f = getattr(self.lib, func.name)
                    f.restype = func.restype
                    f.argtypes = func.argtypes
                    _funcs[func.name] = f
                except AttributeError:
                    if func.name in [
356
357
                        "ncclCommWindowRegister",
                        "ncclCommWindowDeregister",
358
359
360
361
362
363
                    ]:
                        if envs.VLLM_USE_NCCL_SYMM_MEM:
                            logger.warning_once(
                                "The symbol %s is not found in the NCCL "
                                "library %s. To enable VLLM_USE_NCCL_SYMM_MEM "
                                " please update your NCCL version to >= "
364
365
366
367
                                "2.27.03.",
                                func.name,
                                so_file,
                            )
368
369
370
371
372
                        if current_platform.is_rocm():
                            # Having an exception here on ROCm platform is
                            # not allowed during graph capturing
                            continue
                    raise
373
374
375
376
377
378
379
380
381
382
383
            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]

    def ncclGetErrorString(self, result: ncclResult_t) -> str:
        return self._funcs["ncclGetErrorString"](result).decode("utf-8")

    def NCCL_CHECK(self, result: ncclResult_t) -> None:
        if result != 0:
            error_str = self.ncclGetErrorString(result)
            raise RuntimeError(f"NCCL error: {error_str}")

384
    def ncclGetRawVersion(self) -> int:
385
386
        version = ctypes.c_int()
        self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
387
388
389
390
391
        # something like 21903
        return version.value

    def ncclGetVersion(self) -> str:
        version_str = str(self.ncclGetRawVersion())
392
393
394
395
396
397
398
399
        # something like 21903 --> "2.19.3"
        major = version_str[0].lstrip("0")
        minor = version_str[1:3].lstrip("0")
        patch = version_str[3:].lstrip("0")
        return f"{major}.{minor}.{patch}"

    def ncclGetUniqueId(self) -> ncclUniqueId:
        unique_id = ncclUniqueId()
400
        self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
401
402
        return unique_id

403
404
405
    def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
        if len(data) != 128:
            raise ValueError(
406
407
                f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes"
            )
408
409
410
411
        unique_id = ncclUniqueId()
        ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
        return unique_id

412
413
414
    def ncclCommInitRank(
        self, world_size: int, unique_id: ncclUniqueId, rank: int
    ) -> ncclComm_t:
415
        comm = ncclComm_t()
416
417
418
419
420
        self.NCCL_CHECK(
            self._funcs["ncclCommInitRank"](
                ctypes.byref(comm), world_size, unique_id, rank
            )
        )
421
422
        return comm

423
424
425
426
427
428
429
430
431
432
    def ncclAllReduce(
        self,
        sendbuff: buffer_type,
        recvbuff: buffer_type,
        count: int,
        datatype: int,
        op: int,
        comm: ncclComm_t,
        stream: cudaStream_t,
    ) -> None:
433
434
435
436
437
        # `datatype` actually should be `ncclDataType_t`
        # and `op` should be `ncclRedOp_t`
        # both are aliases of `ctypes.c_int`
        # when we pass int to a function, it will be converted to `ctypes.c_int`
        # by ctypes automatically
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        self.NCCL_CHECK(
            self._funcs["ncclAllReduce"](
                sendbuff, recvbuff, count, datatype, op, comm, stream
            )
        )

    def ncclReduce(
        self,
        sendbuff: buffer_type,
        recvbuff: buffer_type,
        count: int,
        datatype: int,
        op: int,
        root: int,
        comm: ncclComm_t,
        stream: cudaStream_t,
    ) -> None:
455
456
457
458
459
        # `datatype` actually should be `ncclDataType_t`
        # and `op` should be `ncclRedOp_t`
        # both are aliases of `ctypes.c_int`
        # when we pass int to a function, it will be converted to `ctypes.c_int`
        # by ctypes automatically
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        self.NCCL_CHECK(
            self._funcs["ncclReduce"](
                sendbuff, recvbuff, count, datatype, op, root, comm, stream
            )
        )

    def ncclReduceScatter(
        self,
        sendbuff: buffer_type,
        recvbuff: buffer_type,
        count: int,
        datatype: int,
        op: int,
        comm: ncclComm_t,
        stream: cudaStream_t,
    ) -> None:
476
477
478
479
480
        # `datatype` actually should be `ncclDataType_t`
        # and `op` should be `ncclRedOp_t`
        # both are aliases of `ctypes.c_int`
        # when we pass int to a function, it will be converted to `ctypes.c_int`
        # by ctypes automatically
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        self.NCCL_CHECK(
            self._funcs["ncclReduceScatter"](
                sendbuff, recvbuff, count, datatype, op, comm, stream
            )
        )

    def ncclAllGather(
        self,
        sendbuff: buffer_type,
        recvbuff: buffer_type,
        count: int,
        datatype: int,
        comm: ncclComm_t,
        stream: cudaStream_t,
    ) -> None:
496
497
498
499
        # `datatype` actually should be `ncclDataType_t`
        # which is an aliases of `ctypes.c_int`
        # when we pass int to a function, it will be converted to `ctypes.c_int`
        # by ctypes automatically
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        self.NCCL_CHECK(
            self._funcs["ncclAllGather"](
                sendbuff, recvbuff, count, datatype, comm, stream
            )
        )

    def ncclSend(
        self,
        sendbuff: buffer_type,
        count: int,
        datatype: int,
        dest: int,
        comm: ncclComm_t,
        stream: cudaStream_t,
    ) -> None:
        self.NCCL_CHECK(
            self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
        )

    def ncclRecv(
        self,
        recvbuff: buffer_type,
        count: int,
        datatype: int,
        src: int,
        comm: ncclComm_t,
        stream: cudaStream_t,
    ) -> None:
        self.NCCL_CHECK(
            self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
        )

    def ncclBroadcast(
        self,
        sendbuff: buffer_type,
        recvbuff: buffer_type,
        count: int,
        datatype: int,
        root: int,
        comm: ncclComm_t,
        stream: cudaStream_t,
    ) -> None:
        self.NCCL_CHECK(
            self._funcs["ncclBroadcast"](
                sendbuff, recvbuff, count, datatype, root, comm, stream
            )
        )
547

548
549
550
    def ncclCommDestroy(self, comm: ncclComm_t) -> None:
        self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

551
552
553
554
555
556
    def ncclGroupStart(self) -> None:
        self.NCCL_CHECK(self._funcs["ncclGroupStart"]())

    def ncclGroupEnd(self) -> None:
        self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())

557
558
559
    def ncclCommWindowRegister(
        self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
    ) -> ncclWindow_t:
560
        window = ncclWindow_t()
561
562
563
564
565
        self.NCCL_CHECK(
            self._funcs["ncclCommWindowRegister"](
                comm, buff, size, ctypes.byref(window), win_flags
            )
        )
566
567
        return window

568
    def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
569
570
        self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))

571
572

__all__ = [
573
574
575
576
577
578
579
    "NCCLLibrary",
    "ncclDataTypeEnum",
    "ncclRedOpTypeEnum",
    "ncclUniqueId",
    "ncclComm_t",
    "cudaStream_t",
    "buffer_type",
580
]