dp_utils.py 8.49 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4

5
6
7
8
9
import numpy as np
import torch
import torch.distributed as dist

from vllm.config import ParallelConfig
10
from vllm.distributed.parallel_state import get_dp_group
11
12
13
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
    check_ubatch_thresholds,
14
    is_last_ubatch_empty,
15
16
17
18
19
20
)

logger = init_logger(__name__)


def _get_device_and_group(parallel_config: ParallelConfig):
21
22
    # Use the actual device assigned to the DP group, not just the device type
    device = get_dp_group().device
23
24
    group = get_dp_group().device_group

25
    # Transferring this tensor from GPU to CPU will introduce a GPU sync
26
27
28
29
    # point that could adversely affect performance of vllm with asynch
    # scheduling. This environment variable exists to quickly disable
    # this optimization if we run into this case.
    if parallel_config.disable_nccl_for_dp_synchronization:
30
        logger.info_once(
31
            "Using CPU all reduce to synchronize DP padding between ranks.",
32
        )
33
34
35
36
37
38
39
40
41
        device = "cpu"
        group = get_dp_group().cpu_group
    return device, group


def _run_ar(
    should_ubatch: bool,
    orig_num_tokens_per_ubatch: int,
    padded_num_tokens_per_ubatch: int,
42
    cudagraph_mode: int,
43
44
45
46
47
    parallel_config: ParallelConfig,
) -> torch.Tensor:
    dp_size = parallel_config.data_parallel_size
    dp_rank = parallel_config.data_parallel_rank
    device, group = _get_device_and_group(parallel_config)
48
    tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
49
50
51
    tensor[0][dp_rank] = orig_num_tokens_per_ubatch
    tensor[1][dp_rank] = padded_num_tokens_per_ubatch
    tensor[2][dp_rank] = 1 if should_ubatch else 0
52
    tensor[3][dp_rank] = cudagraph_mode
53
54
55
56
    dist.all_reduce(tensor, group=group)
    return tensor


57
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
58
59
60
61
62
63
64
65
66
67
68
    orig_num_tokens_tensor = tensor[0, :]
    padded_num_tokens_tensor = tensor[1, :]

    # First determine if we are going to be ubatching.
    should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
    if not should_ubatch:
        return False
    # If the DP ranks are planning to ubatch, make sure that
    # there are no "empty" second ubatches
    orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
    padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
69
    if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
70
71
72
73
74
75
76
        logger.debug(
            "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
        )
        should_ubatch = False
    return should_ubatch


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
    num_tokens_across_dp = tensor[1, :]
    if should_dp_pad:
        # If DP padding is enabled, ensure that each rank is processing the same number
        # of tokens
        max_num_tokens = int(num_tokens_across_dp.max().item())
        return torch.tensor(
            [max_num_tokens] * len(num_tokens_across_dp),
            device="cpu",
            dtype=torch.int32,
        )
    else:
        return num_tokens_across_dp.cpu()


92
93
94
95
96
97
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
    """
    Synchronize cudagraph_mode across DP ranks by taking the minimum.
    If any rank has NONE (0), all ranks use NONE.
    This ensures all ranks send consistent values (all padded or all unpadded).
    """
98
    return int(tensor[3, :].min().item())
99
100


101
102
103
104
def _synchronize_dp_ranks(
    num_tokens_unpadded: int,
    num_tokens_padded: int,
    should_attempt_ubatching: bool,
105
    cudagraph_mode: int,
106
    parallel_config: ParallelConfig,
107
) -> tuple[bool, torch.Tensor | None, int]:
108
109
110
111
112
    """
    1. Decides if each DP rank is going to microbatch. Either all ranks
    run with microbatching or none of them do.

    2. Determines the total number of tokens that each rank will run.
113
114
    When running microbatched or if cudagraph is enabled (synced across ranks),
    all ranks will be padded out so that they run with the same number of tokens.
115

116
117
    3. Synchronizes cudagraph_mode across ranks by taking the minimum.

118
119
120
    Returns: tuple[
        should_ubatch: Are all DP ranks going to microbatch
        num_tokens_after_padding: A tensor containing the total number of
121
        tokens per-microbatch for each DP rank including any DP padding.
122
        synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
123
124
125
126
127
    ]

    """
    assert num_tokens_padded >= num_tokens_unpadded

