pynccl.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional, Union
5
6
7
8

# ===================== import region =====================
import torch
import torch.distributed as dist
9
from torch.distributed import ProcessGroup, ReduceOp
10

11
import vllm.envs as envs
12
13
14
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
15
from vllm.distributed.utils import StatelessProcessGroup
16
from vllm.logger import init_logger
youkaichao's avatar
youkaichao committed
17
from vllm.utils import current_stream
18
19

logger = init_logger(__name__)
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
_NCCL_SYMM_OPS_REGISTERED = False


def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
        nccl_symm_mem_context)
    from vllm.utils import direct_register_custom_op

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

    def all_reduce_symmetric_with_copy_impl(
            input_tensor: torch.Tensor) -> torch.Tensor:
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

    def all_reduce_symmetric_with_copy_fake(
            input_tensor: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )

53

54
class PyNcclCommunicator:
55
56
57

    def __init__(
        self,
58
        group: Union[ProcessGroup, StatelessProcessGroup],
59
        device: Union[int, str, torch.device],
60
        library_path: Optional[str] = None,
61
    ):
62
63
64
65
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
66
            device: the device to bind the PyNcclCommunicator to. If None,
67
                it will be bound to f"cuda:{local_rank}".
68
69
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
70
71
72
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
73
74
75
76
77
78
79
80
81
82
83
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
                "PyNcclCommunicator should be attached to a non-NCCL group.")
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

84
        self.group = group
85
86

        # if world_size == 1, no need to create communicator
87
        if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

103
        self.nccl_version = self.nccl.ncclGetRawVersion()
104
105
        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

106
        if self.rank == 0:
107
108
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
109
        else:
110
111
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
112
113
114
115
116
117
118
119
120
121
122

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
123
        if isinstance(device, int):
124
125
126
127
128
129
130
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
131
132
133
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
134
135
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
136

youkaichao's avatar
youkaichao committed
137
            stream = current_stream()
138
            # A small all_reduce for warmup.
139
140
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
141
            stream.synchronize()
142
            del data
143

144
    def all_reduce(self,
145
                   in_tensor: torch.Tensor,
146
                   out_tensor: torch.Tensor = None,
147
                   op: ReduceOp = ReduceOp.SUM,
148
                   stream=None) -> torch.Tensor:
149
        if self.disabled:
150
            return None
151
152
153
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
154
        assert in_tensor.device == self.device, (
155
            f"this nccl communicator is created to work on {self.device}, "
156
157
            f"but the input tensor is on {in_tensor.device}")

158
159
        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)
160

161
        if stream is None:
youkaichao's avatar
youkaichao committed
162
            stream = current_stream()
163
164
165
166
        self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                                buffer_type(out_tensor.data_ptr()),
                                in_tensor.numel(),
                                ncclDataTypeEnum.from_torch(in_tensor.dtype),
167
168
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))
169
        return out_tensor
170

171
172
173
174
175
176
177
178
179
180
181
182
183
    def all_gather(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
184
            stream = current_stream()
185
186
187
188
189
190
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
            cudaStream_t(stream.cuda_stream))

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
            dst_slice = output_tensor[split_offset:split_offset + split_size]
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

225
226
227
228
229
230
231
232
233
234
235
236
237
238
    def reduce_scatter(self,
                       output_tensor: torch.Tensor,
                       input_tensor: torch.Tensor,
                       op: ReduceOp = ReduceOp.SUM,
                       stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
239
            stream = current_stream()
240
241
242
243
244
245
246
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), self.comm,
            cudaStream_t(stream.cuda_stream))

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
            chunk = input_tensor[split_offset:split_offset + split_size, ...]
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
                buffer_type(output_tensor.data_ptr()), chunk.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                ncclRedOpTypeEnum.from_torch(op), root, self.comm,
                cudaStream_t(stream.cuda_stream))
            split_offset += split_size
        self.nccl.ncclGroupEnd()

279
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
280
281
282
283
284
285
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
286
            stream = current_stream()
287
288
289
290
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

291
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
292
293
294
295
296
297
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
298
            stream = current_stream()
299
300
301
302
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

303
304
305
306
307
308
309
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
310
            stream = current_stream()
311
312
313
314
315
316
317
318
319
320
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
        self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype), src,
                                self.comm, cudaStream_t(stream.cuda_stream))
321
322
323
324
325
326

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
                                                size, 1)

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)