ubatch_utils.py 791 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

from typing_extensions import TypeAlias


@dataclass
class UBatchSlice:
    request_slice: slice
    token_slice: slice

13
    def is_empty(self) -> bool:
14
15
        return (
            self.request_slice.start == self.request_slice.stop
16
            or self.token_slice.start == self.token_slice.stop
17
        )
18
19
20
21
22

    @property
    def num_tokens(self) -> int:
        return self.token_slice.stop - self.token_slice.start

23
24
25
26

UBatchSlices: TypeAlias = list[UBatchSlice]


27
28
29
def is_second_ubatch_empty(
    orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int
) -> bool:
30
    return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch