pynccl.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional, Union
5
6
7
8

# ===================== import region =====================
import torch
import torch.distributed as dist
9
from torch.distributed import ProcessGroup, ReduceOp
10

11
import vllm.envs as envs
12
from vllm.distributed.device_communicators.pynccl_wrapper import (
13
14
15
16
17
18
19
20
    NCCLLibrary,
    buffer_type,
    cudaStream_t,
    ncclComm_t,
    ncclDataTypeEnum,
    ncclRedOpTypeEnum,
    ncclUniqueId,
)
21
from vllm.distributed.utils import StatelessProcessGroup
22
from vllm.logger import init_logger
youkaichao's avatar
youkaichao committed
23
from vllm.utils import current_stream
24
25

logger = init_logger(__name__)
26

27
28
29
30
31
_NCCL_SYMM_OPS_REGISTERED = False


def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
32
33
        nccl_symm_mem_context,
    )
34
35
36
37
38
39
40
    from vllm.utils import direct_register_custom_op

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

41
    def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
42
43
44
45
46
47
48
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

49
    def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
50
51
52
53
54
55
56
57
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )

58

59
class PyNcclCommunicator:
60
61
    def __init__(
        self,
62
        group: Union[ProcessGroup, StatelessProcessGroup],
63
        device: Union[int, str, torch.device],
64
        library_path: Optional[str] = None,
65
    ):
66
67
68
69
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
70
            device: the device to bind the PyNcclCommunicator to. If None,
71
                it will be bound to f"cuda:{local_rank}".
72
73
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
74
75
76
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
77
78
79
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
80
81
                "PyNcclCommunicator should be attached to a non-NCCL group."
            )
82
83
84
85
86
87
88
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

89
        self.group = group
90
91

        # if world_size == 1, no need to create communicator
92
        if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

108
        self.nccl_version = self.nccl.ncclGetRawVersion()
109
110
        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

111
        if self.rank == 0:
112
113
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
114
        else:
115
116
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
117
118
119
120
121
122
123
124
125
126
127

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
128
        if isinstance(device, int):
129
130
131
132
133
134
135
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
136
137
138
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
139
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
140
141
                self.world_size, self.unique_id, self.rank
            )
142

youkaichao's avatar
youkaichao committed
143
            stream = current_stream()
144
            # A small all_reduce for warmup.
145
146
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
147
            stream.synchronize()
148
            del data
149

150
151
152
153
154
155
156
    def all_reduce(
        self,
        in_tensor: torch.Tensor,
        out_tensor: torch.Tensor = None,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ) -> torch.Tensor:
157
        if self.disabled:
158
            return None
159
160
161
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
162
        assert in_tensor.device == self.device, (
163
            f"this nccl communicator is created to work on {self.device}, "
164
165
            f"but the input tensor is on {in_tensor.device}"
        )
166

167
168
        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)
169

170
        if stream is None:
youkaichao's avatar
youkaichao committed
171
            stream = current_stream()
172
173
174
175
176
177
178
179
180
        self.nccl.ncclAllReduce(
            buffer_type(in_tensor.data_ptr()),
            buffer_type(out_tensor.data_ptr()),
            in_tensor.numel(),
            ncclDataTypeEnum.from_torch(in_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
181
        return out_tensor
182

183
184
185
    def all_gather(
        self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
    ):
186
187
188
189
190
191
192
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
193
194
            f"but the input tensor is on {input_tensor.device}"
        )
195
        if stream is None:
youkaichao's avatar
youkaichao committed
196
            stream = current_stream()
197
198
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
199
200
201
202
203
204
            buffer_type(output_tensor.data_ptr()),
            input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
220
221
            f"but the input tensor is on {input_tensor.device}"
        )
222
223
224
225
226
227
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
228
            dst_slice = output_tensor[split_offset : split_offset + split_size]
229
230
231
232
233
234
235
236
237
238
239
240
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

241
242
243
244
245
246
247
    def reduce_scatter(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
248
249
250
251
252
253
254
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
255
256
            f"but the input tensor is on {input_tensor.device}"
        )
257
        if stream is None:
youkaichao's avatar
youkaichao committed
258
            stream = current_stream()
259
260
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
261
262
            buffer_type(output_tensor.data_ptr()),
            output_tensor.numel(),
263
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
264
265
266
267
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
284
285
            f"but the input tensor is on {input_tensor.device}"
        )
286
287
288
289
290
291
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
292
            chunk = input_tensor[split_offset : split_offset + split_size, ...]
293
294
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
295
296
                buffer_type(output_tensor.data_ptr()),
                chunk.numel(),
297
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
298
299
300
301
302
                ncclRedOpTypeEnum.from_torch(op),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
303
304
305
            split_offset += split_size
        self.nccl.ncclGroupEnd()

306
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
307
308
309
310
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
311
312
            f"but the input tensor is on {tensor.device}"
        )
313
        if stream is None:
youkaichao's avatar
youkaichao committed
314
            stream = current_stream()
315
316
317
318
319
320
321
322
        self.nccl.ncclSend(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            dst,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
323

324
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
325
326
327
328
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
329
330
            f"but the input tensor is on {tensor.device}"
        )
331
        if stream is None:
youkaichao's avatar
youkaichao committed
332
            stream = current_stream()
333
334
335
336
337
338
339
340
        self.nccl.ncclRecv(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
341

342
343
344
345
346
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
347
348
            f"but the input tensor is on {tensor.device}"
        )
349
        if stream is None:
youkaichao's avatar
youkaichao committed
350
            stream = current_stream()
351
352
353
354
355
356
357
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
358
359
360
361
362
363
364
365
366
        self.nccl.ncclBroadcast(
            sendbuff,
            recvbuff,
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
367
368
369
370
371
372

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()
373
374
375
376
377
378
379
380
381
382

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
383
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
384
385
386

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)