pynccl.py 9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional, Union
5
6
7
8

# ===================== import region =====================
import torch
import torch.distributed as dist
9
from torch.distributed import ProcessGroup, ReduceOp
10

11
12
13
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
14
from vllm.distributed.utils import StatelessProcessGroup
15
from vllm.logger import init_logger
youkaichao's avatar
youkaichao committed
16
from vllm.utils import current_stream
17
18

logger = init_logger(__name__)
19
20


21
class PyNcclCommunicator:
22
23
24

    def __init__(
        self,
25
        group: Union[ProcessGroup, StatelessProcessGroup],
26
        device: Union[int, str, torch.device],
27
        library_path: Optional[str] = None,
28
    ):
29
30
31
32
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
33
            device: the device to bind the PyNcclCommunicator to. If None,
34
                it will be bind to f"cuda:{local_rank}".
35
36
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
37
38
39
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
40
41
42
43
44
45
46
47
48
49
50
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
                "PyNcclCommunicator should be attached to a non-NCCL group.")
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

51
        self.group = group
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

72
        if self.rank == 0:
73
74
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
75
        else:
76
77
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
78
79
80
81
82
83
84
85
86
87
88

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
89
        if isinstance(device, int):
90
91
92
93
94
95
96
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
97
98
99
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
100
101
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
102

youkaichao's avatar
youkaichao committed
103
            stream = current_stream()
104
            # A small all_reduce for warmup.
105
106
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
107
            stream.synchronize()
108
            del data
109

110
    def all_reduce(self,
111
                   in_tensor: torch.Tensor,
112
                   op: ReduceOp = ReduceOp.SUM,
113
                   stream=None) -> torch.Tensor:
114
        if self.disabled:
115
            return None
116
117
118
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
119
        assert in_tensor.device == self.device, (
120
            f"this nccl communicator is created to work on {self.device}, "
121
122
123
124
            f"but the input tensor is on {in_tensor.device}")

        out_tensor = torch.empty_like(in_tensor)

125
        if stream is None:
youkaichao's avatar
youkaichao committed
126
            stream = current_stream()
127
128
129
130
        self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                                buffer_type(out_tensor.data_ptr()),
                                in_tensor.numel(),
                                ncclDataTypeEnum.from_torch(in_tensor.dtype),
131
132
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))
133
        return out_tensor
134

135
136
137
138
139
140
141
142
143
144
145
146
147
    def all_gather(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
148
            stream = current_stream()
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
            cudaStream_t(stream.cuda_stream))

    def reduce_scatter(self,
                       output_tensor: torch.Tensor,
                       input_tensor: torch.Tensor,
                       op: ReduceOp = ReduceOp.SUM,
                       stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
169
            stream = current_stream()
170
171
172
173
174
175
176
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), self.comm,
            cudaStream_t(stream.cuda_stream))

177
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
178
179
180
181
182
183
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
184
            stream = current_stream()
185
186
187
188
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

189
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
190
191
192
193
194
195
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
196
            stream = current_stream()
197
198
199
200
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

201
202
203
204
205
206
207
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
208
            stream = current_stream()
209
210
211
212
213
214
215
216
217
218
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
        self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype), src,
                                self.comm, cudaStream_t(stream.cuda_stream))