128
    # Coordinate between the DP ranks via an All Reduce
129
130
131
132
133
134
    # to determine the total number of tokens that each rank
    # will run and if we are using ubatching or not.
    tensor = _run_ar(
        should_ubatch=should_attempt_ubatching,
        orig_num_tokens_per_ubatch=num_tokens_unpadded,
        padded_num_tokens_per_ubatch=num_tokens_padded,
135
        cudagraph_mode=cudagraph_mode,
136
137
138
        parallel_config=parallel_config,
    )

139
140
141
142
    # Synchronize cudagraph_mode across ranks first (take min).
    # This is needed before DP padding decision since we use the synced
    # cudagraph mode to determine whether DP padding is needed.
    synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
143

144
    # Check conditions for microbatching
145
    should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
146

147
148
149
150
151
152
    # DP padding is needed when cudagraph is enabled (synced across ranks)
    # or when ubatching/DBO is active (ubatching requires uniform batch
    # sizes across DP ranks currently).
    # Use the synced runtime cudagraph mode rather than the compilation config
    # so we can avoid padding when cudagraph is not enabled for this step.
    should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch
153
154
155
156
157
158
159
160

    # Pad all DP ranks up to the maximum token count across ranks if
    # should_dp_pad is True
    num_tokens_after_padding = _post_process_dp_padding(
        tensor,
        should_dp_pad,
    )

161
    return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
162
163
164
165
166


def coordinate_batch_across_dp(
    num_tokens_unpadded: int,
    allow_microbatching: bool,
167
    parallel_config: ParallelConfig,
168
169
170
    num_tokens_padded: int | None = None,
    uniform_decode: bool | None = None,
    num_scheduled_tokens_per_request: np.ndarray | None = None,
171
172
    cudagraph_mode: int = 0,
) -> tuple[bool, torch.Tensor | None, int]:
173
174
175
176
    """
    Coordinates amongst all DP ranks to determine if and how the full batch
    should be split into microbatches.

177
178
179
180
181
182
183
184
185
186
    Args:
        num_tokens_unpadded: Number of tokens without accounting for padding
        allow_microbatching: If microbatching should be attempted
        parallel_config: The parallel config
        num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
            TP, etc)
        uniform_decode: Only used if allow_microbatching is True. True if the batch
            only contains single token decodes
        num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
            number of tokens per request.
187
188
        cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
            DP padding is enabled when synced cudagraph mode across ranks is not NONE.
189

190
191
192
193
    Returns: tuple[
        ubatch_slices: if this is set then all DP ranks have agreed to
        microbatch
        num_tokens_after_padding: A tensor containing the total number of
194
        tokens per-microbatch for each DP rank including padding. Will be
195
        padded up to the max value across all DP ranks when cudagraph is enabled.
196
        synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
197
198
199
200
201
    ]

    """
    if parallel_config.data_parallel_size == 1:
        # Early exit.
202
        return False, None, cudagraph_mode
203

204
205
206
207
208
209
210
211
212
213
    # If the caller has explicitly enabled microbatching.
    should_attempt_ubatching = False
    if allow_microbatching:
        # Check preconditions for microbatching
        assert uniform_decode is not None
        should_attempt_ubatching = check_ubatch_thresholds(
            parallel_config,
            num_tokens_unpadded,
            uniform_decode=uniform_decode,
        )
214

215
216
    if num_tokens_padded is None:
        num_tokens_padded = num_tokens_unpadded
217

218
219
220
221
222
223
224
225
    (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
        _synchronize_dp_ranks(
            num_tokens_unpadded,
            num_tokens_padded,
            should_attempt_ubatching,
            cudagraph_mode,
            parallel_config,
        )
226
227
    )

228
    return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)