ubatch_utils.py 3.62 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
4
from typing import TypeAlias
5

6
import numpy as np
7

8
9
from vllm.config import ParallelConfig

10
11
12
13
14
15

@dataclass
class UBatchSlice:
    request_slice: slice
    token_slice: slice

16
    def is_empty(self) -> bool:
17
18
        return (
            self.request_slice.start == self.request_slice.stop
19
            or self.token_slice.start == self.token_slice.stop
20
        )
21
22
23
24
25

    @property
    def num_tokens(self) -> int:
        return self.token_slice.stop - self.token_slice.start

26
27
28
29

UBatchSlices: TypeAlias = list[UBatchSlice]


30
31
32
33
def is_last_ubatch_empty(
    orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
) -> bool:
    return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
34
35
36
37


def check_ubatch_thresholds(
    config: ParallelConfig, num_tokens: int, uniform_decode: bool
38
) -> bool:
39
    if not config.use_ubatching:
40
41
42
43
44
45
46
        return False
    if uniform_decode:
        return num_tokens >= config.dbo_decode_token_threshold
    else:
        return num_tokens >= config.dbo_prefill_token_threshold


47
# This pads the last ubatch slice out to the total number of tokens
48
49
50
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices(
    ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
51
) -> UBatchSlices:
52
53
54
55
56
57
    last_slice = ubatch_slices[-1]
    padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
    padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)

    return ubatch_slices[:-1] + [
        UBatchSlice(padded_last_request_slice, padded_last_token_slice)
58
59
60
61
62
63
64
65
    ]


def maybe_create_ubatch_slices(
    should_ubatch: bool,
    num_scheduled_tokens: np.ndarray,
    num_tokens_padded: int,
    num_reqs_padded: int,
66
67
    num_ubatches: int,
    split_point: list[int] | int | None = None,
68
69
70
71
72
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
    if not should_ubatch:
        return None, None

    if split_point is None:
73
74
75
        split_point = int(num_tokens_padded) // num_ubatches

    token_split_points = [split_point * i for i in range(1, num_ubatches)]
76

77
78
79
80
81
    # TODO(lucas): Refactor the gpu_model_runner.py so we can pass
    # in cu_num_tokens directly (i.e. query_start_loc)
    cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
    np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])

82
83
    ubatch_slices = []
    start_token = 0
84

85
86
    # Add the end point to the split points to make iteration easier
    all_points = token_split_points + [cu_num_tokens[-1]]
87

88
89
    for end_token in all_points:
        token_slice = slice(start_token, end_token)
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
        # Determine request slices using exclusive stop semantics
        # Ubatch includes requests whose tokens overlap [start_token, end_token)

        # Start at the request that contains the start_token
        # or the request starting exactly at start_token (if on boundary)
        req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)

        # Stop at the request that starts at or after end_token
        req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))

        req_slice = slice(req_start, req_stop)
        ubatch_slices.append(UBatchSlice(req_slice, token_slice))

        start_token = end_token
105
106
107
108
109
110
111
112

    ubatch_slices_padded = _pad_out_ubatch_slices(
        ubatch_slices, num_tokens_padded, num_reqs_padded
    )

    assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded

    return ubatch_slices, ubatch_slices_padded