communication_op.py 10.5 KB
Newer Older
1
from collections import namedtuple
2
from typing import Any, Dict, List, Optional, Tuple, Union
3

4
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
5
from torch.distributed import ProcessGroup
6

7
8
from .parallel_state import (get_cpu_world_group,
                             get_tensor_model_parallel_group,
9
10
11
                             get_tensor_model_parallel_rank,
                             get_tensor_model_parallel_world_size,
                             is_pynccl_enabled_for_all_reduce)
12
13


14
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
15
16
    """All-reduce the input tensor across model parallel group.

17
18
19
20
21
22
23
    NOTE: This operation will be applied in-place on the input tensor if
    disable_custom_all_reduce is set to True. Otherwise, this operation may or
    may not be applied in place depending on whether custom all reduce is
    invoked for a particular tensor, which further depends on the tensor size
    and GPU topology.

    TLDR: always assume this function modifies its input, but use the return
24
    value as the output.
25
    """
26
27
28
29
    from vllm.distributed.device_communicators import pynccl_utils
    from vllm.distributed.device_communicators.custom_all_reduce import (
        custom_all_reduce)

30
31
32
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_
33
34
35
    out = custom_all_reduce(input_)
    if out is not None:
        return out
36
37
    if is_pynccl_enabled_for_all_reduce():
        pynccl_utils.all_reduce(input_)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
    else:
        torch.distributed.all_reduce(input_,
                                     group=get_tensor_model_parallel_group())
41
42
43
    return input_


44
45
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    """All-gather the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    torch.distributed.all_gather_into_tensor(
        output_tensor, input_, group=get_tensor_model_parallel_group())
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor
70
71


72
73
74
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    """Gather the input tensor across model parallel group.

    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    """
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # Allocate output tensor.
    if get_tensor_model_parallel_rank() == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=dst,
                             group=get_tensor_model_parallel_group())
    if get_tensor_model_parallel_rank() == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor


106
107
108
def broadcast(input_: torch.Tensor,
              src: int = 0,
              group: Optional[ProcessGroup] = None):
109
    """Broadcast the input tensor."""
110
111
112
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
113
114

    # Bypass the function if we are using only 1 GPU.
115
    world_size = torch.distributed.get_world_size(group=group)
116
117
118
    if world_size == 1:
        return input_
    # Broadcast.
119
    torch.distributed.broadcast(input_, src=src, group=group)
120
121
122
    return input_


123
124
125
def broadcast_object_list(obj_list: List[Any],
                          src: int = 0,
                          group: Optional[ProcessGroup] = None):
126
    """Broadcast the input object list."""
127
128
129
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
130
131

    # Bypass the function if we are using only 1 GPU.
132
    world_size = torch.distributed.get_world_size(group=group)
133
134
135
    if world_size == 1:
        return obj_list
    # Broadcast.
136
    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
137
    return obj_list
138
139


140
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
141
142


143
144
145
146
147
148
149
150
151
152
153
154
def _split_tensor_dict(
    tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
    metadata_list = []
    tensor_list = []
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
155
156
157
158
159
160
161
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
            device = "cpu" if value.is_cpu else "cuda"
            metadata_list.append(
                (key, TensorMetadata(device, value.dtype, value.size())))
162
163
164
165
166
167
            tensor_list.append(value)
        else:
            metadata_list.append((key, value))
    return metadata_list, tensor_list


168
169
170
171
def broadcast_tensor_dict(
    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
    src: int = 0,
    group: Optional[ProcessGroup] = None,
172
    metadata_group: Optional[ProcessGroup] = None
173
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
174
175
176
177
178
    """Broadcast the input tensor dictionary.
    `group` is used to broadcast the tensors, while `metadata_group` is used
     to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
     dtypes).
    """
179
    group = group or torch.distributed.group.WORLD
180
    metadata_group = metadata_group or get_cpu_world_group()
181
182
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
183
184

    # Bypass the function if we are using only 1 GPU.
185
    world_size = torch.distributed.get_world_size(group=group)
186
187
188
    if world_size == 1:
        return tensor_dict

189
    rank = torch.distributed.get_rank()
190
    if rank == src:
191
        metadata_list: List[Tuple[Any, Any]] = []
192
193
194
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
195
196
197
198
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `broadcast_object_list` involves serialization and deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
199
200
        torch.distributed.broadcast_object_list([metadata_list],
                                                src=src,
201
                                                group=metadata_group)
202
        async_handles = []
203
        for tensor in tensor_list:
204
205
206
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                continue
207
208
209
210
211
212
213
214
215
216
217
218
219
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=metadata_group,
                                                     async_op=True)
            else:
                # use group for GPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=group,
                                                     async_op=True)
            async_handles.append(handle)
220
221
222
        for async_handle in async_handles:
            async_handle.wait()

223
224
    else:
        recv_metadata_list = [None]
225
226
        torch.distributed.broadcast_object_list(recv_metadata_list,
                                                src=src,
227
                                                group=metadata_group)
228
        assert recv_metadata_list[0] is not None
229
230
        tensor_dict = {}
        async_handles = []
231
        for key, value in recv_metadata_list[0]:
232
233
234
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
235
                                     device=value.device)
236
237
238
239
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    tensor_dict[key] = tensor
                    continue
240
241
242
243
244
245
246
247
248
249
250
251
252
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
253
254
255
256
257
258
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict