tqdm_utils.py 831 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Sequence
from typing import Any, TypeVar, overload

from tqdm.auto import tqdm

_T = TypeVar("_T", bound=Iterable)


@overload
def maybe_tqdm(
    it: Sequence[_T],
    *,
    use_tqdm: bool | Callable[..., tqdm],
    **tqdm_kwargs: Any,
) -> Sequence[_T]: ...


@overload
def maybe_tqdm(
    it: Iterable[_T],
    *,
    use_tqdm: bool | Callable[..., tqdm],
    **tqdm_kwargs: Any,
) -> Iterable[_T]: ...


def maybe_tqdm(
    it: Iterable[_T],
    *,
    use_tqdm: bool | Callable[..., tqdm],
    **tqdm_kwargs: Any,
) -> Iterable[_T]:
    if not use_tqdm:
        return it

    tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
    return tqdm_func(it, **tqdm_kwargs)