dp_utils.py 8.28 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import torch.distributed as dist

from vllm.config import ParallelConfig
8
from vllm.distributed.parallel_state import get_dp_group
9
10
11
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
    check_ubatch_thresholds,
12
    is_last_ubatch_empty,
13
14
15
16
17
18
)

logger = init_logger(__name__)


def _get_device_and_group(parallel_config: ParallelConfig):
19
20
    # Use the actual device assigned to the DP group, not just the device type
    device = get_dp_group().device
21
22
    group = get_dp_group().device_group

23
    # Transferring this tensor from GPU to CPU will introduce a GPU sync
24
25
26
27
    # point that could adversely affect performance of vllm with asynch
    # scheduling. This environment variable exists to quickly disable
    # this optimization if we run into this case.
    if parallel_config.disable_nccl_for_dp_synchronization:
28
        logger.info_once(
29
            "Using CPU all reduce to synchronize DP padding between ranks.",
30
        )
31
32
33
34
35
36
37
38
39
        device = "cpu"
        group = get_dp_group().cpu_group
    return device, group


def _run_ar(
    should_ubatch: bool,
    orig_num_tokens_per_ubatch: int,
    padded_num_tokens_per_ubatch: int,
40
    cudagraph_mode: int,
41
42
43
44
45
    parallel_config: ParallelConfig,
) -> torch.Tensor:
    dp_size = parallel_config.data_parallel_size
    dp_rank = parallel_config.data_parallel_rank
    device, group = _get_device_and_group(parallel_config)
46
    tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
47
48
49
    tensor[0][dp_rank] = orig_num_tokens_per_ubatch
    tensor[1][dp_rank] = padded_num_tokens_per_ubatch
    tensor[2][dp_rank] = 1 if should_ubatch else 0
50
    tensor[3][dp_rank] = cudagraph_mode
51
52
53
54
    dist.all_reduce(tensor, group=group)
    return tensor


55
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
56
57
58
59
60
61
62
63
64
65
66
    orig_num_tokens_tensor = tensor[0, :]
    padded_num_tokens_tensor = tensor[1, :]

    # First determine if we are going to be ubatching.
    should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
    if not should_ubatch:
        return False
    # If the DP ranks are planning to ubatch, make sure that
    # there are no "empty" second ubatches
    orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
    padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
67
    if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
68
69
70
71
72
73
74
        logger.debug(
            "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
        )
        should_ubatch = False
    return should_ubatch


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
    num_tokens_across_dp = tensor[1, :]
    if should_dp_pad:
        # If DP padding is enabled, ensure that each rank is processing the same number
        # of tokens
        max_num_tokens = int(num_tokens_across_dp.max().item())
        return torch.tensor(
            [max_num_tokens] * len(num_tokens_across_dp),
            device="cpu",
            dtype=torch.int32,
        )
    else:
        return num_tokens_across_dp.cpu()


90
91
92
93
94
95
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
    """
    Synchronize cudagraph_mode across DP ranks by taking the minimum.
    If any rank has NONE (0), all ranks use NONE.
    This ensures all ranks send consistent values (all padded or all unpadded).
    """
96
    return int(tensor[3, :].min().item())
97
98


99
100
101
102
def _synchronize_dp_ranks(
    num_tokens_unpadded: int,
    num_tokens_padded: int,
    should_attempt_ubatching: bool,
103
    cudagraph_mode: int,
104
    parallel_config: ParallelConfig,
105
) -> tuple[bool, torch.Tensor | None, int]:
106
107
108
109
110
    """
    1. Decides if each DP rank is going to microbatch. Either all ranks
    run with microbatching or none of them do.

    2. Determines the total number of tokens that each rank will run.
111
112
    When running microbatched or if cudagraph is enabled (synced across ranks),
    all ranks will be padded out so that they run with the same number of tokens.
113

114
115
    3. Synchronizes cudagraph_mode across ranks by taking the minimum.

116
117
118
    Returns: tuple[
        should_ubatch: Are all DP ranks going to microbatch
        num_tokens_after_padding: A tensor containing the total number of
119
        tokens per-microbatch for each DP rank including any DP padding.
120
        synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
121
122
123
124
125
    ]

    """
    assert num_tokens_padded >= num_tokens_unpadded

126
    # Coordinate between the DP ranks via an All Reduce
127
128
129
130
131
132
    # to determine the total number of tokens that each rank
    # will run and if we are using ubatching or not.
    tensor = _run_ar(
        should_ubatch=should_attempt_ubatching,
        orig_num_tokens_per_ubatch=num_tokens_unpadded,
        padded_num_tokens_per_ubatch=num_tokens_padded,
133
        cudagraph_mode=cudagraph_mode,
134
135
136
        parallel_config=parallel_config,
    )

137
138
139
140
    # Synchronize cudagraph_mode across ranks first (take min).
    # This is needed before DP padding decision since we use the synced
    # cudagraph mode to determine whether DP padding is needed.
    synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
141

142
    # Check conditions for microbatching
143
    should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
144

145
146
147
148
149
150
    # DP padding is needed when cudagraph is enabled (synced across ranks)
    # or when ubatching/DBO is active (ubatching requires uniform batch
    # sizes across DP ranks currently).
    # Use the synced runtime cudagraph mode rather than the compilation config
    # so we can avoid padding when cudagraph is not enabled for this step.
    should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch
151
152
153
154
155
156
157
158

    # Pad all DP ranks up to the maximum token count across ranks if
    # should_dp_pad is True
    num_tokens_after_padding = _post_process_dp_padding(
        tensor,
        should_dp_pad,
    )

159
    return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
160
161
162
163
164


def coordinate_batch_across_dp(
    num_tokens_unpadded: int,
    allow_microbatching: bool,
165
    parallel_config: ParallelConfig,
166
167
    num_tokens_padded: int | None = None,
    uniform_decode: bool | None = None,
168
169
    cudagraph_mode: int = 0,
) -> tuple[bool, torch.Tensor | None, int]:
170
171
172
173
    """
    Coordinates amongst all DP ranks to determine if and how the full batch
    should be split into microbatches.

174
175
176
177
178
179
180
181
    Args:
        num_tokens_unpadded: Number of tokens without accounting for padding
        allow_microbatching: If microbatching should be attempted
        parallel_config: The parallel config
        num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
            TP, etc)
        uniform_decode: Only used if allow_microbatching is True. True if the batch
            only contains single token decodes
182
183
        cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
            DP padding is enabled when synced cudagraph mode across ranks is not NONE.
184

185
186
187
188
    Returns: tuple[
        ubatch_slices: if this is set then all DP ranks have agreed to
        microbatch
        num_tokens_after_padding: A tensor containing the total number of
189
        tokens per-microbatch for each DP rank including padding. Will be
190
        padded up to the max value across all DP ranks when cudagraph is enabled.
191
        synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
192
193
194
195
196
    ]

    """
    if parallel_config.data_parallel_size == 1:
        # Early exit.
197
        return False, None, cudagraph_mode
198

199
200
201
202
203
204
205
206
207
208
    # If the caller has explicitly enabled microbatching.
    should_attempt_ubatching = False
    if allow_microbatching:
        # Check preconditions for microbatching
        assert uniform_decode is not None
        should_attempt_ubatching = check_ubatch_thresholds(
            parallel_config,
            num_tokens_unpadded,
            uniform_decode=uniform_decode,
        )
209

210
211
    if num_tokens_padded is None:
        num_tokens_padded = num_tokens_unpadded
212

213
214
215
216
217
218
219
220
    (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
        _synchronize_dp_ranks(
            num_tokens_unpadded,
            num_tokens_padded,
            should_attempt_ubatching,
            cudagraph_mode,
            parallel_config,
        )
221
222
    )

223
    return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)