pynccl.py 9.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
#  often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
#  contains many other potential cuda APIs, that are not allowed during
#  capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.

import ctypes
import datetime
24
import glob
25
26
27
28
29
30
31
import os

# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp

32
33
34
from vllm.logger import init_logger

logger = init_logger(__name__)
35
36
37

so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

38
39
40
41
42
43
44
45
46
# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
    cuda_major = torch.version.cuda.split(".")[0]
    path = os.path.expanduser(
        f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
    files = glob.glob(path)
    vllm_nccl_path = files[0] if files else None

47
48
49
50
51
52
# manually load the nccl library
if so_file:
    logger.info(
        f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
else:
    if torch.version.cuda is not None:
53
        so_file = vllm_nccl_path or "libnccl.so.2"
54
    elif torch.version.hip is not None:
55
        so_file = "librccl.so.1"
56
57
    else:
        raise ValueError("NCCL only supports CUDA and ROCm backends.")
58
    logger.info(f"Loading nccl from library {so_file}")
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

try:
    nccl = ctypes.CDLL(so_file)
except Exception as e:
    logger.error(
        f"Failed to load NCCL library from {so_file} ."
        "It is expected if you are not running on NVIDIA/AMD GPUs."
        "Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
        " to point to the correct nccl library path.")
    raise e

# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in

ncclResult_t = ctypes.c_int

# equivalent to c declaration:
# ncclResult_t  ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
_c_ncclGetVersion.restype = ctypes.c_int
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]


def ncclGetVersion() -> str:
    version = ctypes.c_int()
    result = _c_ncclGetVersion(ctypes.byref(version))
    assert result == 0
    # something like 21903 --> "2.19.3"
    version_str = str(version.value)
    major = version_str[0].lstrip("0")
    minor = version_str[1:3].lstrip("0")
    patch = version_str[3:].lstrip("0")
    return f"{major}.{minor}.{patch}"


class NcclUniqueId(ctypes.Structure):
    _fields_ = [("internal", ctypes.c_byte * 128)]


# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
_c_ncclGetUniqueId.restype = ctypes.c_int
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]


def ncclGetUniqueId() -> NcclUniqueId:
    unique_id = NcclUniqueId()
    result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
    assert result == 0
    return unique_id


# equivalent to c declaration:
# ncclResult_t  ncclCommInitRank(
#   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank = nccl.ncclCommInitRank
_c_ncclCommInitRank.restype = ctypes.c_int
_c_ncclCommInitRank.argtypes = [
    ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]


# enums
class ncclDataType_t(ctypes.c_int):
    ncclInt8 = 0
    ncclChar = 0
    ncclUint8 = 1
    ncclInt32 = 2
    ncclInt = 2
    ncclUint32 = 3
    ncclInt64 = 4
    ncclUint64 = 5
    ncclFloat16 = 6
    ncclHalf = 6
    ncclFloat32 = 7
    ncclFloat = 7
    ncclFloat64 = 8
    ncclDouble = 8
    ncclBfloat16 = 9
    ncclNumTypes = 10

    @classmethod
    def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
        if dtype == torch.int8:
            return cls.ncclInt8
        if dtype == torch.uint8:
            return cls.ncclUint8
        if dtype == torch.int32:
            return cls.ncclInt32
        if dtype == torch.int64:
            return cls.ncclInt64
        if dtype == torch.float16:
            return cls.ncclFloat16
        if dtype == torch.float32:
            return cls.ncclFloat32
        if dtype == torch.float64:
            return cls.ncclFloat64
        if dtype == torch.bfloat16:
            return cls.ncclBfloat16
        raise ValueError(f"Unsupported dtype: {dtype}")


class ncclRedOp_t(ctypes.c_int):
    ncclSum = 0
    ncclProd = 1
    ncclMax = 2
    ncclMin = 3
    ncclAvg = 4
    ncclNumOps = 5

    @classmethod
    def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
        if op == ReduceOp.SUM:
            return cls.ncclSum
        if op == ReduceOp.PRODUCT:
            return cls.ncclProd
        if op == ReduceOp.MAX:
            return cls.ncclMax
        if op == ReduceOp.MIN:
            return cls.ncclMin
        if op == ReduceOp.AVG:
            return cls.ncclAvg
        raise ValueError(f"Unsupported op: {op}")


# equivalent to c declaration:
# ncclResult_t  ncclAllReduce(
#   const void* sendbuff, void* recvbuff, size_t count,
#   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
#   udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
    ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
    ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
]

# equivalent to c declaration:
# ncclResult_t  ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy
_c_ncclCommDestroy.restype = ctypes.c_int
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]


class NCCLCommunicator:

    def __init__(
        self,
        backend=None,
        init_method=None,
        timeout=datetime.timedelta(seconds=10),
        world_size: int = -1,
        rank: int = -1,
        store=None,
        group_name: str = "",
        pg_options=None,
220
        local_rank: int = -1,
221
222
223
224
225
226
227
228
229
230
231
232
233
    ):
        if not dist.is_initialized():
            backend = backend or "nccl"
            assert backend == 'nccl', (
                "only use nccl backend for starting the NCCL communicator")
            dist.init_process_group(backend=backend,
                                    init_method=init_method,
                                    timeout=timeout,
                                    world_size=world_size,
                                    rank=rank,
                                    store=store,
                                    group_name=group_name,
                                    pg_options=pg_options)
234
235
236
237
238
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        if local_rank == -1:
            local_rank = self.rank
        self.local_rank = local_rank
239
240
241
242
243
        # don't use these args, as they can be -1
        # use `self.rank`, `self.local_rank` and `self.world_size` instead
        del world_size, rank, local_rank
        torch.cuda.set_device(self.local_rank)
        if self.rank == 0:
244
245
246
            self.unique_id = ncclGetUniqueId()
        else:
            self.unique_id = NcclUniqueId()
247
248
        tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
            self.local_rank)
249
250
251
252
253
        dist.broadcast(tensor, src=0)
        byte_list = tensor.cpu().tolist()
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
        self.comm = ctypes.c_void_p()
254
255
        result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
                                     self.unique_id, self.rank)
256
        assert result == 0
257
        self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    def all_reduce(self,
                   tensor: torch.Tensor,
                   op: ReduceOp = ReduceOp.SUM,
                   stream=None):
        if stream is None:
            stream = self.stream
        result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
                                  ctypes.c_void_p(tensor.data_ptr()),
                                  tensor.numel(),
                                  ncclDataType_t.from_torch(tensor.dtype),
                                  ncclRedOp_t.from_torch(op), self.comm,
                                  ctypes.c_void_p(stream.cuda_stream))
        assert result == 0

    def __del__(self):
274
275
276
        # `dist` module might have been already destroyed
        if hasattr(dist, 'destroy_process_group'):
            dist.destroy_process_group()
277
278
279
        # function might have been already destroyed
        if _c_ncclCommDestroy is not None:
            _c_ncclCommDestroy(self.comm)