pynccl.py 9.57 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Optional, Union
3
4
5
6

# ===================== import region =====================
import torch
import torch.distributed as dist
7
from torch.distributed import ProcessGroup, ReduceOp
8

9
10
11
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
12
from vllm.distributed.utils import StatelessProcessGroup
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)
16
17


18
class PyNcclCommunicator:
19
20
21

    def __init__(
        self,
22
        group: Union[ProcessGroup, StatelessProcessGroup],
23
        device: Union[int, str, torch.device],
24
        library_path: Optional[str] = None,
25
    ):
26
27
28
29
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
30
            device: the device to bind the PyNcclCommunicator to. If None,
31
                it will be bind to f"cuda:{local_rank}".
32
33
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
34
35
36
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
37
38
39
40
41
42
43
44
45
46
47
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
                "PyNcclCommunicator should be attached to a non-NCCL group.")
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

48
        self.group = group
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            self.stream = None
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            self.stream = None
            return

        self.available = True
        self.disabled = False

        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

71
        if self.rank == 0:
72
73
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
74
        else:
75
76
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
77
78
79
80
81
82
83
84
85
86
87

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
88
        if isinstance(device, int):
89
90
91
92
93
94
95
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
96
97
98
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
99
100
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
101
            self.stream = torch.cuda.Stream()
102

103
            # A small all_reduce for warmup.
104
105
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
106
            self.stream.synchronize()
107
            del data
108

109
    def all_reduce(self,
110
                   in_tensor: torch.Tensor,
111
                   op: ReduceOp = ReduceOp.SUM,
112
                   stream=None) -> torch.Tensor:
113
        if self.disabled:
114
            return None
115
116
117
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
118
        assert in_tensor.device == self.device, (
119
            f"this nccl communicator is created to work on {self.device}, "
120
121
122
123
            f"but the input tensor is on {in_tensor.device}")

        out_tensor = torch.empty_like(in_tensor)

124
125
        if stream is None:
            stream = self.stream
126
127
128
129
        self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                                buffer_type(out_tensor.data_ptr()),
                                in_tensor.numel(),
                                ncclDataTypeEnum.from_torch(in_tensor.dtype),
130
131
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))
132
        return out_tensor
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def all_gather(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = self.stream
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
            cudaStream_t(stream.cuda_stream))

    def reduce_scatter(self,
                       output_tensor: torch.Tensor,
                       input_tensor: torch.Tensor,
                       op: ReduceOp = ReduceOp.SUM,
                       stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = self.stream
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), self.comm,
            cudaStream_t(stream.cuda_stream))

176
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
177
178
179
180
181
182
183
184
185
186
187
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = self.stream
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

188
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
189
190
191
192
193
194
195
196
197
198
199
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = self.stream
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = self.stream
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
        self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype), src,
                                self.comm, cudaStream_t(stream.cuda_stream))

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    @contextmanager
    def change_state(self,
                     enable: Optional[bool] = None,
                     stream: Optional[torch.cuda.Stream] = None):
        """
        A context manager to change the state of the communicator.
        """
        if enable is None:
            # guess a default value when not specified
            enable = self.available

        if stream is None:
            stream = self.stream

        old_disable = self.disabled
        old_stream = self.stream

        self.stream = stream
        self.disabled = not enable
        yield

        self.disabled = old_disable
        self.stream = old_stream