dp_utils.py 8.93 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import numpy as np
import torch
import torch.distributed as dist

from vllm.config import ParallelConfig
9
from vllm.distributed.parallel_state import get_dp_group
10
11
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
12
    UBatchSlice,
13
14
15
16
17
18
19
20
21
22
    UBatchSlices,
    check_ubatch_thresholds,
    create_ubatch_slices,
    is_second_ubatch_empty,
)

logger = init_logger(__name__)


def _get_device_and_group(parallel_config: ParallelConfig):
23
24
    # Use the actual device assigned to the DP group, not just the device type
    device = get_dp_group().device
25
26
    group = get_dp_group().device_group

27
    # Transferring this tensor from GPU to CPU will introduce a GPU sync
28
29
30
31
    # point that could adversely affect performance of vllm with asynch
    # scheduling. This environment variable exists to quickly disable
    # this optimization if we run into this case.
    if parallel_config.disable_nccl_for_dp_synchronization:
32
33
34
        logger.info_once(
            "Using CPU all reduce to synchronize DP padding between ranks."
        )
35
36
37
38
39
40
41
        device = "cpu"
        group = get_dp_group().cpu_group
    return device, group


def _run_ar(
    should_ubatch: bool,
42
    should_dp_pad: bool,
43
44
45
46
47
48
49
    orig_num_tokens_per_ubatch: int,
    padded_num_tokens_per_ubatch: int,
    parallel_config: ParallelConfig,
) -> torch.Tensor:
    dp_size = parallel_config.data_parallel_size
    dp_rank = parallel_config.data_parallel_rank
    device, group = _get_device_and_group(parallel_config)
50
    tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
51
52
53
    tensor[0][dp_rank] = orig_num_tokens_per_ubatch
    tensor[1][dp_rank] = padded_num_tokens_per_ubatch
    tensor[2][dp_rank] = 1 if should_ubatch else 0
54
    tensor[3][dp_rank] = 1 if should_dp_pad else 0
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    dist.all_reduce(tensor, group=group)
    return tensor


def _post_process_ubatch(tensor: torch.Tensor) -> bool:
    orig_num_tokens_tensor = tensor[0, :]
    padded_num_tokens_tensor = tensor[1, :]

    # First determine if we are going to be ubatching.
    should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
    if not should_ubatch:
        return False
    # If the DP ranks are planning to ubatch, make sure that
    # there are no "empty" second ubatches
    orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
    padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
    if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
        logger.debug(
            "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
        )
        should_ubatch = False
    return should_ubatch


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
    num_tokens_across_dp = tensor[1, :]
    if should_dp_pad:
        # If DP padding is enabled, ensure that each rank is processing the same number
        # of tokens
        max_num_tokens = int(num_tokens_across_dp.max().item())
        return torch.tensor(
            [max_num_tokens] * len(num_tokens_across_dp),
            device="cpu",
            dtype=torch.int32,
        )
    else:
        return num_tokens_across_dp.cpu()


94
95
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
96
97
98
99
def _pad_out_ubatch_slice(
    ubatch_slices: UBatchSlices, num_total_tokens: int
) -> UBatchSlices:
    padded_second_token_slice = slice(
100
101
102
        ubatch_slices[1].token_slice.start, num_total_tokens
    )
    ubatch_slices[1] = UBatchSlice(
103
        ubatch_slices[1].request_slice, padded_second_token_slice
104
    )
105
    return ubatch_slices
106
107


108
109
110
111
def _synchronize_dp_ranks(
    num_tokens_unpadded: int,
    num_tokens_padded: int,
    should_attempt_ubatching: bool,
112
    should_attempt_dp_padding: bool,
113
    parallel_config: ParallelConfig,
114
) -> tuple[bool, torch.Tensor | None]:
115
116
117
118
119
    """
    1. Decides if each DP rank is going to microbatch. Either all ranks
    run with microbatching or none of them do.

    2. Determines the total number of tokens that each rank will run.
120
121
    When running microbatched or if should_attempt_dp_padding is True, all
    ranks will be padded out so that the run with the same number of tokens
122
123
124
125

    Returns: tuple[
        should_ubatch: Are all DP ranks going to microbatch
        num_tokens_after_padding: A tensor containing the total number of
126
        tokens per-microbatch for each DP rank including any DP padding.
127
128
129
130
131
    ]

    """
    assert num_tokens_padded >= num_tokens_unpadded

132
    # Coordinate between the DP ranks via an All Reduce
