Commit 96ae75ad authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev

parents f9f4a735 2339d59f
import fnmatch
import os
import shutil
import signal
import tempfile
from pathlib import Path
from typing import Optional
from vllm.utils import PlaceholderModule
try:
import boto3
except ImportError:
boto3 = PlaceholderModule("boto3") # type: ignore[assignment]
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths if any(
fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def glob(s3=None,
path: str = "",
allow_pattern: Optional[list[str]] = None) -> list[str]:
"""
List full file names from S3 path and filter by allow pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full S3 paths allowed by the pattern
"""
if s3 is None:
s3 = boto3.client("s3")
bucket_name, _, paths = list_files(s3,
path=path,
allow_pattern=allow_pattern)
return [f"s3://{bucket_name}/{path}" for path in paths]
def list_files(
s3,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts = path.removeprefix('s3://').split('/')
prefix = '/'.join(parts[1:])
bucket_name = parts[0]
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj['Key'] for obj in objects.get('Contents', [])]
paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)
if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Model:
"""
A class representing a S3 model mirrored into a temporary directory.
Attributes:
s3: S3 client.
dir: The temporary created directory.
Methods:
pull_files(): Pull model from S3 to the temporary directory.
"""
def __init__(self) -> None:
self.s3 = boto3.client('s3')
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp()
def __del__(self):
self._close()
def _close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self._close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
def pull_files(self,
s3_model_path: str = "",
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
if len(files) == 0:
return
for file in files:
destination_file = self.dir + file.removeprefix(base_dir)
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.s3.download_file(bucket_name, file, destination_file)
...@@ -132,7 +132,7 @@ def get_tokenizer( ...@@ -132,7 +132,7 @@ def get_tokenizer(
if is_from_mistral_org and tokenizer_mode != "mistral": if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn( warnings.warn(
'It is strongly recommended to run mistral models with ' 'It is strongly recommended to run mistral models with '
'`--tokenizer_mode "mistral"` to ensure correct ' '`--tokenizer-mode "mistral"` to ensure correct '
'encoding and decoding.', 'encoding and decoding.',
FutureWarning, FutureWarning,
stacklevel=2) stacklevel=2)
......
...@@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0) max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer]( self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0) capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
@classmethod @classmethod
......
...@@ -314,12 +314,15 @@ class MistralTokenizer: ...@@ -314,12 +314,15 @@ class MistralTokenizer:
if regular_tokens: if regular_tokens:
decoded_list.append( decoded_list.append(
self.decode(regular_tokens)) # type: ignore self.tokenizer.decode(regular_tokens)) # type: ignore
decoded = ''.join(decoded_list) decoded = ''.join(decoded_list)
return decoded return decoded
# WARN: Outlines logits processors can overwrite this method.
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
# for more.
def decode(self, def decode(self,
ids: Union[List[int], int], ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str: skip_special_tokens: bool = True) -> str:
......
...@@ -3,6 +3,10 @@ from pathlib import Path ...@@ -3,6 +3,10 @@ from pathlib import Path
from typing import Union from typing import Union
def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')
def check_gguf_file(model: Union[str, PathLike]) -> bool: def check_gguf_file(model: Union[str, PathLike]) -> bool:
"""Check if the file is a GGUF model.""" """Check if the file is a GGUF model."""
model = Path(model) model = Path(model)
......
...@@ -6,10 +6,13 @@ import datetime ...@@ -6,10 +6,13 @@ import datetime
import enum import enum
import gc import gc
import getpass import getpass
import importlib.metadata
import importlib.util import importlib.util
import inspect import inspect
import ipaddress import ipaddress
import os import os
import re
import resource
import signal import signal
import socket import socket
import subprocess import subprocess
...@@ -21,14 +24,13 @@ import uuid ...@@ -21,14 +24,13 @@ import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal, Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, Optional, Tuple, Type, TypeVar, Union, overload)
overload)
from uuid import uuid4 from uuid import uuid4
import numpy as np import numpy as np
...@@ -52,7 +54,7 @@ logger = init_logger(__name__) ...@@ -52,7 +54,7 @@ logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios # Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/usage/compatibility_matrix.rst # Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
STR_NOT_IMPL_ENC_DEC_SWA = \ STR_NOT_IMPL_ENC_DEC_SWA = \
...@@ -154,10 +156,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = { ...@@ -154,10 +156,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
} }
P = ParamSpec('P') P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U") U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class _Sentinel: class _Sentinel:
... ...
...@@ -190,50 +194,48 @@ class Counter: ...@@ -190,50 +194,48 @@ class Counter:
self.counter = 0 self.counter = 0
class LRUCache(Generic[T]): class LRUCache(Generic[_K, _V]):
def __init__(self, capacity: int): def __init__(self, capacity: int) -> None:
self.cache: OrderedDict[Hashable, T] = OrderedDict() self.cache = OrderedDict[_K, _V]()
self.pinned_items: Set[Hashable] = set() self.pinned_items = set[_K]()
self.capacity = capacity self.capacity = capacity
def __contains__(self, key: Hashable) -> bool: def __contains__(self, key: _K) -> bool:
return key in self.cache return key in self.cache
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> T: def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key) self.cache.move_to_end(key)
return value return value
def __setitem__(self, key: Hashable, value: T) -> None: def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value) self.put(key, value)
def __delitem__(self, key: Hashable) -> None: def __delitem__(self, key: _K) -> None:
self.pop(key) self.pop(key)
def touch(self, key: Hashable) -> None: def touch(self, key: _K) -> None:
self.cache.move_to_end(key) self.cache.move_to_end(key)
def get(self, def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
key: Hashable, value: Optional[_V]
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
if key in self.cache: if key in self.cache:
value = self.cache[key] value = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
else: else:
value = default_value value = default
return value return value
def put(self, key: Hashable, value: T) -> None: def put(self, key: _K, value: _V) -> None:
self.cache[key] = value self.cache[key] = value
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def pin(self, key: Hashable) -> None: def pin(self, key: _K) -> None:
""" """
Pins a key in the cache preventing it from being Pins a key in the cache preventing it from being
evicted in the LRU order. evicted in the LRU order.
...@@ -242,13 +244,13 @@ class LRUCache(Generic[T]): ...@@ -242,13 +244,13 @@ class LRUCache(Generic[T]):
raise ValueError(f"Cannot pin key: {key} not in cache.") raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key) self.pinned_items.add(key)
def _unpin(self, key: Hashable) -> None: def _unpin(self, key: _K) -> None:
self.pinned_items.remove(key) self.pinned_items.remove(key)
def _on_remove(self, key: Hashable, value: Optional[T]): def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass pass
def remove_oldest(self, remove_pinned=False): def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache: if not self.cache:
return return
...@@ -262,17 +264,15 @@ class LRUCache(Generic[T]): ...@@ -262,17 +264,15 @@ class LRUCache(Generic[T]):
"cannot remove oldest from the cache.") "cannot remove oldest from the cache.")
else: else:
lru_key = next(iter(self.cache)) lru_key = next(iter(self.cache))
self.pop(lru_key) self.pop(lru_key) # type: ignore
def _remove_old_if_needed(self) -> None: def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
self.remove_oldest() self.remove_oldest()
def pop(self, def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value) value = self.cache.pop(key, default)
# remove from pinned items # remove from pinned items
if key in self.pinned_items: if key in self.pinned_items:
self._unpin(key) self._unpin(key)
...@@ -280,7 +280,7 @@ class LRUCache(Generic[T]): ...@@ -280,7 +280,7 @@ class LRUCache(Generic[T]):
self._on_remove(key, value) self._on_remove(key, value)
return value return value
def clear(self): def clear(self) -> None:
while len(self.cache) > 0: while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True) self.remove_oldest(remove_pinned=True)
self.cache.clear() self.cache.clear()
...@@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int: ...@@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int:
# `collections` helpers # `collections` helpers
def is_list_of( def is_list_of(
value: object, value: object,
typ: Type[T], typ: Union[type[T], tuple[type[T], ...]],
*, *,
check: Literal["first", "all"] = "first", check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]: ) -> TypeIs[List[T]]:
...@@ -843,10 +843,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: ...@@ -843,10 +843,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
""" """
Unlike :class:`itertools.groupby`, groups are not broken by Unlike :class:`itertools.groupby`, groups are not broken by
...@@ -1282,6 +1278,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, ...@@ -1282,6 +1278,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
def supports_kw( def supports_kw(
callable: Callable[..., object], callable: Callable[..., object],
kw_name: str, kw_name: str,
*,
requires_kw_only: bool = False, requires_kw_only: bool = False,
allow_var_kwargs: bool = True, allow_var_kwargs: bool = True,
) -> bool: ) -> bool:
...@@ -1326,6 +1323,8 @@ def resolve_mm_processor_kwargs( ...@@ -1326,6 +1323,8 @@ def resolve_mm_processor_kwargs(
init_kwargs: Optional[Mapping[str, object]], init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]], inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object], callable: Callable[..., object],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e., """Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
...@@ -1344,11 +1343,17 @@ def resolve_mm_processor_kwargs( ...@@ -1344,11 +1343,17 @@ def resolve_mm_processor_kwargs(
runtime_mm_kwargs = get_allowed_kwarg_only_overrides( runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, callable,
overrides=inference_kwargs, overrides=inference_kwargs,
allow_var_kwargs=allow_var_kwargs) requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Filter init time multimodal processor kwargs provided # Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides( init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs) callable,
overrides=init_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Merge the final processor kwargs, prioritizing inference # Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values. # time values over the initialization time values.
...@@ -1359,6 +1364,8 @@ def resolve_mm_processor_kwargs( ...@@ -1359,6 +1364,8 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Mapping[str, object]], overrides: Optional[Mapping[str, object]],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
...@@ -1390,16 +1397,21 @@ def get_allowed_kwarg_only_overrides( ...@@ -1390,16 +1397,21 @@ def get_allowed_kwarg_only_overrides(
for kwarg_name, val in overrides.items() for kwarg_name, val in overrides.items()
if supports_kw(callable, if supports_kw(callable,
kwarg_name, kwarg_name,
requires_kw_only=True, requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs) allow_var_kwargs=allow_var_kwargs)
} }
# If anything is dropped, log a warning # If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys() dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys: if dropped_keys:
logger.warning( if requires_kw_only:
"The following intended overrides are not keyword-only args " logger.warning(
"and and will be dropped: %s", dropped_keys) "The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and and will be dropped: %s", dropped_keys)
return filtered_overrides return filtered_overrides
...@@ -1628,6 +1640,67 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): ...@@ -1628,6 +1640,67 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
return module return module
@lru_cache(maxsize=None)
def get_vllm_optional_dependencies():
metadata = importlib.metadata.metadata("vllm")
requirements = metadata.get_all("Requires-Dist", [])
extras = metadata.get_all("Provides-Extra", [])
return {
extra: [
re.split(r";|>=|<=|==", req)[0] for req in requirements
if req.endswith(f'extra == "{extra}"')
]
for extra in extras
}
@dataclass(frozen=True)
class PlaceholderModule:
"""
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name: str
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)
def __getattr__(self, key: str):
name = self.name
try:
importlib.import_module(self.name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
msg = f"Please install vllm[{extra}] for {extra} support"
raise ImportError(msg) from exc
raise exc
raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
@dataclass(frozen=True)
class _PlaceholderModuleAttr:
module: PlaceholderModule
attr_path: str
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.module,
f"{self.attr_path}.{attr_path}")
def __getattr__(self, key: str):
getattr(self.module, f"{self.attr_path}.{key}")
raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
# create a library to hold the custom op # create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa vllm_lib = Library("vllm", "FRAGMENT") # noqa
...@@ -1655,8 +1728,18 @@ def direct_register_custom_op( ...@@ -1655,8 +1728,18 @@ def direct_register_custom_op(
library object. If you want to bind the operator to a different library, library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used. make sure the library object is alive when the operator is used.
""" """
if is_in_doc_build() or not supports_custom_op(): if is_in_doc_build():
return return
if not supports_custom_op():
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies.")
return
import torch.library import torch.library
if hasattr(torch.library, "infer_schema"): if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, schema_str = torch.library.infer_schema(op_func,
...@@ -1823,3 +1906,20 @@ def memory_profiling( ...@@ -1823,3 +1906,20 @@ def memory_profiling(
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
result.profile_time = diff.timestamp result.profile_time = diff.timestamp
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type,
(target_soft_limit, current_hard))
except ValueError as e:
logger.warning(
"Found ulimit of %s and failed to automatically increase"
"with error %s. This can cause fd limit errors like"
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n", current_soft, e)
...@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional ...@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens) hash_request_tokens)
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -83,10 +85,12 @@ class KVCacheManager: ...@@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks = [] computed_blocks = []
# TODO(rickyx): potentially we could cache this so we don't have to # The block hashes for the request may already be computed
# recompute it every time. # if the request was preempted and resumed.
block_hashes = hash_request_tokens(self.block_size, if not request.kv_block_hashes:
request.all_token_ids) request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
for block_hash in block_hashes: for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
...@@ -197,23 +201,15 @@ class KVCacheManager: ...@@ -197,23 +201,15 @@ class KVCacheManager:
f"num_tokens must be greater than 0, got {num_tokens}") f"num_tokens must be greater than 0, got {num_tokens}")
# Touch the computed blocks to make sure they won't be evicted. # Touch the computed blocks to make sure they won't be evicted.
num_evictable_computed_blocks = 0
if self.enable_caching: if self.enable_caching:
self._touch(computed_blocks) self._touch(computed_blocks)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
else: else:
assert not computed_blocks, ( assert not computed_blocks, (
"Computed blocks should be empty when " "Computed blocks should be empty when "
"prefix caching is disabled") "prefix caching is disabled")
num_required_blocks = cdiv(num_tokens, self.block_size) num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks - if (num_required_blocks > self.free_block_queue.num_free_blocks):
num_evictable_computed_blocks):
# Cannot allocate new blocks. # Cannot allocate new blocks.
return None return None
...@@ -221,8 +217,7 @@ class KVCacheManager: ...@@ -221,8 +217,7 @@ class KVCacheManager:
# preallocated blocks. # preallocated blocks.
num_new_blocks = min( num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks, num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks - self.free_block_queue.num_free_blocks,
num_evictable_computed_blocks,
# Should not exceed the maximum number of blocks per request. # Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape # This is especially because the block table has the shape
# [..., max_num_blocks_per_req]. # [..., max_num_blocks_per_req].
...@@ -242,14 +237,16 @@ class KVCacheManager: ...@@ -242,14 +237,16 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
self._cache_full_blocks( new_full_blocks = self.req_to_blocks[
request=request, request.request_id][len(computed_blocks):num_full_blocks]
blk_start_idx=len(computed_blocks), if new_full_blocks:
# The new full blocks are the full blocks that are not computed. self._cache_full_blocks(
full_blocks=self.req_to_blocks[request.request_id] request=request,
[len(computed_blocks):num_full_blocks], blk_start_idx=len(computed_blocks),
prev_block=computed_blocks[-1] if computed_blocks else None, # The new full blocks are the full blocks that are not computed.
) full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)
return new_blocks return new_blocks
...@@ -376,6 +373,8 @@ class KVCacheManager: ...@@ -376,6 +373,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata. full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain. prev_block: The previous block in the chain.
""" """
num_cached_block_hashes = len(request.kv_block_hashes)
# Update the new blocks with the block hashes through the chain. # Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None prev_block_hash_value = None
if prev_block is not None: if prev_block is not None:
...@@ -387,17 +386,35 @@ class KVCacheManager: ...@@ -387,17 +386,35 @@ class KVCacheManager:
for i, blk in enumerate(full_blocks): for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i blk_idx = blk_start_idx + i
block_tokens = request.all_token_ids[blk_idx * if blk_idx < num_cached_block_hashes:
self.block_size:(blk_idx + # The block hash may already be computed in
1) * # "get_computed_blocks" if the tokens are not generated by
self.block_size] # this request (either the prompt tokens or the previously
assert len(block_tokens) == self.block_size, ( # generated tokens with preemption). In this case we simply
f"Expected {self.block_size} tokens, got {len(block_tokens)} " # reuse the block hash.
f"at {blk_idx}th block for request " block_hash = request.kv_block_hashes[blk_idx]
f"{request.request_id}({request})") else:
# Otherwise compute the block hash and cache it in the request
# Compute the hash of the current block. # in case it will be preempted in the future.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens) start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
# Update and added the full block to the cache. # Update and added the full block to the cache.
blk.block_hash = block_hash blk.block_hash = block_hash
......
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple from typing import Any, List, NamedTuple, Optional, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
class BlockHashType(NamedTuple): class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block. """Hash value of a block (int), the token IDs in the block, and extra keys.
The reason we keep a tuple of token IDs is to make sure no hash The reason we keep a tuple of token IDs and extra keys is to make sure
collision happens when the hash value is the same. no hash collision happens when the hash value is the same.
""" """
# Hash value of the block in an integer.
hash_value: int hash_value: int
# Token IDs in the block.
token_ids: Tuple[int, ...] token_ids: Tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
@dataclass @dataclass
...@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue: ...@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return ret return ret
def hash_block_tokens(parent_block_hash: Optional[int], def generate_block_hash_extra_keys(
curr_block_token_ids: Sequence[int]) -> BlockHashType: request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return None, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set disable_mm_preprocessor_cache=False.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
return None, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
extra_keys = []
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"]
length = mm_positions[curr_mm_idx]["length"]
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
mm_start = max(0, start_token_idx - offset)
extra_keys.append((mm_hashes[curr_mm_idx], mm_start))
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx += 1
else:
# Otherwise this block is done with mm inputs.
break
else:
# This block has not reached the current mm input.
break
return tuple(extra_keys), curr_mm_idx
def hash_block_tokens(
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and """Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing prefix caching. We use LRU cache for this function to avoid recomputing
...@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int], ...@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block. if this is the first block.
curr_block_token_ids: A list of token ids in the current curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full. block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns: Returns:
The hash value of the block and the token ids in the block. The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
""" """
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
tuple(curr_block_token_ids)) tuple(curr_block_token_ids), extra_keys)
def hash_request_tokens(block_size: int, def hash_request_tokens(block_size: int,
token_ids: Sequence[int]) -> List[BlockHashType]: request: Request) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of """Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching. token IDs. The hash value is used for prefix caching.
Args: Args:
block_size: The size of each block. block_size: The size of each block.
token_ids: A sequence of token ids in the request. request: The request object.
Returns: Returns:
The list of computed hash values. The list of computed hash values.
""" """
token_ids = request.all_token_ids
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")
# TODO: Extend this to support other features such as LoRA.
need_extra_keys = bool(mm_positions)
extra_keys = None
curr_mm_idx = 0
ret = [] ret = []
parent_block_hash_value = None parent_block_hash_value = None
for start in range(0, len(token_ids), block_size): for start in range(0, len(token_ids), block_size):
...@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int, ...@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full. # Do not hash the block if it is not full.
if len(block_token_ids) < block_size: if len(block_token_ids) < block_size:
break break
# Add extra keys if the block is a multi-modal block.
if need_extra_keys:
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value, block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids) block_token_ids, extra_keys)
ret.append(block_hash) ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value parent_block_hash_value = block_hash.hash_value
return ret return ret
...@@ -516,6 +516,7 @@ class NewRequestData: ...@@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids: List[int] prompt_token_ids: List[int]
prompt: Optional[str] prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"] mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"] mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams sampling_params: SamplingParams
block_ids: List[int] block_ids: List[int]
...@@ -533,6 +534,7 @@ class NewRequestData: ...@@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt, prompt=request.prompt,
mm_inputs=request.mm_inputs, mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions, mm_positions=request.mm_positions,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
block_ids=block_ids, block_ids=block_ids,
......
...@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType ...@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
...@@ -54,15 +53,17 @@ class AsyncLLM(EngineClient): ...@@ -54,15 +53,17 @@ class AsyncLLM(EngineClient):
lora_config=vllm_config.lora_config) lora_config=vllm_config.lora_config)
self.tokenizer.ping() self.tokenizer.ping()
# Request streams (map of request_id -> AsyncStream). # Request streams (map of request_id -> queue).
self.request_streams: Dict[str, AsyncStream] = {} self.rid_to_queue: Dict[str, asyncio.Queue] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []
# Processor (converts Inputs --> EngineCoreRequests). # Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config, self.processor = Processor(
vllm_config.lora_config, self.tokenizer, model_config=vllm_config.model_config,
input_registry) cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput). # Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer( self.detokenizer = Detokenizer(
...@@ -94,7 +95,7 @@ class AsyncLLM(EngineClient): ...@@ -94,7 +95,7 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine": ) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs.""" """Create an AsyncLLM from the EngineArgs."""
# Create the engine configs. # Create the engine configs.
...@@ -149,14 +150,13 @@ class AsyncLLM(EngineClient): ...@@ -149,14 +150,13 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: ) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
if self.detokenizer.is_request_active(request_id): # 1) Create a new output queue for the request.
raise ValueError(f"Request {request_id} already exists.") if request_id in self.rid_to_queue:
raise ValueError(f"Request id {request_id} already running.")
# 1) Create a new AsyncStream for the request. self.rid_to_queue[request_id] = asyncio.Queue()
stream = self._add_request_to_streams(request_id)
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest. # 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs( detokenizer_req, engine_core_req = self.processor.process_inputs(
...@@ -169,8 +169,10 @@ class AsyncLLM(EngineClient): ...@@ -169,8 +169,10 @@ class AsyncLLM(EngineClient):
# 4) Add the EngineCoreRequest to EngineCore (separate process). # 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req) await self.engine_core.add_request_async(engine_core_req)
# 5) Return the generator. if self.log_requests:
return stream.generator() logger.info("Added request %s.", request_id)
return self.rid_to_queue[request_id]
# TODO: we should support multiple prompts in one call, as you # TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion # can do with LLM.generate. So that for multi-prompt completion
...@@ -190,7 +192,7 @@ class AsyncLLM(EngineClient): ...@@ -190,7 +192,7 @@ class AsyncLLM(EngineClient):
""" """
Main function called by the API server to kick off a request Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request. * 1) Making an AsyncStream corresponding to the Request.
# 2) Processing the Input. * 2) Processing the Input.
* 3) Adding the Request to the Detokenizer. * 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process). * 4) Adding the Request to the EngineCore (separate process).
...@@ -202,14 +204,15 @@ class AsyncLLM(EngineClient): ...@@ -202,14 +204,15 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller. returning the RequestOutput back to the caller.
""" """
# We start the output_handler on the first call to generate() so that try:
# we can call __init__ before the event loop starts, which enables us # We start the output_handler on the first call to generate() so
# to handle startup failure gracefully in the OpenAI server. # we can call __init__ before the event loop, which enables us
if self.output_handler is None: # to handle startup failure gracefully in the OpenAI server.
self.output_handler = asyncio.create_task( if self.output_handler is None:
self._run_output_handler()) self.output_handler = asyncio.create_task(
self._run_output_handler())
async for output in await self.add_request(
q = await self.add_request(
request_id, request_id,
prompt, prompt,
sampling_params, sampling_params,
...@@ -217,79 +220,42 @@ class AsyncLLM(EngineClient): ...@@ -217,79 +220,42 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
): )
yield output
def _finish_stream(self, request_id: str):
stream = self.request_streams.pop(request_id, None)
if stream is not None:
stream.finish()
def _add_request_to_streams(
self,
request_id: str,
) -> AsyncStream:
if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")
# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs = self.client_aborted_requests
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream
if self.log_requests:
logger.info("Added request %s.", request_id)
return stream # The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
async def _process_cancellations(self) -> None: while True:
""" # Note: drain queue without await if possible (avoids
Process requests cancelled from user disconnecting. # task switching under load which helps performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get()
When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to # Note: both Detokenizer and EngineCore handle their
self.client_aborted_requests. # own request cleanup based on finished.
if out.finished:
As a result, if any requests are canceled from the user side del self.rid_to_queue[request_id]
the request_id will show up in self.client_aborted_requests. yield out
""" break
# Avoid streams having circular ref to parent AsyncLLM object. yield out
if not self.client_aborted_requests:
return # If the request is disconnected by the client, the
reqs_to_abort = self.client_aborted_requests.copy() # generate() task will be canceled. So, we abort the
self.client_aborted_requests.clear() # request if we end up here.
except asyncio.CancelledError:
# Remove from Detokenizer. await self.abort(request_id)
self.detokenizer.abort_requests(reqs_to_abort) raise
# Remove from RequestStreams.
for request_id in reqs_to_abort:
if self.log_requests:
logger.info("User-cancelled request %s.", request_id)
self._finish_stream(request_id)
# Remove from EngineCore.
await self.engine_core.abort_requests_async(reqs_to_abort)
def _process_request_outputs(self, request_outputs: List[RequestOutput]): def _process_request_outputs(self, request_outputs: List[RequestOutput]):
"""Process outputs by putting them into per-request AsyncStreams.""" """Process outputs by putting them into per-request queues."""
for request_output in request_outputs: for request_output in request_outputs:
request_id = request_output.request_id request_id = request_output.request_id
assert request_id in self.request_streams
# Each request in the API server pulls from the per-request stream.
stream = self.request_streams.get(request_id)
if stream is not None:
stream.put(request_output)
# If finished, remove from the tracker. # Note: it is possible a request was aborted and removed from
if request_output.finished: # the state due to client cancellations, so if we encounter a
if self.log_requests: # request id not in the state, we skip.
logger.info("Finished request %s.", request_id) if request_id in self.rid_to_queue:
self._finish_stream(request_id) self.rid_to_queue[request_id].put_nowait(request_output)
async def _run_output_handler(self): async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams.""" """Background loop: pulls from EngineCore and pushes to AsyncStreams."""
...@@ -302,24 +268,27 @@ class AsyncLLM(EngineClient): ...@@ -302,24 +268,27 @@ class AsyncLLM(EngineClient):
# 2) Detokenize based on the output. # 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs) request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
# 3) Put the RequestOutputs into the per-request AsyncStreams. # 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs) self._process_request_outputs(request_outputs)
# 4) Abort any requests that finished due to stop strings. # 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort) await self.engine_core.abort_requests_async(reqs_to_abort)
# 5) Abort any requests due to client cancellations.
await self._process_cancellations()
except BaseException as e: except BaseException as e:
logger.error(e) logger.error(e)
raise e raise e
# TODO: can we eliminate these?
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
# Note: Who Calls this? I dont think this is actually used. """Abort RequestId in self, detokenizer, and engine core."""
raise ValueError("Not Supported on V1 yet.")
request_ids = [request_id]
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)
# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]
def encode( def encode(
self, self,
...@@ -382,7 +351,3 @@ class AsyncLLM(EngineClient): ...@@ -382,7 +351,3 @@ class AsyncLLM(EngineClient):
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
return Exception() # TODO: implement return Exception() # TODO: implement
# Retain V0 name for backwards compatibility.
AsyncLLMEngine = AsyncLLM
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import PoolingRequestOutput, RequestOutput
class AsyncStream:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(exception if self._is_raisable(exception)
else AsyncStream.STOP_ITERATION)
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
finished = False
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
finished = True
if result == AsyncStream.STOP_ITERATION:
return
raise result
yield result
finally:
self._finished = True
if not finished:
self._cancel(self.request_id)
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
...@@ -32,7 +32,7 @@ logger = init_logger(__name__) ...@@ -32,7 +32,7 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5000 LOGGING_TIME_S = POLLING_TIMEOUT_S
class EngineCore: class EngineCore:
...@@ -65,7 +65,8 @@ class EngineCore: ...@@ -65,7 +65,8 @@ class EngineCore:
self._last_logging_time = time.time() self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer() self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
def _initialize_kv_caches(self, def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]: cache_config: CacheConfig) -> Tuple[int, int]:
...@@ -98,9 +99,8 @@ class EngineCore: ...@@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache # MM mapper, so anything that has a hash must have a HIT cache
# entry here as well. # entry here as well.
assert request.mm_inputs is not None assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = ( request.mm_inputs = self.mm_input_mapper_server.process_inputs(
self.mm_input_mapper_server.process_inputs( request.mm_inputs, request.mm_hashes)
request.mm_inputs, request.mm_hashes))
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
......
...@@ -55,9 +55,12 @@ class LLMEngine: ...@@ -55,9 +55,12 @@ class LLMEngine:
self.tokenizer.ping() self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config, self.processor = Processor(model_config=vllm_config.model_config,
vllm_config.lora_config, self.tokenizer, cache_config=vllm_config.cache_config,
input_registry, mm_registry) lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput) # Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer( self.detokenizer = Detokenizer(
...@@ -107,7 +110,10 @@ class LLMEngine: ...@@ -107,7 +110,10 @@ class LLMEngine:
executor_class: Type[Executor] executor_class: Type[Executor]
distributed_executor_backend = ( distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend) vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp": if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor executor_class = MultiprocExecutor
else: else:
......
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import PIL import PIL
from blake3 import blake3 from blake3 import blake3
...@@ -8,7 +8,7 @@ from vllm.inputs import PromptType ...@@ -8,7 +8,7 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry) MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache from vllm.utils import LRUCache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -42,7 +42,9 @@ class MMInputMapperClient: ...@@ -42,7 +42,9 @@ class MMInputMapperClient:
model_config) model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config) self.mm_registry.init_mm_limits_per_prompt(model_config)
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) # Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None self.mm_debug_cache_hit_ratio_steps = None
...@@ -61,7 +63,7 @@ class MMInputMapperClient: ...@@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes: Optional[List[str]], mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]], mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]], precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]: ) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None: if precomputed_mm_inputs is None:
image_inputs = mm_data["image"] image_inputs = mm_data["image"]
if not isinstance(image_inputs, list): if not isinstance(image_inputs, list):
...@@ -70,26 +72,21 @@ class MMInputMapperClient: ...@@ -70,26 +72,21 @@ class MMInputMapperClient:
else: else:
num_inputs = len(precomputed_mm_inputs) num_inputs = len(precomputed_mm_inputs)
# Check if hash is enabled # Sanity
use_hash = mm_hashes is not None if self.use_cache:
if use_hash:
assert mm_hashes is not None assert mm_hashes is not None
assert num_inputs == len( assert num_inputs == len(mm_hashes)
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
# Process each image input separately, so that later we can schedule # Process each image input separately, so that later we can schedule
# them in a fine-grained manner. # them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided) # Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes: Optional[List[str]] = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = [] ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs): for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None: if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_hash = None
mm_input = None mm_input = None
if use_hash: if self.use_cache:
assert mm_hashes is not None assert mm_hashes is not None
mm_hash = mm_hashes[input_id] mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash) mm_input = self.mm_cache.get(mm_hash)
...@@ -106,7 +103,7 @@ class MMInputMapperClient: ...@@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
if use_hash: if self.use_cache:
# Add to cache # Add to cache
assert mm_hash is not None assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input) self.mm_cache.put(mm_hash, mm_input)
...@@ -114,19 +111,16 @@ class MMInputMapperClient: ...@@ -114,19 +111,16 @@ class MMInputMapperClient:
self.mm_cache_hits += 1 self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server mm_input = None # Avoids sending mm_input to Server
if use_hash:
assert mm_hash is not None
assert ret_hashes is not None
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input) ret_inputs.append(mm_input)
return ret_inputs, ret_hashes return ret_inputs
class MMInputMapperServer: class MMInputMapperServer:
def __init__(self, ): def __init__(self, model_config):
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs( def process_inputs(
self, self,
...@@ -135,6 +129,9 @@ class MMInputMapperServer: ...@@ -135,6 +129,9 @@ class MMInputMapperServer:
) -> List[MultiModalKwargs]: ) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes) assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
return mm_inputs
full_mm_inputs = [] full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes): for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None assert mm_hash is not None
...@@ -154,12 +151,45 @@ class MMHasher: ...@@ -154,12 +151,45 @@ class MMHasher:
def __init__(self): def __init__(self):
pass pass
def hash(self, prompt: PromptType) -> Optional[List[str]]: def hash_dummy_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
"""Hash user-defined dummy multimodal data used for profiling."""
if mm_data is None:
return None
image_inputs = mm_data['image']
# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if isinstance(image_inputs, dict):
assert "raw_mm_data" in image_inputs and isinstance(
image_inputs["raw_mm_data"], PIL.Image.Image)
image_inputs = image_inputs.pop("raw_mm_data")
return self.hash_images(image_inputs)
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt: if "multi_modal_data" not in prompt:
return None return None
mm_data = prompt["multi_modal_data"] mm_data = prompt["multi_modal_data"]
if not mm_data:
# mm_data can be None or an empty dict.
return None
image_inputs = mm_data["image"] image_inputs = mm_data["image"]
return self.hash_images(image_inputs)
def hash_images(self, image_inputs) -> Optional[List[str]]:
"""Hash PIL image objects to strings."""
if not isinstance(image_inputs, list): if not isinstance(image_inputs, list):
image_inputs = [image_inputs] image_inputs = [image_inputs]
assert len(image_inputs) > 0 assert len(image_inputs) > 0
......
import time import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union from typing import Mapping, Optional, Tuple, Union
from vllm.config import LoRAConfig, ModelConfig from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import is_encoder_decoder_inputs
...@@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, ...@@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
...@@ -23,6 +22,7 @@ class Processor: ...@@ -23,6 +22,7 @@ class Processor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
tokenizer: BaseTokenizerGroup, tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
...@@ -33,8 +33,8 @@ class Processor: ...@@ -33,8 +33,8 @@ class Processor:
self.lora_config = lora_config self.lora_config = lora_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = model_config.try_get_generation_config(
model_config) )
self.input_preprocessor = InputPreprocessor(model_config, self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer, self.tokenizer,
mm_registry) mm_registry)
...@@ -45,8 +45,9 @@ class Processor: ...@@ -45,8 +45,9 @@ class Processor:
self.mm_input_mapper_client = MMInputMapperClient(model_config) self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images) # Multi-modal hasher (for images)
self.mm_hasher = MMHasher( self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
) if model_config.mm_cache_preprocessor else None cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess. # TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the # This ideally should releases the GIL, so we should not block the
...@@ -77,8 +78,8 @@ class Processor: ...@@ -77,8 +78,8 @@ class Processor:
# Compute MM hashes (if enabled) # Compute MM hashes (if enabled)
mm_hashes = None mm_hashes = None
if self.mm_hasher is not None: if self.use_hash:
mm_hashes = self.mm_hasher.hash(prompt) mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt)
# Process inputs. # Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
...@@ -118,7 +119,7 @@ class Processor: ...@@ -118,7 +119,7 @@ class Processor:
# Apply MM mapper # Apply MM mapper
mm_inputs = None mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0: if len(decoder_inputs.multi_modal_data) > 0:
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs( mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
...@@ -179,16 +180,3 @@ class Processor: ...@@ -179,16 +180,3 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can # TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them # check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens # max_batch_len = self.scheduler_config.max_num_batched_tokens
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
initialize_ray_cluster, ray)
from vllm.v1.outputs import ModelRunnerOutput
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class RayExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# A list of workers to run a model.
self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
# Skip bundles that don't have GPUs,
# as each worker needs one GPU.
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
worker_to_ip = dict(zip(self.workers, worker_ips))
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first. This is simply a tiebreaker to make
sure the workers are sorted in a deterministic way.
"""
ip = worker_to_ip[worker]
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips)
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
"VLLM_USE_V1":
str(int(envs.VLLM_USE_V1)),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("initialize")
self._run_workers("load_model")
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""
Return worker init args for a given rank.
"""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks")
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV cache in all workers.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self._run_workers("initialize_cache", num_gpu_blocks)
self._run_workers("compile_or_warm_up_model")
def _run_workers(
self,
method: str,
*args,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> Any:
"""
Runs the given method on all workers. Can be used in the following
ways:
Args:
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 0, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 0, None)
ray_worker_refs = [
worker.execute_method.remote( # type: ignore[attr-defined]
method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
return ray.get(ray_worker_refs)
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag()
# Only the first worker (with rank 0) returns the execution result.
# Others return None.
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
return output
def profile(self, is_start=True):
raise NotImplementedError
def shutdown(self):
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def check_health(self) -> None:
logger.debug("Called check_health.")
def _check_ray_compiled_graph_installation(self):
import pkg_resources
from packaging import version
required_version = version.parse("2.39")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
if raycg is None:
raise ValueError("Ray Compiled Graph is not installed. "
"Run `pip install ray[adag]` to install it.")
cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation.")
def _compiled_ray_dag(self):
assert self.parallel_config.use_ray
self._check_ray_compiled_graph_installation()
from ray.dag import InputNode, MultiOutputNode
with InputNode() as input_batches:
outputs = [
worker.execute_model.bind( # type: ignore[attr-defined]
input_batches) for worker in self.workers
]
forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile()
def __del__(self):
self.shutdown()
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import get_ip
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
PG_WAIT_TIMEOUT = 60
try:
import ray
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup
try:
from ray._private.state import available_resources_per_node
except ImportError:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from ray._private.state import state as _state
available_resources_per_node = _state._available_resources_per_node
class RayWorkerWrapper(WorkerWrapperBase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread. It will be removed soon.
self.compiled_dag_cuda_device_set = False
def get_node_ip(self) -> str:
return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
def setup_device_if_necessary(self):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
# device.
# We can remove this API after it is fixed in compiled graph.
import torch
assert self.worker is not None, "Worker is not initialized"
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
return output
ray_import_err = None
except ImportError as e:
ray = None # type: ignore
ray_import_err = e
RayWorkerWrapper = None # type: ignore
def ray_is_available() -> bool:
"""Returns True if Ray is available."""
return ray is not None
def assert_ray_available():
"""
Raise an exception if Ray is not available.
"""
if ray is None:
raise ValueError("Failed to import Ray, please install Ray with "
"`pip install ray`.") from ray_import_err
def _verify_bundles(placement_group: "PlacementGroup",
parallel_config: ParallelConfig, device_str: str):
"""
Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
Args:
placement_group: The placement group to verify.
parallel_config: The parallel configuration.
device_str: The required device.
"""
assert ray.is_initialized(), (
"Ray is not initialized although distributed-executor-backend is ray.")
pg_data = placement_group_table(placement_group)
# bundle_idx -> node_id
bundle_to_node_ids = pg_data["bundles_to_node_id"]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
driver_node_id = ray.get_runtime_context().get_node_id()
if driver_node_id not in node_id_to_bundle:
raise RuntimeError(
f"driver node id {driver_node_id} is not included in a placement "
f"group {placement_group.id}. Node id -> bundles "
f"{node_id_to_bundle}. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f"{driver_node_id} before starting an vLLM engine.")
for node_id, bundles in node_id_to_bundle.items():
if len(bundles) < parallel_config.tensor_parallel_size:
logger.warning(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node.",
parallel_config.tensor_parallel_size, device_str, len(bundles),
device_str, node_id, parallel_config.tensor_parallel_size)
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs = current_placement_group.bundle_specs
s = time.time()
pg_ready_ref = current_placement_group.ready()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
if len(ready) > 0:
break
# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources.",
int(time.time() - s), placement_group_specs)
try:
ray.get(pg_ready_ref, timeout=0)
except ray.exceptions.GetTimeoutError:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
"`ray status` to make sure the cluster has enough resources."
) from None
def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
assert_ray_available()
# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto")
except ConnectionError:
logger.warning(
"No existing RAY instance detected. "
"A new instance will be launched with current node resources.")
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if parallel_config.placement_group:
# Placement group is already set.
return
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
device_bundles = 0
for bundle in bundles:
bundle_devices = bundle.get(device_str, 0)
if bundle_devices > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 "
f"{device_str}.")
if bundle_devices:
device_bundles += 1
if parallel_config.world_size > device_bundles:
raise ValueError(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group."
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
if parallel_config.world_size > num_devices_in_cluster:
raise ValueError(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs: List[Dict[str, float]] = ([{
device_str: 1.0
} for _ in range(parallel_config.world_size)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str, 0) < 1:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
# By default, Ray packs resources as much as possible.
current_placement_group = ray.util.placement_group(
placement_group_specs, strategy="PACK")
_wait_until_pg_ready(current_placement_group)
assert current_placement_group is not None
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group
import enum import enum
from typing import List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics ...@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import BlockHashType
class Request: class Request:
...@@ -45,6 +48,7 @@ class Request: ...@@ -45,6 +48,7 @@ class Request:
self._all_token_ids: List[int] = self.prompt_token_ids.copy() self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders mm_positions = self.inputs.multi_modal_placeholders
if mm_positions: if mm_positions:
# FIXME(woosuk): Support other modalities. # FIXME(woosuk): Support other modalities.
...@@ -56,6 +60,12 @@ class Request: ...@@ -56,6 +60,12 @@ class Request:
if self.inputs.multi_modal_inputs: if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs self.mm_inputs = self.inputs.multi_modal_inputs
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls( return cls(
...@@ -65,6 +75,7 @@ class Request: ...@@ -65,6 +75,7 @@ class Request:
prompt=request.prompt, prompt=request.prompt,
multi_modal_data=None, multi_modal_data=None,
multi_modal_inputs=request.mm_inputs, multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders, multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None, mm_processor_kwargs=None,
), ),
...@@ -121,6 +132,17 @@ class Request: ...@@ -121,6 +132,17 @@ class Request:
num_tokens = self.mm_positions[input_id]["length"] num_tokens = self.mm_positions[input_id]["length"]
return num_tokens return num_tokens
@property
def kv_block_hashes(self) -> ConstantList["BlockHashType"]:
# Prevent directly appending to the kv_block_hashes.
return ConstantList(self._kv_block_hashes)
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
self._kv_block_hashes = value
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
self._kv_block_hashes.append(block_hash)
class RequestStatus(enum.IntEnum): class RequestStatus(enum.IntEnum):
"""Status of a request.""" """Status of a request."""
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict from typing import Dict, List, Optional, Set
import torch import torch
...@@ -19,3 +19,13 @@ class SamplingMetadata: ...@@ -19,3 +19,13 @@ class SamplingMetadata:
generators: Dict[int, torch.Generator] generators: Dict[int, torch.Generator]
max_num_logprobs: int max_num_logprobs: int
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment