ubatch_utils.py 790 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

from typing_extensions import TypeAlias


@dataclass
class UBatchSlice:
    request_slice: slice
    token_slice: slice

13
14
15
16
17
18
19
20
    def is_empty(self) -> bool:
        return self.request_slice.start == self.request_slice.stop \
            or self.token_slice.start == self.token_slice.stop

    @property
    def num_tokens(self) -> int:
        return self.token_slice.stop - self.token_slice.start

21
22
23
24
25
26
27

UBatchSlices: TypeAlias = list[UBatchSlice]


def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int,
                           padded_num_tokens_per_ubatch: int) -> bool:
    return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch