communication_op.py 12.1 KB
Newer Older
1
from collections import namedtuple
2
from contextlib import contextmanager, nullcontext
3
from typing import Any, Dict, List, Optional, Tuple, Union
4

5
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
6
from torch.distributed import ProcessGroup
7

8
9
from .parallel_state import (get_cpu_world_group,
                             get_tensor_model_parallel_group,
10
11
                             get_tensor_model_parallel_rank,
                             get_tensor_model_parallel_world_size,
12
                             get_tp_ca_communicator,
13
14
15
16
                             get_tp_pynccl_communicator)


@contextmanager
17
18
def graph_mode():
    # In graph mode, we have to be very careful about the collective
19
20
21
22
23
24
25
26
27
    # operations. The current status is:
    #     allreduce \ Mode   |  Eager  |  Graph  |
    # --------------------------------------------
    # custom allreduce       | enabled | enabled |
    # PyNccl                 | disabled| enabled |
    # torch.distributed      | enabled | disabled|
    #
    # Note that custom allreduce will have a runtime check, if the tensor size
    # is too large, it will fallback to the next available option.
28
29
30
31
32
    # In summary: When using CUDA graph, we use
    # either custom all-reduce kernel or pynccl. When not using CUDA
    # graph, we use either custom all-reduce kernel or PyTorch NCCL.
    # We always prioritize using custom all-reduce kernel but fall back
    # to PyTorch or pynccl if it is disabled or not supported.
33
    pynccl_comm = get_tp_pynccl_communicator()
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    if pynccl_comm is None:
        context = nullcontext()
    else:
        context = pynccl_comm.change_state(enable=True,
                                           stream=torch.cuda.current_stream())
    with context:
        yield


@contextmanager
def graph_capture():
    """
    `graph_capture` is a context manager which should include the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed.
    """
    ca_comm = get_tp_ca_communicator()
    context = nullcontext() if ca_comm is None else ca_comm.capture()
    with context:
54
        yield
55
56


57
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
58
59
    """All-reduce the input tensor across model parallel group.

60
61
62
63
64
65
66
    NOTE: This operation will be applied in-place on the input tensor if
    disable_custom_all_reduce is set to True. Otherwise, this operation may or
    may not be applied in place depending on whether custom all reduce is
    invoked for a particular tensor, which further depends on the tensor size
    and GPU topology.

    TLDR: always assume this function modifies its input, but use the return
67
    value as the output.
68
    """
69
    ca_comm = get_tp_ca_communicator()
70

71
72
73
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_
74
75
76
77
    if ca_comm is not None:
        out = ca_comm.custom_all_reduce(input_)
        if out is not None:
            return out
78
79
80
    pynccl_comm = get_tp_pynccl_communicator()
    if (pynccl_comm is not None and not pynccl_comm.disabled):
        pynccl_comm.all_reduce(input_)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
    else:
        torch.distributed.all_reduce(input_,
                                     group=get_tensor_model_parallel_group())
84
85
86
    return input_


87
88
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    """All-gather the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    torch.distributed.all_gather_into_tensor(
        output_tensor, input_, group=get_tensor_model_parallel_group())
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor
113
114


115
116
117
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    """Gather the input tensor across model parallel group.

    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    """
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # Allocate output tensor.
    if get_tensor_model_parallel_rank() == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=dst,
                             group=get_tensor_model_parallel_group())
    if get_tensor_model_parallel_rank() == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor


149
150
151
def broadcast(input_: torch.Tensor,
              src: int = 0,
              group: Optional[ProcessGroup] = None):
152
    """Broadcast the input tensor."""
153
154
155
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
156
157

    # Bypass the function if we are using only 1 GPU.
158
    world_size = torch.distributed.get_world_size(group=group)
159
160
161
    if world_size == 1:
        return input_
    # Broadcast.
162
    torch.distributed.broadcast(input_, src=src, group=group)
163
164
165
    return input_


166
167
168
def broadcast_object_list(obj_list: List[Any],
                          src: int = 0,
                          group: Optional[ProcessGroup] = None):
169
    """Broadcast the input object list."""
170
171
172
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
173
174

    # Bypass the function if we are using only 1 GPU.
175
    world_size = torch.distributed.get_world_size(group=group)
176
177
178
    if world_size == 1:
        return obj_list
    # Broadcast.
179
    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
180
    return obj_list
181
182


183
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
184
185


186
187
188
189
190
191
192
193
194
195
196
197
def _split_tensor_dict(
    tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
    metadata_list = []
    tensor_list = []
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
198
199
200
201
202
203
204
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
            device = "cpu" if value.is_cpu else "cuda"
            metadata_list.append(
                (key, TensorMetadata(device, value.dtype, value.size())))
205
206
207
208
209
210
            tensor_list.append(value)
        else:
            metadata_list.append((key, value))
    return metadata_list, tensor_list


211
212
213
214
def broadcast_tensor_dict(
    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
    src: int = 0,
    group: Optional[ProcessGroup] = None,
215
    metadata_group: Optional[ProcessGroup] = None
216
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
217
218
219
220
221
    """Broadcast the input tensor dictionary.
    `group` is used to broadcast the tensors, while `metadata_group` is used
     to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
     dtypes).
    """
222
223
224
225
226
    # Bypass the function if we are using only 1 GPU.
    if (not torch.distributed.is_initialized()
            or torch.distributed.get_world_size(group=group) == 1):
        return tensor_dict

227
    group = group or torch.distributed.group.WORLD
228
    metadata_group = metadata_group or get_cpu_world_group()
229
230
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
231

232
    rank = torch.distributed.get_rank()
233
    if rank == src:
234
        metadata_list: List[Tuple[Any, Any]] = []
235
236
237
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
238
239
240
241
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `broadcast_object_list` involves serialization and deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
242
243
        torch.distributed.broadcast_object_list([metadata_list],
                                                src=src,
244
                                                group=metadata_group)
245
        async_handles = []
246
        for tensor in tensor_list:
247
248
249
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                continue
250
251
252
253
254
255
256
257
258
259
260
261
262
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=metadata_group,
                                                     async_op=True)
            else:
                # use group for GPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=group,
                                                     async_op=True)
            async_handles.append(handle)
263
264
265
        for async_handle in async_handles:
            async_handle.wait()

266
267
    else:
        recv_metadata_list = [None]
268
269
        torch.distributed.broadcast_object_list(recv_metadata_list,
                                                src=src,
270
                                                group=metadata_group)
271
        assert recv_metadata_list[0] is not None
272
273
        tensor_dict = {}
        async_handles = []
274
        for key, value in recv_metadata_list[0]:
275
276
277
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
278
                                     device=value.device)
279
280
281
282
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    tensor_dict[key] = tensor
                    continue
283
284
285
286
287
288
289
290
291
292
293
294
295
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
296
297
298
299
300
301
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict