pynccl.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7

# ===================== import region =====================
import torch
import torch.distributed as dist
8
from torch.distributed import ProcessGroup, ReduceOp
9

10
import vllm.envs as envs
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
12
13
14
15
16
17
18
19
    NCCLLibrary,
    buffer_type,
    cudaStream_t,
    ncclComm_t,
    ncclDataTypeEnum,
    ncclRedOpTypeEnum,
    ncclUniqueId,
)
20
from vllm.distributed.utils import StatelessProcessGroup
21
from vllm.logger import init_logger
22
from vllm.utils.torch_utils import current_stream
23
24

logger = init_logger(__name__)
25

26
27
28
29
30
_NCCL_SYMM_OPS_REGISTERED = False


def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
31
32
        nccl_symm_mem_context,
    )
33
    from vllm.utils.torch_utils import direct_register_custom_op
34
35
36
37
38
39

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

40
    def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
41
42
43
44
45
46
47
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

48
    def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
49
50
51
52
53
54
55
56
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )

57

58
class PyNcclCommunicator:
59
60
    def __init__(
        self,
61
62
63
        group: ProcessGroup | StatelessProcessGroup,
        device: int | str | torch.device,
        library_path: str | None = None,
64
    ):
65
66
67
68
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
69
            device: the device to bind the PyNcclCommunicator to. If None,
70
                it will be bound to f"cuda:{local_rank}".
71
72
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
73
74
75
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
76
77
78
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
79
80
                "PyNcclCommunicator should be attached to a non-NCCL group."
            )
81
82
83
84
85
86
87
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

88
        self.group = group
89
90

        # if world_size == 1, no need to create communicator
91
        if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

107
        self.nccl_version = self.nccl.ncclGetRawVersion()
108
        if self.rank == 0:
109
110
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
111
            logger.info_once("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
112
        else:
113
114
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
115
116
117
118
119
120
121
122
123
124
125

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
126
        if isinstance(device, int):
127
128
129
130
131
132
133
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
134
        with torch.accelerator.device_index(device.index):
135
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
136
137
                self.world_size, self.unique_id, self.rank
            )
138

youkaichao's avatar
youkaichao committed
139
            stream = current_stream()
140
            # A small all_reduce for warmup.
141
142
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
143
            stream.synchronize()
144
            del data
145

146
147
148
149
150
151
152
    def destroy(self):
        if self.available and not self.disabled:
            with torch.accelerator.device_index(self.device.index):
                self.nccl.ncclCommDestroy(self.comm)
            self.available = False
            self.disabled = True

153
154
155
156
157
158
159
    def all_reduce(
        self,
        in_tensor: torch.Tensor,
        out_tensor: torch.Tensor = None,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ) -> torch.Tensor:
160
        if self.disabled:
161
            return None
162
163
164
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
165
        assert in_tensor.device == self.device, (
166
            f"this nccl communicator is created to work on {self.device}, "
167
168
            f"but the input tensor is on {in_tensor.device}"
        )
169

170
171
        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)
172

173
        if stream is None:
youkaichao's avatar
youkaichao committed
174
            stream = current_stream()
175
176
177
178
179
180
181
182
183
        self.nccl.ncclAllReduce(
            buffer_type(in_tensor.data_ptr()),
            buffer_type(out_tensor.data_ptr()),
            in_tensor.numel(),
            ncclDataTypeEnum.from_torch(in_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
184
        return out_tensor
185

186
187
188
    def all_gather(
        self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
    ):
189
190
191
192
193
194
195
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
196
197
            f"but the input tensor is on {input_tensor.device}"
        )
198
        if stream is None:
youkaichao's avatar
youkaichao committed
199
            stream = current_stream()
200
201
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
202
203
204
205
206
207
            buffer_type(output_tensor.data_ptr()),
            input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
223
224
            f"but the input tensor is on {input_tensor.device}"
        )
225
226
227
228
229
230
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
231
            dst_slice = output_tensor[split_offset : split_offset + split_size]
232
233
234
235
236
237
238
239
240
241
242
243
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

244
245
246
247
248
249
250
    def reduce_scatter(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
251
252
253
254
255
256
257
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
258
259
            f"but the input tensor is on {input_tensor.device}"
        )
260
        if stream is None:
youkaichao's avatar
youkaichao committed
261
            stream = current_stream()
262
263
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
264
265
            buffer_type(output_tensor.data_ptr()),
            output_tensor.numel(),
266
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
267
268
269
270
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
287
288
            f"but the input tensor is on {input_tensor.device}"
        )
289
290
291
292
293
294
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
295
            chunk = input_tensor[split_offset : split_offset + split_size, ...]
296
297
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
298
299
                buffer_type(output_tensor.data_ptr()),
                chunk.numel(),
300
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
301
302
303
304
305
                ncclRedOpTypeEnum.from_torch(op),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
306
307
308
            split_offset += split_size
        self.nccl.ncclGroupEnd()

309
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
310
311
312
313
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
314
315
            f"but the input tensor is on {tensor.device}"
        )
316
        if stream is None:
youkaichao's avatar
youkaichao committed
317
            stream = current_stream()
318
319
320
321
322
323
324
325
326
        if tensor.dtype in [
            torch.float8_e5m2,
            torch.float8_e4m3fn,
            torch.float8_e4m3fnuz,
            torch.float8_e5m2fnuz,
        ]:
            nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
        else:
            nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
327
328
329
        self.nccl.ncclSend(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
330
            nccl_dtype,
331
332
333
334
            dst,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
335

336
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
337
338
339
340
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
341
342
            f"but the input tensor is on {tensor.device}"
        )
343
        if stream is None:
youkaichao's avatar
youkaichao committed
344
            stream = current_stream()
345
346
347
348
349
350
351
352
353
        if tensor.dtype in [
            torch.float8_e5m2,
            torch.float8_e4m3fn,
            torch.float8_e4m3fnuz,
            torch.float8_e5m2fnuz,
        ]:
            nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
        else:
            nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
354
355
356
        self.nccl.ncclRecv(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
357
            nccl_dtype,
358
359
360
361
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
362

363
364
365
366
367
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
368
369
            f"but the input tensor is on {tensor.device}"
        )
370
        if stream is None:
youkaichao's avatar
youkaichao committed
371
            stream = current_stream()
372
373
374
375
376
377
378
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
379
380
381
382
383
384
385
386
387
        self.nccl.ncclBroadcast(
            sendbuff,
            recvbuff,
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
388
389
390
391
392
393

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()
394
395
396
397
398
399
400
401
402
403

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
404
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
405
406
407

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)
408
409
410
411
412
413
414
415
416
417
418
419
420
421

    def batch_isend_irecv(self, p2p_ops: list, stream=None):
        if self.disabled:
            return
        if stream is None:
            stream = current_stream()
        self.group_start()
        for op in p2p_ops:
            if op.op is torch.distributed.isend:
                self.send(op.tensor, op.group_peer, stream)
            elif op.op is torch.distributed.irecv:
                self.recv(op.tensor, op.group_peer, stream)

        self.group_end()