pynccl.py 8.93 KB
Newer Older
1
from typing import Optional, Union
2
3
4
5

# ===================== import region =====================
import torch
import torch.distributed as dist
6
from torch.distributed import ProcessGroup, ReduceOp
7

8
9
10
from vllm.distributed.device_communicators.pynccl_wrapper import (
    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
    ncclRedOpTypeEnum, ncclUniqueId)
11
from vllm.distributed.utils import StatelessProcessGroup
12
13
14
from vllm.logger import init_logger

logger = init_logger(__name__)
15
16


17
class PyNcclCommunicator:
18
19
20

    def __init__(
        self,
21
        group: Union[ProcessGroup, StatelessProcessGroup],
22
        device: Union[int, str, torch.device],
23
        library_path: Optional[str] = None,
24
    ):
25
26
27
28
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
29
            device: the device to bind the PyNcclCommunicator to. If None,
30
                it will be bind to f"cuda:{local_rank}".
31
32
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
33
34
35
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
36
37
38
39
40
41
42
43
44
45
46
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
                "PyNcclCommunicator should be attached to a non-NCCL group.")
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

47
        self.group = group
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

        # if world_size == 1, no need to create communicator
        if self.world_size == 1:
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

68
        if self.rank == 0:
69
70
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
71
        else:
72
73
            # construct an empty unique id
            self.unique_id = ncclUniqueId()
74
75
76
77
78
79
80
81
82
83
84

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
85
        if isinstance(device, int):
86
87
88
89
90
91
92
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
93
94
95
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
96
97
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)
98

99
            stream = torch.cuda.current_stream()
100
            # A small all_reduce for warmup.
101
102
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
103
            stream.synchronize()
104
            del data
105

106
    def all_reduce(self,
107
                   in_tensor: torch.Tensor,
108
                   op: ReduceOp = ReduceOp.SUM,
109
                   stream=None) -> torch.Tensor:
110
        if self.disabled:
111
            return None
112
113
114
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
115
        assert in_tensor.device == self.device, (
116
            f"this nccl communicator is created to work on {self.device}, "
117
118
119
120
            f"but the input tensor is on {in_tensor.device}")

        out_tensor = torch.empty_like(in_tensor)

121
        if stream is None:
122
            stream = torch.cuda.current_stream()
123
124
125
126
        self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                                buffer_type(out_tensor.data_ptr()),
                                in_tensor.numel(),
                                ncclDataTypeEnum.from_torch(in_tensor.dtype),
127
128
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))
129
        return out_tensor
130

131
132
133
134
135
136
137
138
139
140
141
142
143
    def all_gather(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
144
            stream = torch.cuda.current_stream()
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
            cudaStream_t(stream.cuda_stream))

    def reduce_scatter(self,
                       output_tensor: torch.Tensor,
                       input_tensor: torch.Tensor,
                       op: ReduceOp = ReduceOp.SUM,
                       stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
165
            stream = torch.cuda.current_stream()
166
167
168
169
170
171
172
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), self.comm,
            cudaStream_t(stream.cuda_stream))

173
    def send(self, tensor: torch.Tensor, dst: int, stream=None):
174
175
176
177
178
179
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
180
            stream = torch.cuda.current_stream()
181
182
183
184
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

185
    def recv(self, tensor: torch.Tensor, src: int, stream=None):
186
187
188
189
190
191
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
192
            stream = torch.cuda.current_stream()
193
194
195
196
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

197
198
199
200
201
202
203
    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
204
            stream = torch.cuda.current_stream()
205
206
207
208
209
210
211
212
213
214
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
        self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype), src,
                                self.comm, cudaStream_t(stream.cuda_stream))