# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Contains helpers that are applied to functions. This is similar in concept to the `functools` module. """ import asyncio import concurrent.futures import inspect import threading import warnings from collections.abc import Awaitable, Callable, Mapping from functools import lru_cache, partial, wraps from typing import Any, TypeVar from typing_extensions import ParamSpec from vllm.logger import init_logger logger = init_logger(__name__) P = ParamSpec("P") T = TypeVar("T") F = TypeVar("F", bound=Callable[..., Any]) def identity(value: T, **kwargs) -> T: """Returns the first provided value.""" return value def make_async( func: Callable[P, T], executor: concurrent.futures.Executor | None = None, ) -> Callable[P, Awaitable[T]]: """ Take a blocking function, and run it on in an executor thread. This function prevents the blocking function from blocking the asyncio event loop. The code in this function needs to be thread safe. """ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]: loop = asyncio.get_event_loop() p_func = partial(func, *args, **kwargs) return loop.run_in_executor(executor=executor, func=p_func) return _async_wrapper def run_once(f: Callable[P, None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: if wrapper.has_run: # type: ignore[attr-defined] return with wrapper.lock: # type: ignore[attr-defined] if not wrapper.has_run: # type: ignore[attr-defined] wrapper.has_run = True # type: ignore[attr-defined] return f(*args, **kwargs) wrapper.has_run = False # type: ignore[attr-defined] wrapper.lock = threading.Lock() # type: ignore[attr-defined] return wrapper def deprecate_args( start_index: int, is_deprecated: bool | Callable[[], bool] = True, additional_message: str | None = None, ) -> Callable[[F], F]: if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: params = inspect.signature(fn).parameters pos_types = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): deprecated_args = pos_kws[start_index : len(args)] if deprecated_args: msg = ( f"The positional arguments {deprecated_args} are " "deprecated and will be removed in a future update." ) if additional_message is not None: msg += f" {additional_message}" warnings.warn( DeprecationWarning(msg), stacklevel=3, # The inner function takes up one level ) return fn(*args, **kwargs) return inner # type: ignore return wrapper def deprecate_kwargs( *kws: str, is_deprecated: bool | Callable[[], bool] = True, additional_message: str | None = None, ) -> Callable[[F], F]: deprecated_kws = set(kws) if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): deprecated_kwargs = kwargs.keys() & deprecated_kws if deprecated_kwargs: msg = ( f"The keyword arguments {deprecated_kwargs} are " "deprecated and will be removed in a future update." ) if additional_message is not None: msg += f" {additional_message}" warnings.warn( DeprecationWarning(msg), stacklevel=3, # The inner function takes up one level ) return fn(*args, **kwargs) return inner # type: ignore return wrapper @lru_cache def supports_kw( callable: Callable[..., object], kw_name: str, *, requires_kw_only: bool = False, allow_var_kwargs: bool = True, ) -> bool: """Check if a keyword is a valid kwarg for a callable; if requires_kw_only disallows kwargs names that can also be positional arguments. """ params = inspect.signature(callable).parameters if not params: return False param_val = params.get(kw_name) # Types where the it may be valid, i.e., explicitly defined & nonvariadic passable_kw_types = set( ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) ) if param_val: is_sig_param = param_val.kind in passable_kw_types # We want kwargs only, but this is passable as a positional arg if ( requires_kw_only and is_sig_param and param_val.kind != inspect.Parameter.KEYWORD_ONLY ): return False if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( not requires_kw_only and is_sig_param ): return True # If we're okay with var-kwargs, it's supported as long as # the kw_name isn't something like *args, **kwargs if allow_var_kwargs: # Get the last param; type is ignored here because params is a proxy # mapping, but it wraps an ordered dict, and they appear in order. # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters last_param = params[next(reversed(params))] # type: ignore return ( last_param.kind == inspect.Parameter.VAR_KEYWORD and last_param.name != kw_name ) return False def get_allowed_kwarg_only_overrides( callable: Callable[..., object], overrides: Mapping[str, object] | None, *, requires_kw_only: bool = True, allow_var_kwargs: bool = False, ) -> dict[str, Any]: """ Given a callable which has one or more keyword only params and a dict mapping param names to values, drop values that can be not be kwarg expanded to overwrite one or more keyword-only args. This is used in a few places to handle custom processor overrides for multimodal models, e.g., for profiling when processor options provided by the user may affect the number of mm tokens per instance. Args: callable: Callable which takes 0 or more keyword only arguments. If None is provided, all overrides names are allowed. overrides: Potential overrides to be used when invoking the callable. allow_var_kwargs: Allows overrides that are expandable for var kwargs. Returns: Dictionary containing the kwargs to be leveraged which may be used to overwrite one or more keyword only arguments when invoking the callable. """ if not overrides: return {} # Drop any mm_processor_kwargs provided by the user that # are not kwargs, unless it can fit it var_kwargs param filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() if supports_kw( callable, kwarg_name, requires_kw_only=requires_kw_only, allow_var_kwargs=allow_var_kwargs, ) } # If anything is dropped, log a warning dropped_keys = overrides.keys() - filtered_overrides.keys() if dropped_keys: if requires_kw_only: logger.warning( "The following intended overrides are not keyword-only args " "and will be dropped: %s", dropped_keys, ) else: logger.warning( "The following intended overrides are not keyword args " "and will be dropped: %s", dropped_keys, ) return filtered_overrides