communication_op.py 11.2 KB
Newer Older
1
from collections import namedtuple
2
from contextlib import contextmanager
3
from typing import Any, Dict, List, Optional, Tuple, Union
4

5
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
6
from torch.distributed import ProcessGroup
7

8
9
from .parallel_state import (get_cpu_world_group,
                             get_tensor_model_parallel_group,
10
11
                             get_tensor_model_parallel_rank,
                             get_tensor_model_parallel_world_size,
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
                             get_tp_pynccl_communicator)


@contextmanager
def graph_capture_mode():
    # In graph capture, we have to be very careful about the collective
    # operations. The current status is:
    #     allreduce \ Mode   |  Eager  |  Graph  |
    # --------------------------------------------
    # custom allreduce       | enabled | enabled |
    # PyNccl                 | disabled| enabled |
    # torch.distributed      | enabled | disabled|
    #
    # Note that custom allreduce will have a runtime check, if the tensor size
    # is too large, it will fallback to the next available option.
    pynccl_comm = get_tp_pynccl_communicator()
    assert pynccl_comm is not None
    with pynccl_comm.change_state(enable=True,
                                  stream=torch.cuda.current_stream()):
        yield
32
33


34
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
35
36
    """All-reduce the input tensor across model parallel group.

37
38
39
40
41
42
43
    NOTE: This operation will be applied in-place on the input tensor if
    disable_custom_all_reduce is set to True. Otherwise, this operation may or
    may not be applied in place depending on whether custom all reduce is
    invoked for a particular tensor, which further depends on the tensor size
    and GPU topology.

    TLDR: always assume this function modifies its input, but use the return
44
    value as the output.
45
    """
46
47
48
    from vllm.distributed.device_communicators.custom_all_reduce import (
        custom_all_reduce)

49
50
51
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_
52
53
54
    out = custom_all_reduce(input_)
    if out is not None:
        return out
55
56
57
    pynccl_comm = get_tp_pynccl_communicator()
    if (pynccl_comm is not None and not pynccl_comm.disabled):
        pynccl_comm.all_reduce(input_)
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
    else:
        torch.distributed.all_reduce(input_,
                                     group=get_tensor_model_parallel_group())
61
62
63
    return input_


64
65
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    """All-gather the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    torch.distributed.all_gather_into_tensor(
        output_tensor, input_, group=get_tensor_model_parallel_group())
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor
90
91


92
93
94
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    """Gather the input tensor across model parallel group.

    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    """
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # Allocate output tensor.
    if get_tensor_model_parallel_rank() == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=dst,
                             group=get_tensor_model_parallel_group())
    if get_tensor_model_parallel_rank() == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor


126
127
128
def broadcast(input_: torch.Tensor,
              src: int = 0,
              group: Optional[ProcessGroup] = None):
129
    """Broadcast the input tensor."""
130
131
132
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
133
134

    # Bypass the function if we are using only 1 GPU.
135
    world_size = torch.distributed.get_world_size(group=group)
136
137
138
    if world_size == 1:
        return input_
    # Broadcast.
139
    torch.distributed.broadcast(input_, src=src, group=group)
140
141
142
    return input_


143
144
145
def broadcast_object_list(obj_list: List[Any],
                          src: int = 0,
                          group: Optional[ProcessGroup] = None):
146
    """Broadcast the input object list."""
147
148
149
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
150
151

    # Bypass the function if we are using only 1 GPU.
152
    world_size = torch.distributed.get_world_size(group=group)
153
154
155
    if world_size == 1:
        return obj_list
    # Broadcast.
156
    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
157
    return obj_list
158
159


160
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
161
162


163
164
165
166
167
168
169
170
171
172
173
174
def _split_tensor_dict(
    tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
    metadata_list = []
    tensor_list = []
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
175
176
177
178
179
180
181
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
            device = "cpu" if value.is_cpu else "cuda"
            metadata_list.append(
                (key, TensorMetadata(device, value.dtype, value.size())))
182
183
184
185
186
187
            tensor_list.append(value)
        else:
            metadata_list.append((key, value))
    return metadata_list, tensor_list


188
189
190
191
def broadcast_tensor_dict(
    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
    src: int = 0,
    group: Optional[ProcessGroup] = None,
192
    metadata_group: Optional[ProcessGroup] = None
193
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
194
195
196
197
198
    """Broadcast the input tensor dictionary.
    `group` is used to broadcast the tensors, while `metadata_group` is used
     to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
     dtypes).
    """
199
    group = group or torch.distributed.group.WORLD
200
    metadata_group = metadata_group or get_cpu_world_group()
201
202
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
203
204

    # Bypass the function if we are using only 1 GPU.
205
    world_size = torch.distributed.get_world_size(group=group)
206
207
208
    if world_size == 1:
        return tensor_dict

209
    rank = torch.distributed.get_rank()
210
    if rank == src:
211
        metadata_list: List[Tuple[Any, Any]] = []
212
213
214
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
215
216
217
218
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `broadcast_object_list` involves serialization and deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
219
220
        torch.distributed.broadcast_object_list([metadata_list],
                                                src=src,
221
                                                group=metadata_group)
222
        async_handles = []
223
        for tensor in tensor_list:
224
225
226
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                continue
227
228
229
230
231
232
233
234
235
236
237
238
239
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=metadata_group,
                                                     async_op=True)
            else:
                # use group for GPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=group,
                                                     async_op=True)
            async_handles.append(handle)
240
241
242
        for async_handle in async_handles:
            async_handle.wait()

243
244
    else:
        recv_metadata_list = [None]
245
246
        torch.distributed.broadcast_object_list(recv_metadata_list,
                                                src=src,
247
                                                group=metadata_group)
248
        assert recv_metadata_list[0] is not None
249
250
        tensor_dict = {}
        async_handles = []
251
        for key, value in recv_metadata_list[0]:
252
253
254
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
255
                                     device=value.device)
256
257
258
259
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    tensor_dict[key] = tensor
                    continue
260
261
262
263
264
265
266
267
268
269
270
271
272
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
273
274
275
276
277
278
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict