communication_op.py 13 KB
Newer Older
1
from collections import namedtuple
2
from contextlib import contextmanager, nullcontext
3
from dataclasses import dataclass
4
from typing import Any, Dict, List, Optional, Tuple, Union
5

6
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from torch.distributed import ProcessGroup
8

9
10
from .parallel_state import (get_cpu_world_group,
                             get_tensor_model_parallel_group,
11
12
                             get_tensor_model_parallel_rank,
                             get_tensor_model_parallel_world_size,
13
                             get_tp_ca_communicator,
14
15
16
                             get_tp_pynccl_communicator)


17
18
19
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
20
21
22
23
24


@contextmanager
def graph_capture():
    """
25
    `graph_capture` is a context manager which should surround the code that
26
27
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
28
29
30
31
32
33
34
35
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
36
    """
37
38
    stream = torch.cuda.Stream()
    graph_capture_context = GraphCaptureContext(stream)
39
    ca_comm = get_tp_ca_communicator()
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
    with torch.cuda.stream(stream), maybe_ca_context:
        # In graph mode, we have to be very careful about the collective
        # operations. The current status is:
        #     allreduce \ Mode   |  Eager  |  Graph  |
        # --------------------------------------------
        # custom allreduce       | enabled | enabled |
        # PyNccl                 | disabled| enabled |
        # torch.distributed      | enabled | disabled|
        #
        # Note that custom allreduce will have a runtime check, if the tensor
        #  size is too large, it will fallback to the next available option.
        # In summary: When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
        # to PyTorch or pynccl if it is disabled or not supported.
        pynccl_comm = get_tp_pynccl_communicator()
        if pynccl_comm is None:
            maybe_pynccl_context = nullcontext()
        else:
            maybe_pynccl_context = pynccl_comm.change_state(
                enable=True, stream=torch.cuda.current_stream())
        with maybe_pynccl_context:
            yield graph_capture_context
65
66


67
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
68
69
    """All-reduce the input tensor across model parallel group.

70
71
72
73
74
75
76
    NOTE: This operation will be applied in-place on the input tensor if
    disable_custom_all_reduce is set to True. Otherwise, this operation may or
    may not be applied in place depending on whether custom all reduce is
    invoked for a particular tensor, which further depends on the tensor size
    and GPU topology.

    TLDR: always assume this function modifies its input, but use the return
77
    value as the output.
78
    """
79
    ca_comm = get_tp_ca_communicator()
80

81
82
83
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_
84
85
86
87
    if ca_comm is not None:
        out = ca_comm.custom_all_reduce(input_)
        if out is not None:
            return out
88
89
90
    pynccl_comm = get_tp_pynccl_communicator()
    if (pynccl_comm is not None and not pynccl_comm.disabled):
        pynccl_comm.all_reduce(input_)
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
    else:
        torch.distributed.all_reduce(input_,
                                     group=get_tensor_model_parallel_group())
94
95
96
    return input_


97
98
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    """All-gather the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    torch.distributed.all_gather_into_tensor(
        output_tensor, input_, group=get_tensor_model_parallel_group())
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor
123
124


125
126
127
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    """Gather the input tensor across model parallel group.

    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    """
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # Allocate output tensor.
    if get_tensor_model_parallel_rank() == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=dst,
                             group=get_tensor_model_parallel_group())
    if get_tensor_model_parallel_rank() == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor


159
160
161
def broadcast(input_: torch.Tensor,
              src: int = 0,
              group: Optional[ProcessGroup] = None):
162
    """Broadcast the input tensor."""
163
164
165
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
166
167

    # Bypass the function if we are using only 1 GPU.
168
    world_size = torch.distributed.get_world_size(group=group)
169
170
171
    if world_size == 1:
        return input_
    # Broadcast.
172
    torch.distributed.broadcast(input_, src=src, group=group)
173
174
175
    return input_


176
177
178
def broadcast_object_list(obj_list: List[Any],
                          src: int = 0,
                          group: Optional[ProcessGroup] = None):
179
    """Broadcast the input object list."""
180
181
182
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
183
184

    # Bypass the function if we are using only 1 GPU.
185
    world_size = torch.distributed.get_world_size(group=group)
186
187
188
    if world_size == 1:
        return obj_list
    # Broadcast.
189
    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
190
    return obj_list
191
192


193
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
194
195


196
197
198
199
200
201
202
203
204
205
206
207
def _split_tensor_dict(
    tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
    metadata_list = []
    tensor_list = []
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
208
209
210
211
212
213
214
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
            device = "cpu" if value.is_cpu else "cuda"
            metadata_list.append(
                (key, TensorMetadata(device, value.dtype, value.size())))
215
216
217
218
219
220
            tensor_list.append(value)
        else:
            metadata_list.append((key, value))
    return metadata_list, tensor_list


221
222
223
224
def broadcast_tensor_dict(
    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
    src: int = 0,
    group: Optional[ProcessGroup] = None,
225
    metadata_group: Optional[ProcessGroup] = None
226
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
227
228
229
230
231
    """Broadcast the input tensor dictionary.
    `group` is used to broadcast the tensors, while `metadata_group` is used
     to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
     dtypes).
    """
232
233
234
235
236
    # Bypass the function if we are using only 1 GPU.
    if (not torch.distributed.is_initialized()
            or torch.distributed.get_world_size(group=group) == 1):
        return tensor_dict

237
    group = group or torch.distributed.group.WORLD
238
    metadata_group = metadata_group or get_cpu_world_group()
239
240
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
241

242
    rank = torch.distributed.get_rank()
243
    if rank == src:
244
        metadata_list: List[Tuple[Any, Any]] = []
245
246
247
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
248
249
250
251
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `broadcast_object_list` involves serialization and deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
252
253
        torch.distributed.broadcast_object_list([metadata_list],
                                                src=src,
254
                                                group=metadata_group)
255
        async_handles = []
256
        for tensor in tensor_list:
257
258
259
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                continue
260
261
262
263
264
265
266
267
268
269
270
271
272
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=metadata_group,
                                                     async_op=True)
            else:
                # use group for GPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=group,
                                                     async_op=True)
            async_handles.append(handle)
273
274
275
        for async_handle in async_handles:
            async_handle.wait()

276
277
    else:
        recv_metadata_list = [None]
278
279
        torch.distributed.broadcast_object_list(recv_metadata_list,
                                                src=src,
280
                                                group=metadata_group)
281
        assert recv_metadata_list[0] is not None
282
283
        tensor_dict = {}
        async_handles = []
284
        for key, value in recv_metadata_list[0]:
285
286
287
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
288
                                     device=value.device)
289
290
291
292
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    tensor_dict[key] = tensor
                    continue
293
294
295
296
297
298
299
300
301
302
303
304
305
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
306
307
308
309
310
311
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict