pynccl.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7

# ===================== import region =====================
import torch
import torch.distributed as dist
8
from torch.distributed import ProcessGroup, ReduceOp
9

10
import vllm.envs as envs
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
12
13
14
15
16
17
18
19
    NCCLLibrary,
    buffer_type,
    cudaStream_t,
    ncclComm_t,
    ncclDataTypeEnum,
    ncclRedOpTypeEnum,
    ncclUniqueId,
)
20
from vllm.distributed.utils import StatelessProcessGroup
21
from vllm.logger import init_logger
youkaichao's avatar
youkaichao committed
22
from vllm.utils import current_stream
23
24

logger = init_logger(__name__)
25

26
27
28
29
30
_NCCL_SYMM_OPS_REGISTERED = False


def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
31
32
        nccl_symm_mem_context,
    )
33
34
35
36
37
38
39
    from vllm.utils import direct_register_custom_op

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

40
    def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
41
42
43
44
45
46
47
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

48
    def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
49
50
51
52
53
54
55
56
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )

57

58
class PyNcclCommunicator:
59
60
    def __init__(
        self,
61
62
63
        group: ProcessGroup | StatelessProcessGroup,
        device: int | str | torch.device,
        library_path: str | None = None,
64
    ):
65
66
67
68
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
69
            device: the device to bind the PyNcclCommunicator to. If None,
70
                it will be bound to f"cuda:{local_rank}".
71
72
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
73
74
75
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
76
77
78
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
79
80
                "PyNcclCommunicator should be attached to a non-NCCL group."
            )
81
82
83
84
85
86
87
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

88
        self.group = group
89
90

        # if world_size == 1, no need to create communicator
91
        if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

107
        self.nccl_version = self.nccl.ncclGetRawVersion()
108
109
        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

110
        if self.rank == 0:
111
112
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
113
        else:
114
115
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
116
117
118
119
120
121
122
123
124
125
126

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
127
        if isinstance(device, int):
128
129
130
131
132
133
134
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
135
136
137
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
138
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
139
140
                self.world_size, self.unique_id, self.rank
            )
141

youkaichao's avatar
youkaichao committed
142
            stream = current_stream()
143
            # A small all_reduce for warmup.
144
145
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
146
            stream.synchronize()
147
            del data
148

149
150
151
152
153
154
155
    def all_reduce(
        self,
        in_tensor: torch.Tensor,
        out_tensor: torch.Tensor = None,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ) -> torch.Tensor:
156
        if self.disabled:
157
            return None
158
159
160
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
161
        assert in_tensor.device == self.device, (
162
            f"this nccl communicator is created to work on {self.device}, "
163
164
            f"but the input tensor is on {in_tensor.device}"
        )
165

166
167
        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)
168

169
        if stream is None:
youkaichao's avatar
youkaichao committed
170
            stream = current_stream()
171
172
173
174
175
176
177
178
179
        self.nccl.ncclAllReduce(
            buffer_type(in_tensor.data_ptr()),
            buffer_type(out_tensor.data_ptr()),
            in_tensor.numel(),
            ncclDataTypeEnum.from_torch(in_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
180
        return out_tensor
181

182
183
184
    def all_gather(
        self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
    ):
185
186
187
188
189
190
191
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
192
193
            f"but the input tensor is on {input_tensor.device}"
        )
194
        if stream is None:
youkaichao's avatar
youkaichao committed
195
            stream = current_stream()
196
197
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
198
199
200
201
202
203
            buffer_type(output_tensor.data_ptr()),
            input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
219
220
            f"but the input tensor is on {input_tensor.device}"
        )
221
222
223
224
225
226
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
227
            dst_slice = output_tensor[split_offset : split_offset + split_size]
228
229
230
231
232
233
234
235
236
237
238
239
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

240
241
242
243
244
245
246
    def reduce_scatter(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
247
248
249
250
251
252
253
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
254
255
            f"but the input tensor is on {input_tensor.device}"
        )
256
        if stream is None:
youkaichao's avatar
youkaichao committed
257
            stream = current_stream()
258
259
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
260
261
            buffer_type(output_tensor.data_ptr()),
            output_tensor.numel(),
262
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
263
264
265
266
            ncclRedOpTypeEnum.from_torch(op),
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
283
284
            f"but the input tensor is on {input_tensor.device}"
        )
285
286
287
288
289
290
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
291
            chunk = input_tensor[split_offset : split_offset + split_size, ...]
292
293
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
294
295
                buffer_type(output_tensor.data_ptr()),
                chunk.numel(),
296
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
297
298
299
300
301
                ncclRedOpTypeEnum.from_torch(op),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
302
303
304
            split_offset += split_size
        self.nccl.ncclGroupEnd()

305
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
306
307
308
309
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
310
311
            f"but the input tensor is on {tensor.device}"
        )
312
        if stream is None:
youkaichao's avatar
youkaichao committed
313
            stream = current_stream()
314
315
316
317
318
319
320
321
        self.nccl.ncclSend(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            dst,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
322

323
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
324
325
326
327
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
328
329
            f"but the input tensor is on {tensor.device}"
        )
330
        if stream is None:
youkaichao's avatar
youkaichao committed
331
            stream = current_stream()
332
333
334
335
336
337
338
339
        self.nccl.ncclRecv(
            buffer_type(tensor.data_ptr()),
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
340

341
342
343
344
345
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
346
347
            f"but the input tensor is on {tensor.device}"
        )
348
        if stream is None:
youkaichao's avatar
youkaichao committed
349
            stream = current_stream()
350
351
352
353
354
355
356
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
357
358
359
360
361
362
363
364
365
        self.nccl.ncclBroadcast(
            sendbuff,
            recvbuff,
            tensor.numel(),
            ncclDataTypeEnum.from_torch(tensor.dtype),
            src,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
366
367
368
369
370
371

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()
372
373
374
375
376
377
378
379
380
381

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
382
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
383
384
385

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)