pynccl.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7

# ===================== import region =====================
import torch
import torch.distributed as dist
8
from torch.distributed import ProcessGroup, ReduceOp
9

10
import vllm.envs as envs
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
12
13
14
15
16
17
18
19
    NCCLLibrary,
    buffer_type,
    cudaStream_t,
    ncclComm_t,
    ncclDataTypeEnum,
    ncclRedOpTypeEnum,
    ncclUniqueId,
)
20
from vllm.distributed.utils import StatelessProcessGroup
21
from vllm.logger import init_logger
22
from vllm.utils.torch_utils import current_stream
23
24

logger = init_logger(__name__)
25

26
27
28
29
30
_NCCL_SYMM_OPS_REGISTERED = False


def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
31
32
        nccl_symm_mem_context,
    )
33
    from vllm.utils.torch_utils import direct_register_custom_op
34
35
36
37
38
39

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

40
    def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
41
42
43
44
45
46
47
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

48
    def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
49
50
51
52
53
54
55
56
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )

57

58
class PyNcclCommunicator:
59
60
    def __init__(
        self,
61
62
63
        group: ProcessGroup | StatelessProcessGroup,
        device: int | str | torch.device,
        library_path: str | None = None,
64
    ):
65
66
67
68
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
69
            device: the device to bind the PyNcclCommunicator to. If None,
70
                it will be bound to f"cuda:{local_rank}".
71
72
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
73
74
75
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
76
77
78
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
79
80
                "PyNcclCommunicator should be attached to a non-NCCL group."
            )
81
82
83
84
85
86
87
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

88
        self.group = group
89
90

        # if world_size == 1, no need to create communicator
91
        if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

107
        self.nccl_version = self.nccl.ncclGetRawVersion()
108
        if self.rank == 0:
109
110
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
111
            logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
112
        else:
113
114
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
115
116
117
118
119
120
121
122
123
124
125

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
126
        if isinstance(device, int):
127
128
129
130
131
132
133
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
134
135
136
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
137
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
138
139
                self.world_size, self.unique_id, self.rank
            )
140

youkaichao's avatar
youkaichao committed
141
            stream = current_stream()
142
            # A small all_reduce for warmup.
143
144
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
145
            stream.synchronize()
146
            del data
147

148
149
150
151
152
153
154
    def all_reduce(
        self,
        in_tensor: torch.Tensor,
        out_tensor: torch.Tensor = None,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ) -> torch.Tensor:
155
        if self.disabled:
156
            return None
157
158
159
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
160
        assert in_tensor.device == self.device, (
161
            f"this nccl communicator is created to work on {self.device}, "
162
163
            f"but the input tensor is on {in_tensor.device}"
        )
164

165
166
        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)
167

168
        if stream is None:
youkaichao's avatar
youkaichao committed
169
            stream = current_stream()
170
171
172
173
174
175
176
177
178
        self.nccl.ncclAllReduce(
            buffer_type(in_tensor.data_ptr()),
            buffer_type(out_tensor.data_ptr()),
            in_tensor.numel(),
            ncclDataTypeEnum.from_torch(in_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
179
        return out_tensor
180

181
182
183
    def all_gather(
        self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
    ):
184
185
186
187
188
189
190
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
191
192
            f"but the input tensor is on {input_tensor.device}"
        )
193
        if stream is None:
youkaichao's avatar
youkaichao committed
194
            stream = current_stream()
195
196
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
197
198
199
200
201
202
            buffer_type(output_tensor.data_ptr()),
            input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
218
219
            f"but the input tensor is on {input_tensor.device}"
        )
220
221
222
223
224
225
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
226
            dst_slice = output_tensor[split_offset : split_offset + split_size]
227
228
229
230
231
232
233
234
235
236
237
238
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

239
240
241
242
243
244
245
    def reduce_scatter(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
246
247
248
249
250
251
252
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
253
254
            f"but the input tensor is on {input_tensor.device}"
        )
255
        if stream is None:
youkaichao's avatar
youkaichao committed
256
            stream = current_stream()
257
258
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
259
260
            buffer_type(output_tensor.data_ptr()),
            output_tensor.numel(),
261
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
262
263
264
265
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
282
283
            f"but the input tensor is on {input_tensor.device}"
        )
284
285
286
287
288
289
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
290
            chunk = input_tensor[split_offset : split_offset + split_size, ...]
291
292
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
293
294
                buffer_type(output_tensor.data_ptr()),
                chunk.numel(),
295
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
296
297
298
299
300
                ncclRedOpTypeEnum.from_torch(op),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
301
302
303
            split_offset += split_size
        self.nccl.ncclGroupEnd()

304
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
305
306
307
308
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
309
310
            f"but the input tensor is on {tensor.device}"
        )
311
        if stream is None:
youkaichao's avatar
youkaichao committed
312
            stream = current_stream()
313
314
315
316
317
318
319
320
        self.nccl.ncclSend(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            dst,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
321

322
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
323
324
325
326
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
327
328
            f"but the input tensor is on {tensor.device}"
        )
329
        if stream is None:
youkaichao's avatar
youkaichao committed
330
            stream = current_stream()
331
332
333
334
335
336
337
338
        self.nccl.ncclRecv(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
339

340
341
342
343
344
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
345
346
            f"but the input tensor is on {tensor.device}"
        )
347
        if stream is None:
youkaichao's avatar
youkaichao committed
348
            stream = current_stream()
349
350
351
352
353
354
355
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
356
357
358
359
360
361
362
363
364
        self.nccl.ncclBroadcast(
            sendbuff,
            recvbuff,
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
365
366
367
368
369
370

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()
371
372
373
374
375
376
377
378
379
380

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
381
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
382
383
384

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)