pynccl.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional, Union
5
6
7
8

# ===================== import region =====================
import torch
import torch.distributed as dist
9
from torch.distributed import ProcessGroup, ReduceOp
10

11
12
13
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
14
from vllm.distributed.utils import StatelessProcessGroup
15
from vllm.logger import init_logger
youkaichao's avatar
youkaichao committed
16
from vllm.utils import current_stream
17
18

logger = init_logger(__name__)
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
_NCCL_SYMM_OPS_REGISTERED = False


def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
        nccl_symm_mem_context)
    from vllm.utils import direct_register_custom_op

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

    def all_reduce_symmetric_with_copy_impl(
            input_tensor: torch.Tensor) -> torch.Tensor:
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

    def all_reduce_symmetric_with_copy_fake(
            input_tensor: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )

52

53
class PyNcclCommunicator:
54
55
56

    def __init__(
        self,
57
        group: Union[ProcessGroup, StatelessProcessGroup],
58
        device: Union[int, str, torch.device],
59
        library_path: Optional[str] = None,
60
    ):
61
62
63
64
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
65
            device: the device to bind the PyNcclCommunicator to. If None,
66
                it will be bound to f"cuda:{local_rank}".
67
68
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
69
70
71
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
72
73
74
75
76
77
78
79
80
81
82
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
                "PyNcclCommunicator should be attached to a non-NCCL group.")
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

83
        self.group = group
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

102
        self.nccl_version = self.nccl.ncclGetRawVersion()
103
104
        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

105
        if self.rank == 0:
106
107
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
108
        else:
109
110
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
111
112
113
114
115
116
117
118
119
120
121

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
122
        if isinstance(device, int):
123
124
125
126
127
128
129
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
130
131
132
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
133
134
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
135

youkaichao's avatar
youkaichao committed
136
            stream = current_stream()
137
            # A small all_reduce for warmup.
138
139
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
140
            stream.synchronize()
141
            del data
142

143
    def all_reduce(self,
144
                   in_tensor: torch.Tensor,
145
                   out_tensor: torch.Tensor = None,
146
                   op: ReduceOp = ReduceOp.SUM,
147
                   stream=None) -> torch.Tensor:
148
        if self.disabled:
149
            return None
150
151
152
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
153
        assert in_tensor.device == self.device, (
154
            f"this nccl communicator is created to work on {self.device}, "
155
156
            f"but the input tensor is on {in_tensor.device}")

157
158
        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)
159

160
        if stream is None:
youkaichao's avatar
youkaichao committed
161
            stream = current_stream()
162
163
164
165
        self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                                buffer_type(out_tensor.data_ptr()),
                                in_tensor.numel(),
                                ncclDataTypeEnum.from_torch(in_tensor.dtype),
166
167
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))
168
        return out_tensor
169

170
171
172
173
174
175
176
177
178
179
180
181
182
    def all_gather(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
183
            stream = current_stream()
184
185
186
187
188
189
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
            cudaStream_t(stream.cuda_stream))

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
            dst_slice = output_tensor[split_offset:split_offset + split_size]
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def reduce_scatter(self,
                       output_tensor: torch.Tensor,
                       input_tensor: torch.Tensor,
                       op: ReduceOp = ReduceOp.SUM,
                       stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
238
            stream = current_stream()
239
240
241
242
243
244
245
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), self.comm,
            cudaStream_t(stream.cuda_stream))

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
            chunk = input_tensor[split_offset:split_offset + split_size, ...]
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
                buffer_type(output_tensor.data_ptr()), chunk.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                ncclRedOpTypeEnum.from_torch(op), root, self.comm,
                cudaStream_t(stream.cuda_stream))
            split_offset += split_size
        self.nccl.ncclGroupEnd()

278
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
279
280
281
282
283
284
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
285
            stream = current_stream()
286
287
288
289
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

290
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
291
292
293
294
295
296
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
297
            stream = current_stream()
298
299
300
301
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

302
303
304
305
306
307
308
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
309
            stream = current_stream()
310
311
312
313
314
315
316
317
318
319
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
        self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype), src,
                                self.comm, cudaStream_t(stream.cuda_stream))
320
321
322
323
324
325

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
                                                size, 1)

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)