communication_op.py 6.36 KB
Newer Older
1
2
3
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union

4
5
6
import torch

from vllm.model_executor.parallel_utils.parallel_state import (
7
    get_tensor_model_parallel_rank,
8
9
10
11
12
    get_tensor_model_parallel_world_size,
    get_tensor_model_parallel_group,
)


13
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
14
15
    """All-reduce the input tensor across model parallel group.

16
    NOTE: This operation is applied in-place on the input tensor.
17
18
19
20
21
    """
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_
    # All-reduce.
22
23
    torch.distributed.all_reduce(input_,
                                 group=get_tensor_model_parallel_group())
24
25
26
    return input_


27
28
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    """All-gather the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    torch.distributed.all_gather_into_tensor(
        output_tensor, input_, group=get_tensor_model_parallel_group())
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor
53
54


55
56
57
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    """Gather the input tensor across model parallel group.

    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    """
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # Allocate output tensor.
    if get_tensor_model_parallel_rank() == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=dst,
                             group=get_tensor_model_parallel_group())
    if get_tensor_model_parallel_rank() == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor


89
def broadcast(input_: torch.Tensor, src: int = 0):
90
91
92
93
94
95
96
97
98
99
100
101
    """Broadcast the input tensor."""
    world_size = torch.distributed.get_world_size()
    assert 0 <= src < world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    # Broadcast.
    torch.distributed.broadcast(input_, src=src)
    return input_


102
def broadcast_object_list(obj_list: List[Any], src: int = 0):
103
104
105
106
107
108
109
110
111
112
    """Broadcast the input object list."""
    world_size = torch.distributed.get_world_size()
    assert 0 <= src < world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return obj_list
    # Broadcast.
    torch.distributed.broadcast_object_list(obj_list, src=src)
    return obj_list
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169


TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])


def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
                                                                Any]]] = None,
                          src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]:
    """Broadcast the input tensor dictionary."""
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    assert 0 <= src < world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return tensor_dict

    if rank == src:
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
        metadata_list = []
        for key, value in tensor_dict.items():
            if isinstance(value, torch.Tensor):
                assert value.is_cuda, (
                    f"Tensor {key}: {value} is not on cuda. Currently we only "
                    f"support broadcasting tensors on cuda.")
                metadata_list.append(
                    (key, TensorMetadata(value.dtype, value.size())))
            else:
                metadata_list.append((key, value))
        torch.distributed.broadcast_object_list([metadata_list], src=src)
        for key, value in metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = tensor_dict[key]
                torch.distributed.broadcast(tensor, src=src)
    else:
        recv_metadata_list = [None]
        torch.distributed.broadcast_object_list(recv_metadata_list, src=src)
        metadata_list = recv_metadata_list[0]
        tensor_dict = {}
        async_handles = []
        for key, value in metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device="cuda")
                async_handle = torch.distributed.broadcast(tensor,
                                                           src=src,
                                                           async_op=True)
                async_handles.append(async_handle)
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict