pynccl.py 8.89 KB
Newer Older
1
from typing import Optional, Union
2
3
4
5

# ===================== import region =====================
import torch
import torch.distributed as dist
6
from torch.distributed import ProcessGroup, ReduceOp
7

8
9
10
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
11
from vllm.distributed.utils import StatelessProcessGroup
12
from vllm.logger import init_logger
youkaichao's avatar
youkaichao committed
13
from vllm.utils import current_stream
14
15

logger = init_logger(__name__)
16
17


18
class PyNcclCommunicator:
19
20
21

    def __init__(
        self,
22
        group: Union[ProcessGroup, StatelessProcessGroup],
23
        device: Union[int, str, torch.device],
24
        library_path: Optional[str] = None,
25
    ):
26
27
28
29
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
30
            device: the device to bind the PyNcclCommunicator to. If None,
31
                it will be bind to f"cuda:{local_rank}".
32
33
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
34
35
36
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
37
38
39
40
41
42
43
44
45
46
47
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
                "PyNcclCommunicator should be attached to a non-NCCL group.")
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

48
        self.group = group
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

69
        if self.rank == 0:
70
71
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
72
        else:
73
74
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
75
76
77
78
79
80
81
82
83
84
85

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
86
        if isinstance(device, int):
87
88
89
90
91
92
93
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
94
95
96
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
97
98
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
99

youkaichao's avatar
youkaichao committed
100
            stream = current_stream()
101
            # A small all_reduce for warmup.
102
103
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
104
            stream.synchronize()
105
            del data
106

107
    def all_reduce(self,
108
                   in_tensor: torch.Tensor,
109
                   op: ReduceOp = ReduceOp.SUM,
110
                   stream=None) -> torch.Tensor:
111
        if self.disabled:
112
            return None
113
114
115
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
116
        assert in_tensor.device == self.device, (
117
            f"this nccl communicator is created to work on {self.device}, "
118
119
120
121
            f"but the input tensor is on {in_tensor.device}")

        out_tensor = torch.empty_like(in_tensor)

122
        if stream is None:
youkaichao's avatar
youkaichao committed
123
            stream = current_stream()
124
125
126
127
        self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                                buffer_type(out_tensor.data_ptr()),
                                in_tensor.numel(),
                                ncclDataTypeEnum.from_torch(in_tensor.dtype),
128
129
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))
130
        return out_tensor
131

132
133
134
135
136
137
138
139
140
141
142
143
144
    def all_gather(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
145
            stream = current_stream()
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
            cudaStream_t(stream.cuda_stream))

    def reduce_scatter(self,
                       output_tensor: torch.Tensor,
                       input_tensor: torch.Tensor,
                       op: ReduceOp = ReduceOp.SUM,
                       stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
166
            stream = current_stream()
167
168
169
170
171
172
173
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), self.comm,
            cudaStream_t(stream.cuda_stream))

174
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
175
176
177
178
179
180
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
181
            stream = current_stream()
182
183
184
185
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

186
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
187
188
189
190
191
192
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
193
            stream = current_stream()
194
195
196
197
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

198
199
200
201
202
203
204
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
youkaichao's avatar
youkaichao committed
205
            stream = current_stream()
206
207
208
209
210
211
212
213
214
215
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
        self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype), src,
                                self.comm, cudaStream_t(stream.cuda_stream))