communication_op.py 13.3 KB
Newer Older
1
from collections import namedtuple
2
from contextlib import contextmanager, nullcontext
3
from dataclasses import dataclass
4
from typing import Any, Dict, List, Optional, Tuple, Union
5

6
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from torch.distributed import ProcessGroup
8

9
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
10
                             get_tensor_model_parallel_group,
11
12
                             get_tensor_model_parallel_rank,
                             get_tensor_model_parallel_world_size,
13
                             get_tp_ca_communicator,
14
15
16
                             get_tp_pynccl_communicator)


17
18
19
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
20
21
22
23
24


@contextmanager
def graph_capture():
    """
25
    `graph_capture` is a context manager which should surround the code that
26
27
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
28
29
30
31
32
33
34
35
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
36
    """
37
38
    stream = torch.cuda.Stream()
    graph_capture_context = GraphCaptureContext(stream)
39
    ca_comm = get_tp_ca_communicator()
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
    with torch.cuda.stream(stream), maybe_ca_context:
        # In graph mode, we have to be very careful about the collective
        # operations. The current status is:
        #     allreduce \ Mode   |  Eager  |  Graph  |
        # --------------------------------------------
        # custom allreduce       | enabled | enabled |
        # PyNccl                 | disabled| enabled |
        # torch.distributed      | enabled | disabled|
        #
        # Note that custom allreduce will have a runtime check, if the tensor
        #  size is too large, it will fallback to the next available option.
        # In summary: When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
        # to PyTorch or pynccl if it is disabled or not supported.
57
58
59
60
        tp_pynccl_comm = get_tp_pynccl_communicator()
        pp_pynccl_comm = get_pp_pynccl_communicator()
        if not tp_pynccl_comm:
            maybe_tp_pynccl_context = nullcontext()
61
        else:
62
            maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
63
                enable=True, stream=torch.cuda.current_stream())
64
65
66
67
68
69
        if not pp_pynccl_comm:
            maybe_pp_pynccl_context = nullcontext()
        else:
            maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
                enable=True, stream=torch.cuda.current_stream())
        with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
70
            yield graph_capture_context
71
72


73
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
74
75
    """All-reduce the input tensor across model parallel group.

76
77
78
79
80
81
82
    NOTE: This operation will be applied in-place on the input tensor if
    disable_custom_all_reduce is set to True. Otherwise, this operation may or
    may not be applied in place depending on whether custom all reduce is
    invoked for a particular tensor, which further depends on the tensor size
    and GPU topology.

    TLDR: always assume this function modifies its input, but use the return
83
    value as the output.
84
    """
85
    ca_comm = get_tp_ca_communicator()
86

87
88
89
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_
90
91
92
93
    if ca_comm is not None:
        out = ca_comm.custom_all_reduce(input_)
        if out is not None:
            return out
94
95
96
    pynccl_comm = get_tp_pynccl_communicator()
    if (pynccl_comm is not None and not pynccl_comm.disabled):
        pynccl_comm.all_reduce(input_)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
    else:
        torch.distributed.all_reduce(input_,
                                     group=get_tensor_model_parallel_group())
100
101
102
    return input_


103
104
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    """All-gather the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    torch.distributed.all_gather_into_tensor(
        output_tensor, input_, group=get_tensor_model_parallel_group())
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor
129
130


131
132
133
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    """Gather the input tensor across model parallel group.

    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    """
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # Allocate output tensor.
    if get_tensor_model_parallel_rank() == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=dst,
                             group=get_tensor_model_parallel_group())
    if get_tensor_model_parallel_rank() == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor


165
166
167
def broadcast(input_: torch.Tensor,
              src: int = 0,
              group: Optional[ProcessGroup] = None):
168
    """Broadcast the input tensor."""
169
170
171
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
172
173

    # Bypass the function if we are using only 1 GPU.
174
    world_size = torch.distributed.get_world_size(group=group)
175
176
177
    if world_size == 1:
        return input_
    # Broadcast.
178
    torch.distributed.broadcast(input_, src=src, group=group)
179
180
181
    return input_


182
183
184
def broadcast_object_list(obj_list: List[Any],
                          src: int = 0,
                          group: Optional[ProcessGroup] = None):
185
    """Broadcast the input object list."""
186
187
188
    group = group or torch.distributed.group.WORLD
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
189
190

    # Bypass the function if we are using only 1 GPU.
191
    world_size = torch.distributed.get_world_size(group=group)
192
193
194
    if world_size == 1:
        return obj_list
    # Broadcast.
195
    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
196
    return obj_list
197
198


199
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
200
201


202
203
204
205
206
207
208
209
210
211
212
213
def _split_tensor_dict(
    tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
    metadata_list = []
    tensor_list = []
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
214
215
216
217
218
219
220
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
            device = "cpu" if value.is_cpu else "cuda"
            metadata_list.append(
                (key, TensorMetadata(device, value.dtype, value.size())))
221
222
223
224
225
226
            tensor_list.append(value)
        else:
            metadata_list.append((key, value))
    return metadata_list, tensor_list


227
228
229
230
def broadcast_tensor_dict(
    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
    src: int = 0,
    group: Optional[ProcessGroup] = None,
231
    metadata_group: Optional[ProcessGroup] = None
232
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
233
234
235
236
237
    """Broadcast the input tensor dictionary.
    `group` is used to broadcast the tensors, while `metadata_group` is used
     to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
     dtypes).
    """
238
239
240
241
242
    # Bypass the function if we are using only 1 GPU.
    if (not torch.distributed.is_initialized()
            or torch.distributed.get_world_size(group=group) == 1):
        return tensor_dict

243
    group = group or torch.distributed.group.WORLD
244
    metadata_group = metadata_group or get_cpu_world_group()
245
246
    ranks = torch.distributed.get_process_group_ranks(group)
    assert src in ranks, f"Invalid src rank ({src})"
247

248
    rank = torch.distributed.get_rank()
249
    if rank == src:
250
        metadata_list: List[Tuple[Any, Any]] = []
251
252
253
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
254
255
256
257
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `broadcast_object_list` involves serialization and deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
258
259
        torch.distributed.broadcast_object_list([metadata_list],
                                                src=src,
260
                                                group=metadata_group)
261
        async_handles = []
262
        for tensor in tensor_list:
263
264
265
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                continue
266
267
268
269
270
271
272
273
274
275
276
277
278
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=metadata_group,
                                                     async_op=True)
            else:
                # use group for GPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=src,
                                                     group=group,
                                                     async_op=True)
            async_handles.append(handle)
279
280
281
        for async_handle in async_handles:
            async_handle.wait()

282
283
    else:
        recv_metadata_list = [None]
284
285
        torch.distributed.broadcast_object_list(recv_metadata_list,
                                                src=src,
286
                                                group=metadata_group)
287
        assert recv_metadata_list[0] is not None
288
289
        tensor_dict = {}
        async_handles = []
290
        for key, value in recv_metadata_list[0]:
291
292
293
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
294
                                     device=value.device)
295
296
297
298
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    tensor_dict[key] = tensor
                    continue
299
300
301
302
303
304
305
306
307
308
309
310
311
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
                                                         src=src,
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
312
313
314
315
316
317
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict