Unverified Commit d2740faf authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Separate out `vllm.utils.collections` (#26990)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 17838e50
......@@ -19,7 +19,8 @@ import numpy as np
import torch
from typing_extensions import assert_never
from vllm.utils import LazyLoader, is_list_of
from vllm.utils import LazyLoader
from vllm.utils.collections import is_list_of
from .audio import AudioResampler
from .inputs import (
......@@ -364,7 +365,7 @@ class MultiModalDataParser:
if isinstance(data, torch.Tensor):
return data.ndim == 3
if is_list_of(data, torch.Tensor):
return data[0].ndim == 2
return data[0].ndim == 2 # type: ignore[index]
return False
......@@ -422,6 +423,7 @@ class MultiModalDataParser:
if self._is_embeddings(data):
return AudioEmbeddingItems(data)
data_items: list[AudioItem]
if (
is_list_of(data, float)
or isinstance(data, (np.ndarray, torch.Tensor))
......@@ -432,7 +434,7 @@ class MultiModalDataParser:
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
data_items = data # type: ignore[assignment]
new_audios = list[np.ndarray]()
for data_item in data_items:
......@@ -485,6 +487,7 @@ class MultiModalDataParser:
if self._is_embeddings(data):
return VideoEmbeddingItems(data)
data_items: list[VideoItem]
if (
is_list_of(data, PILImage.Image)
or isinstance(data, (np.ndarray, torch.Tensor))
......@@ -496,7 +499,7 @@ class MultiModalDataParser:
elif isinstance(data, tuple) and len(data) == 2:
data_items = [data]
else:
data_items = data
data_items = data # type: ignore[assignment]
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
metadata_lst: list[dict[str, Any] | None] = []
......
......@@ -25,7 +25,7 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
from vllm.utils import flatten_2d_lists, full_groupby
from vllm.utils.collections import flatten_2d_lists, full_groupby
from vllm.utils.functools import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
......@@ -484,8 +484,11 @@ _M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp)
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
"""Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
based on modality."""
"""
Convenience function to apply
[`full_groupby`][vllm.utils.collections.full_groupby]
based on modality.
"""
return full_groupby(values, key=lambda x: x.modality)
......
......@@ -9,7 +9,7 @@ import torch.nn as nn
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
from vllm.utils import ClassRegistry
from vllm.utils.collections import ClassRegistry
from .cache import BaseMultiModalProcessorCache
from .processing import (
......
......@@ -8,7 +8,8 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any
from vllm.logger import init_logger
from vllm.utils import import_from_path, is_list_of
from vllm.utils import import_from_path
from vllm.utils.collections import is_list_of
if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import (
......
......@@ -37,29 +37,19 @@ from argparse import (
RawDescriptionHelpFormatter,
_ArgumentGroup,
)
from collections import UserDict, defaultdict
from collections import defaultdict
from collections.abc import (
Callable,
Collection,
Generator,
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
)
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
TextIO,
TypeVar,
)
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
from urllib.parse import urlparse
from uuid import uuid4
......@@ -78,7 +68,7 @@ import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from typing_extensions import Never, TypeIs, assert_never
from typing_extensions import Never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
......@@ -170,9 +160,6 @@ def set_default_torch_num_threads(num_threads: int):
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class Device(enum.Enum):
GPU = enum.auto()
......@@ -421,12 +408,6 @@ def update_environment_variables(envs: dict[str, str]):
os.environ[k] = v
def chunk_list(lst: list[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
......@@ -743,53 +724,6 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
)
def as_list(maybe_list: Iterable[T]) -> list[T]:
"""Convert iterable to list, unless it's already a list."""
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
return [obj] # type: ignore[list-item]
return obj
# `collections` helpers
def is_list_of(
value: object,
typ: type[T] | tuple[type[T], ...],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[list[T]]:
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike [`itertools.groupby`][], groups are not broken by
non-contiguous data.
"""
groups = defaultdict[_K, list[_V]](list)
for value in values:
groups[key(value)].append(value)
return groups.items()
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
......@@ -1578,50 +1512,6 @@ class AtomicCounter:
return self._value
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: dict[str, Callable[[], T]]):
self._factory = factory
self._dict: dict[str, T] = {}
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __setitem__(self, key: str, value: Callable[[], T]):
self._factory[key] = value
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
......@@ -2588,22 +2478,6 @@ class LazyLoader(types.ModuleType):
return dir(self._module)
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
"""
Helper function to swap values for two keys
"""
v1 = obj.get(key1)
v2 = obj.get(key2)
if v1 is not None:
obj[key2] = v1
else:
obj.pop(key2, None)
if v2 is not None:
obj[key1] = v2
else:
obj.pop(key1, None)
@contextlib.contextmanager
def cprofile_context(save_file: str | None = None):
"""Run a cprofile
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers that are applied to collections.
This is similar in concept to the `collections` module.
"""
from collections import UserDict, defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class ClassRegistry(UserDict[type[T], _V]):
"""
A registry that acts like a dictionary but searches for other classes
in the MRO if the original class is not found.
"""
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
class LazyDict(Mapping[str, T], Generic[T]):
"""
Evaluates dictionary items only when they are accessed.
Adapted from: https://stackoverflow.com/a/47212782/5082708
"""
def __init__(self, factory: dict[str, Callable[[], T]]):
self._factory = factory
self._dict: dict[str, T] = {}
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __setitem__(self, key: str, value: Callable[[], T]):
self._factory[key] = value
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)
def as_list(maybe_list: Iterable[T]) -> list[T]:
"""Convert iterable to list, unless it's already a list."""
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
return [obj] # type: ignore[list-item]
return obj
def is_list_of(
value: object,
typ: type[T] | tuple[type[T], ...],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[list[T]]:
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike [`itertools.groupby`][], groups are not broken by
non-contiguous data.
"""
groups = defaultdict[_K, list[_V]](list)
for value in values:
groups[key(value)].append(value)
return groups.items()
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
"""Swap values between two keys."""
v1 = obj.get(key1)
v2 = obj.get(key2)
if v1 is not None:
obj[key2] = v1
else:
obj.pop(key2, None)
if v2 is not None:
obj[key1] = v2
else:
obj.pop(key1, None)
......@@ -29,8 +29,9 @@ from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, as_list, cdiv
from vllm.utils import Device, cdiv
from vllm.utils.asyncio import cancel_task_threadsafe
from vllm.utils.collections import as_list
from vllm.utils.functools import deprecate_kwargs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
......
......@@ -12,7 +12,8 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collections import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (
......
......@@ -9,7 +9,8 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collections import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment