pynccl.py 14.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7

# ===================== import region =====================
import torch
import torch.distributed as dist
8
from torch.distributed import ProcessGroup, ReduceOp
9

10
import vllm.envs as envs
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
12
13
14
15
16
17
18
19
    NCCLLibrary,
    buffer_type,
    cudaStream_t,
    ncclComm_t,
    ncclDataTypeEnum,
    ncclRedOpTypeEnum,
    ncclUniqueId,
)
20
from vllm.distributed.utils import StatelessProcessGroup
21
from vllm.logger import init_logger
22
from vllm.utils.torch_utils import current_stream
23
24

logger = init_logger(__name__)
25

26
27
28
29
30
_NCCL_SYMM_OPS_REGISTERED = False


def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
31
32
        nccl_symm_mem_context,
    )
33
    from vllm.utils.torch_utils import direct_register_custom_op
34
35
36
37
38
39

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

40
    def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
41
42
43
44
45
46
47
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

48
    def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
49
50
51
52
53
54
55
56
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )

57

58
class PyNcclCommunicator:
59
60
    def __init__(
        self,
61
62
63
        group: ProcessGroup | StatelessProcessGroup,
        device: int | str | torch.device,
        library_path: str | None = None,
64
    ):
65
66
67
68
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
69
            device: the device to bind the PyNcclCommunicator to. If None,
70
                it will be bound to f"cuda:{local_rank}".
71
72
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
73
74
75
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
76
77
78
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
79
80
                "PyNcclCommunicator should be attached to a non-NCCL group."
            )
81
82
83
84
85
86
87
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

88
        self.group = group
89
90

        # if world_size == 1, no need to create communicator
91
        if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

107
        self.nccl_version = self.nccl.ncclGetRawVersion()
108
        if self.rank == 0:
109
110
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
111
112
113
            logger.info_once(
                "vLLM is using nccl==%s", self.nccl.ncclGetVersion(), scope="local"
            )
114
        else:
115
116
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
117
118
119
120
121
122
123
124
125
126
127

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
128
        if isinstance(device, int):
129
130
131
132
133
134
135
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
136
        with torch.accelerator.device_index(device.index):
137
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
138
139
                self.world_size, self.unique_id, self.rank
            )
140

youkaichao's avatar
youkaichao committed
141
            stream = current_stream()
142
            # A small all_reduce for warmup.
143
144
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
145
            stream.synchronize()
146
            del data
147

148
149
150
151
152
153
154
    def destroy(self):
        if self.available and not self.disabled:
            with torch.accelerator.device_index(self.device.index):
                self.nccl.ncclCommDestroy(self.comm)
            self.available = False
            self.disabled = True

155
156
157
158
159
160
161
    def all_reduce(
        self,
        in_tensor: torch.Tensor,
        out_tensor: torch.Tensor = None,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ) -> torch.Tensor:
162
        if self.disabled:
163
            return None
164
165
166
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
167
        assert in_tensor.device == self.device, (
168
            f"this nccl communicator is created to work on {self.device}, "
169
170
            f"but the input tensor is on {in_tensor.device}"
        )
171

172
173
        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)
174

175
        if stream is None:
youkaichao's avatar
youkaichao committed
176
            stream = current_stream()
177
178
179
180
181
182
183
184
185
        self.nccl.ncclAllReduce(
            buffer_type(in_tensor.data_ptr()),
            buffer_type(out_tensor.data_ptr()),
            in_tensor.numel(),
            ncclDataTypeEnum.from_torch(in_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
186
        return out_tensor
187

188
189
190
    def all_gather(
        self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
    ):
191
192
193
194
195
196
197
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
198
199
            f"but the input tensor is on {input_tensor.device}"
        )
200
        if stream is None:
youkaichao's avatar
youkaichao committed
201
            stream = current_stream()
202
203
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
204
205
206
207
208
209
            buffer_type(output_tensor.data_ptr()),
            input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
225
226
            f"but the input tensor is on {input_tensor.device}"
        )
227
228
229
230
231
232
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
233
            dst_slice = output_tensor[split_offset : split_offset + split_size]
234
235
236
237
238
239
240
241
242
243
244
245
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

246
247
248
249
250
251
252
    def reduce_scatter(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
253
254
255
256
257
258
259
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
260
261
            f"but the input tensor is on {input_tensor.device}"
        )
262
        if stream is None:
youkaichao's avatar
youkaichao committed
263
            stream = current_stream()
264
265
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
266
267
            buffer_type(output_tensor.data_ptr()),
            output_tensor.numel(),
268
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
269
270
271
272
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
289
290
            f"but the input tensor is on {input_tensor.device}"
        )
291
292
293
294
295
296
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
297
            chunk = input_tensor[split_offset : split_offset + split_size, ...]
298
299
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
300
301
                buffer_type(output_tensor.data_ptr()),
                chunk.numel(),
302
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
303
304
305
306
307
                ncclRedOpTypeEnum.from_torch(op),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
308
309
310
            split_offset += split_size
        self.nccl.ncclGroupEnd()

311
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
312
313
314
315
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
316
317
            f"but the input tensor is on {tensor.device}"
        )
318
        if stream is None:
youkaichao's avatar
youkaichao committed
319
            stream = current_stream()
320
321
322
323
324
325
326
327
328
        if tensor.dtype in [
            torch.float8_e5m2,
            torch.float8_e4m3fn,
            torch.float8_e4m3fnuz,
            torch.float8_e5m2fnuz,
        ]:
            nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
        else:
            nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
329
330
331
        self.nccl.ncclSend(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
332
            nccl_dtype,
333
334
335
336
            dst,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
337

338
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
339
340
341
342
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
343
344
            f"but the input tensor is on {tensor.device}"
        )
345
        if stream is None:
youkaichao's avatar
youkaichao committed
346
            stream = current_stream()
347
348
349
350
351
352
353
354
355
        if tensor.dtype in [
            torch.float8_e5m2,
            torch.float8_e4m3fn,
            torch.float8_e4m3fnuz,
            torch.float8_e5m2fnuz,
        ]:
            nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
        else:
            nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
356
357
358
        self.nccl.ncclRecv(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
359
            nccl_dtype,
360
361
362
363
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
364

365
366
367
368
369
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
370
371
            f"but the input tensor is on {tensor.device}"
        )
372
        if stream is None:
youkaichao's avatar
youkaichao committed
373
            stream = current_stream()
374
375
376
377
378
379
380
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
381
382
383
384
385
386
387
388
389
        self.nccl.ncclBroadcast(
            sendbuff,
            recvbuff,
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
390
391
392
393
394
395

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()
396
397
398
399
400
401
402
403
404
405

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
406
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
407
408
409

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    def batch_isend_irecv(self, p2p_ops: list, stream=None):
        if self.disabled:
            return
        if stream is None:
            stream = current_stream()
        self.group_start()
        for op in p2p_ops:
            if op.op is torch.distributed.isend:
                self.send(op.tensor, op.group_peer, stream)
            elif op.op is torch.distributed.irecv:
                self.recv(op.tensor, op.group_peer, stream)

        self.group_end()