133
134
135
136
    # to determine the total number of tokens that each rank
    # will run and if we are using ubatching or not.
    tensor = _run_ar(
        should_ubatch=should_attempt_ubatching,
137
        should_dp_pad=should_attempt_dp_padding,
138
139
140
141
142
        orig_num_tokens_per_ubatch=num_tokens_unpadded,
        padded_num_tokens_per_ubatch=num_tokens_padded,
        parallel_config=parallel_config,
    )

143
144
145
146
    should_dp_pad = bool(torch.all(tensor[3] == 1).item())

    # DP ranks should all have the same value for should_attempt_dp_padding.
    assert should_attempt_dp_padding == should_dp_pad
147

148
    # Check conditions for microbatching
149
150
    should_ubatch = _post_process_ubatch(tensor)

151
    if should_ubatch and not should_dp_pad:
152
153
154
155
156
157
        logger.debug_once(
            "Microbatching has been triggered and requires DP padding. "
            "Enabling DP padding even though it has been explicitly "
            "disabled.",
            scope="global",
        )
158
159
160
161
162
163
164
165
166
        should_dp_pad = True

    # Pad all DP ranks up to the maximum token count across ranks if
    # should_dp_pad is True
    num_tokens_after_padding = _post_process_dp_padding(
        tensor,
        should_dp_pad,
    )

167
168
169
170
171
172
    return should_ubatch, num_tokens_after_padding


def coordinate_batch_across_dp(
    num_tokens_unpadded: int,
    allow_microbatching: bool,
173
174
    allow_dp_padding: bool,
    parallel_config: ParallelConfig,
175
176
177
178
    num_tokens_padded: int | None = None,
    uniform_decode: bool | None = None,
    num_scheduled_tokens_per_request: np.ndarray | None = None,
) -> tuple[UBatchSlices | None, torch.Tensor | None]:
179
180
181
182
    """
    Coordinates amongst all DP ranks to determine if and how the full batch
    should be split into microbatches.

183
184
185
186
187
188
189
190
191
192
193
194
    Args:
        num_tokens_unpadded: Number of tokens without accounting for padding
        allow_microbatching: If microbatching should be attempted
        allow_dp_padding: If all DP ranks should be padded up to the same value
        parallel_config: The parallel config
        num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
            TP, etc)
        uniform_decode: Only used if allow_microbatching is True. True if the batch
            only contains single token decodes
        num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
            number of tokens per request.

195
196
197
198
    Returns: tuple[
        ubatch_slices: if this is set then all DP ranks have agreed to
        microbatch
        num_tokens_after_padding: A tensor containing the total number of
199
200
201
        tokens per-microbatch for each DP rank including padding. Will be
        padded up to the max value across all DP ranks when allow_dp_padding
        is True.
202
203
204
205
206
207
208
    ]

    """
    if parallel_config.data_parallel_size == 1:
        # Early exit.
        return None, None

209
210
211
212
213
214
215
216
217
218
    # If the caller has explicitly enabled microbatching.
    should_attempt_ubatching = False
    if allow_microbatching:
        # Check preconditions for microbatching
        assert uniform_decode is not None
        should_attempt_ubatching = check_ubatch_thresholds(
            parallel_config,
            num_tokens_unpadded,
            uniform_decode=uniform_decode,
        )
219

220
221
    if num_tokens_padded is None:
        num_tokens_padded = num_tokens_unpadded
222
223
224
225
226

    (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
        num_tokens_unpadded,
        num_tokens_padded,
        should_attempt_ubatching,
227
        allow_dp_padding,
228
229
230
231
232
233
234
235
236
237
238
239
        parallel_config,
    )

    # Don't microbatch unless every other DP worker is also microbatching
    if not should_ubatch:
        return (None, num_tokens_after_padding)

    # This doesn't actually pad the ubatch slices. It just initializes the
    # split point to the padded value so that padding can be applied
    # to the second ubatch in pad_out_ubatch_slice after attention
    # metadata creation
    assert num_tokens_after_padding is not None
240
241
    num_tokens_padded = int(num_tokens_after_padding[0].item())
    token_split_point = int(num_tokens_padded) // 2
242

243
    assert num_scheduled_tokens_per_request is not None
244
245
246
    ubatch_slices = create_ubatch_slices(
        num_scheduled_tokens_per_request, token_split_point
    )
247
248
    ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
    assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
249
250

    return (ubatch_slices, num_tokens_after_padding)