communication_op.py 7.66 KB
Newer Older
1
2
3
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union

4
5
from torch.distributed import ProcessGroup

6
7
8
import torch

from vllm.model_executor.parallel_utils.parallel_state import (
9
    get_tensor_model_parallel_rank,
10
11
12
    get_tensor_model_parallel_world_size,
    get_tensor_model_parallel_group,
)
13
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
14
15


16
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
17
18
    """All-reduce the input tensor across model parallel group.

19
20
21
22
23
24
25
26
    NOTE: This operation will be applied in-place on the input tensor if
    disable_custom_all_reduce is set to True. Otherwise, this operation may or
    may not be applied in place depending on whether custom all reduce is
    invoked for a particular tensor, which further depends on the tensor size
    and GPU topology.

    TLDR: always assume this function modifies its input, but use the return
    value as the output. 
27
28
29
30
    """
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_
31
32
33
    out = custom_all_reduce(input_)
    if out is not None:
        return out
34
35
    torch.distributed.all_reduce(input_,
                                 group=get_tensor_model_parallel_group())
36
37
38
    return input_


39
40
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    """All-gather the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    torch.distributed.all_gather_into_tensor(
        output_tensor, input_, group=get_tensor_model_parallel_group())
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor
65
66


67
68
69
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    """Gather the input tensor across model parallel group.

    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    """
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # Allocate output tensor.
    if get_tensor_model_parallel_rank() == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=dst,
                             group=get_tensor_model_parallel_group())
    if get_tensor_model_parallel_rank() == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor


101
102
103
def broadcast(input_: torch.Tensor,
              src: int = 0,
              group: Optional[ProcessGroup] = None):
104
    """Broadcast the input tensor."""
105
106
107
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
108
109

    # Bypass the function if we are using only 1 GPU.
110
    world_size = torch.distributed.get_world_size(group=group)
111
112
113
    if world_size == 1:
        return input_
    # Broadcast.
114
    torch.distributed.broadcast(input_, src=src, group=group)
115
116
117
    return input_


118
119
120
def broadcast_object_list(obj_list: List[Any],
                          src: int = 0,
                          group: Optional[ProcessGroup] = None):
121
    """Broadcast the input object list."""
122
123
124
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
125
126

    # Bypass the function if we are using only 1 GPU.
127
    world_size = torch.distributed.get_world_size(group=group)
128
129
130
    if world_size == 1:
        return obj_list
    # Broadcast.
131
    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
132
    return obj_list
133
134
135
136
137


TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])


138
139
140
141
142
def broadcast_tensor_dict(
    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
    src: int = 0,
    group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
143
    """Broadcast the input tensor dictionary."""
144
145
146
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
147
148

    # Bypass the function if we are using only 1 GPU.
149
    world_size = torch.distributed.get_world_size(group=group)
150
151
152
    if world_size == 1:
        return tensor_dict

153
    rank = torch.distributed.get_rank()
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    if rank == src:
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
        metadata_list = []
        for key, value in tensor_dict.items():
            if isinstance(value, torch.Tensor):
                assert value.is_cuda, (
                    f"Tensor {key}: {value} is not on cuda. Currently we only "
                    f"support broadcasting tensors on cuda.")
                metadata_list.append(
                    (key, TensorMetadata(value.dtype, value.size())))
            else:
                metadata_list.append((key, value))
168
169
170
        torch.distributed.broadcast_object_list([metadata_list],
                                                src=src,
                                                group=group)
171
172
173
174
175
176
        for key, value in metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = tensor_dict[key]
                torch.distributed.broadcast(tensor, src=src)
    else:
        recv_metadata_list = [None]
177
178
179
        torch.distributed.broadcast_object_list(recv_metadata_list,
                                                src=src,
                                                group=group)
180
181
182
183
184
185
186
187
188
189
        metadata_list = recv_metadata_list[0]
        tensor_dict = {}
        async_handles = []
        for key, value in metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device="cuda")
                async_handle = torch.distributed.broadcast(tensor,
                                                           src=src,
190
191
                                                           async_op=True,
                                                           group=group)
192
193
194
195
196
197
198
                async_handles.append(async_handle)
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict