Commit 96ae75ad authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev

parents f9f4a735 2339d59f
import fnmatch
import os
import shutil
import signal
import tempfile
from pathlib import Path
from typing import Optional
from vllm.utils import PlaceholderModule
try:
import boto3
except ImportError:
boto3 = PlaceholderModule("boto3") # type: ignore[assignment]
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths if any(
fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def glob(s3=None,
path: str = "",
allow_pattern: Optional[list[str]] = None) -> list[str]:
"""
List full file names from S3 path and filter by allow pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full S3 paths allowed by the pattern
"""
if s3 is None:
s3 = boto3.client("s3")
bucket_name, _, paths = list_files(s3,
path=path,
allow_pattern=allow_pattern)
return [f"s3://{bucket_name}/{path}" for path in paths]
def list_files(
s3,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts = path.removeprefix('s3://').split('/')
prefix = '/'.join(parts[1:])
bucket_name = parts[0]
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj['Key'] for obj in objects.get('Contents', [])]
paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)
if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Model:
"""
A class representing a S3 model mirrored into a temporary directory.
Attributes:
s3: S3 client.
dir: The temporary created directory.
Methods:
pull_files(): Pull model from S3 to the temporary directory.
"""
def __init__(self) -> None:
self.s3 = boto3.client('s3')
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp()
def __del__(self):
self._close()
def _close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self._close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
def pull_files(self,
s3_model_path: str = "",
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
if len(files) == 0:
return
for file in files:
destination_file = self.dir + file.removeprefix(base_dir)
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.s3.download_file(bucket_name, file, destination_file)
......@@ -132,7 +132,7 @@ def get_tokenizer(
if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn(
'It is strongly recommended to run mistral models with '
'`--tokenizer_mode "mistral"` to ensure correct '
'`--tokenizer-mode "mistral"` to ensure correct '
'encoding and decoding.',
FutureWarning,
stacklevel=2)
......
......@@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer](
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
@classmethod
......
......@@ -314,12 +314,15 @@ class MistralTokenizer:
if regular_tokens:
decoded_list.append(
self.decode(regular_tokens)) # type: ignore
self.tokenizer.decode(regular_tokens)) # type: ignore
decoded = ''.join(decoded_list)
return decoded
# WARN: Outlines logits processors can overwrite this method.
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
# for more.
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
......
......@@ -3,6 +3,10 @@ from pathlib import Path
from typing import Union
def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')
def check_gguf_file(model: Union[str, PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
......
......@@ -6,10 +6,13 @@ import datetime
import enum
import gc
import getpass
import importlib.metadata
import importlib.util
import inspect
import ipaddress
import os
import re
import resource
import signal
import socket
import subprocess
......@@ -21,14 +24,13 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
Optional, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4
import numpy as np
......@@ -52,7 +54,7 @@ logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid
STR_NOT_IMPL_ENC_DEC_SWA = \
......@@ -154,10 +156,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class _Sentinel:
...
......@@ -190,50 +194,48 @@ class Counter:
self.counter = 0
class LRUCache(Generic[T]):
class LRUCache(Generic[_K, _V]):
def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
self.pinned_items = set[_K]()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
def __contains__(self, key: _K) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> T:
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: Hashable, value: T) -> None:
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
def __delitem__(self, key: _K) -> None:
self.pop(key)
def touch(self, key: Hashable) -> None:
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
value: Optional[_V]
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
value = default
return value
def put(self, key: Hashable, value: T) -> None:
def put(self, key: _K, value: _V) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def pin(self, key: Hashable) -> None:
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
......@@ -242,13 +244,13 @@ class LRUCache(Generic[T]):
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: Hashable) -> None:
def _unpin(self, key: _K) -> None:
self.pinned_items.remove(key)
def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass
def remove_oldest(self, remove_pinned=False):
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache:
return
......@@ -262,17 +264,15 @@ class LRUCache(Generic[T]):
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)
self.pop(lru_key) # type: ignore
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
......@@ -280,7 +280,7 @@ class LRUCache(Generic[T]):
self._on_remove(key, value)
return value
def clear(self):
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
......@@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int:
# `collections` helpers
def is_list_of(
value: object,
typ: Type[T],
typ: Union[type[T], tuple[type[T], ...]],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
......@@ -843,10 +843,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist]
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by
......@@ -1282,6 +1278,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
def supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
......@@ -1326,6 +1323,8 @@ def resolve_mm_processor_kwargs(
init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
......@@ -1344,11 +1343,17 @@ def resolve_mm_processor_kwargs(
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=inference_kwargs,
allow_var_kwargs=allow_var_kwargs)
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
callable,
overrides=init_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
......@@ -1359,6 +1364,8 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Mapping[str, object]],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""
......@@ -1390,16 +1397,21 @@ def get_allowed_kwarg_only_overrides(
for kwarg_name, val in overrides.items()
if supports_kw(callable,
kwarg_name,
requires_kw_only=True,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs)
}
# If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys:
logger.warning(
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys)
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and and will be dropped: %s", dropped_keys)
return filtered_overrides
......@@ -1628,6 +1640,67 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
return module
@lru_cache(maxsize=None)
def get_vllm_optional_dependencies():
metadata = importlib.metadata.metadata("vllm")
requirements = metadata.get_all("Requires-Dist", [])
extras = metadata.get_all("Provides-Extra", [])
return {
extra: [
re.split(r";|>=|<=|==", req)[0] for req in requirements
if req.endswith(f'extra == "{extra}"')
]
for extra in extras
}
@dataclass(frozen=True)
class PlaceholderModule:
"""
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name: str
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)
def __getattr__(self, key: str):
name = self.name
try:
importlib.import_module(self.name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
msg = f"Please install vllm[{extra}] for {extra} support"
raise ImportError(msg) from exc
raise exc
raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
@dataclass(frozen=True)
class _PlaceholderModuleAttr:
module: PlaceholderModule
attr_path: str
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.module,
f"{self.attr_path}.{attr_path}")
def __getattr__(self, key: str):
getattr(self.module, f"{self.attr_path}.{key}")
raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa
......@@ -1655,8 +1728,18 @@ def direct_register_custom_op(
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if is_in_doc_build() or not supports_custom_op():
if is_in_doc_build():
return
if not supports_custom_op():
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies.")
return
import torch.library
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func,
......@@ -1823,3 +1906,20 @@ def memory_profiling(
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
result.profile_time = diff.timestamp
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type,
(target_soft_limit, current_hard))
except ValueError as e:
logger.warning(
"Found ulimit of %s and failed to automatically increase"
"with error %s. This can cause fd limit errors like"
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n", current_soft, e)
......@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request
......@@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks = []
# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
......@@ -197,23 +201,15 @@ class KVCacheManager:
f"num_tokens must be greater than 0, got {num_tokens}")
# Touch the computed blocks to make sure they won't be evicted.
num_evictable_computed_blocks = 0
if self.enable_caching:
self._touch(computed_blocks)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
else:
assert not computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")
num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
if (num_required_blocks > self.free_block_queue.num_free_blocks):
# Cannot allocate new blocks.
return None
......@@ -221,8 +217,7 @@ class KVCacheManager:
# preallocated blocks.
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
......@@ -242,14 +237,16 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id]
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None,
)
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)
return new_blocks
......@@ -376,6 +373,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
if prev_block is not None:
......@@ -387,17 +386,35 @@ class KVCacheManager:
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
block_tokens = request.all_token_ids[blk_idx *
self.block_size:(blk_idx +
1) *
self.block_size]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
f"at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash
......
"""KV-Cache Utilities."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple
from typing import Any, List, NamedTuple, Optional, Tuple
from vllm.logger import init_logger
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""Hash value of a block (int), the token IDs in the block, and extra keys.
The reason we keep a tuple of token IDs and extra keys is to make sure
no hash collision happens when the hash value is the same.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: Tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
@dataclass
......@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return ret
def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int]) -> BlockHashType:
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return None, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set disable_mm_preprocessor_cache=False.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
return None, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
extra_keys = []
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"]
length = mm_positions[curr_mm_idx]["length"]
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
mm_start = max(0, start_token_idx - offset)
extra_keys.append((mm_hashes[curr_mm_idx], mm_start))
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx += 1
else:
# Otherwise this block is done with mm inputs.
break
else:
# This block has not reached the current mm input.
break
return tuple(extra_keys), curr_mm_idx
def hash_block_tokens(
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
......@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block.
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
tuple(curr_block_token_ids))
tuple(curr_block_token_ids), extra_keys)
def hash_request_tokens(block_size: int,
token_ids: Sequence[int]) -> List[BlockHashType]:
request: Request) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in the request.
request: The request object.
Returns:
The list of computed hash values.
"""
token_ids = request.all_token_ids
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")
# TODO: Extend this to support other features such as LoRA.
need_extra_keys = bool(mm_positions)
extra_keys = None
curr_mm_idx = 0
ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
......@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
# Add extra keys if the block is a multi-modal block.
if need_extra_keys:
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
block_token_ids, extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
......@@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
......@@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
......
......@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
......@@ -54,15 +53,17 @@ class AsyncLLM(EngineClient):
lora_config=vllm_config.lora_config)
self.tokenizer.ping()
# Request streams (map of request_id -> AsyncStream).
self.request_streams: Dict[str, AsyncStream] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []
# Request streams (map of request_id -> queue).
self.rid_to_queue: Dict[str, asyncio.Queue] = {}
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry)
self.processor = Processor(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer(
......@@ -94,7 +95,7 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
......@@ -149,14 +150,13 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM."""
if self.detokenizer.is_request_active(request_id):
raise ValueError(f"Request {request_id} already exists.")
# 1) Create a new AsyncStream for the request.
stream = self._add_request_to_streams(request_id)
# 1) Create a new output queue for the request.
if request_id in self.rid_to_queue:
raise ValueError(f"Request id {request_id} already running.")
self.rid_to_queue[request_id] = asyncio.Queue()
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
......@@ -169,8 +169,10 @@ class AsyncLLM(EngineClient):
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req)
# 5) Return the generator.
return stream.generator()
if self.log_requests:
logger.info("Added request %s.", request_id)
return self.rid_to_queue[request_id]
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
......@@ -190,7 +192,7 @@ class AsyncLLM(EngineClient):
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
# 2) Processing the Input.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
......@@ -202,14 +204,15 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller.
"""
# We start the output_handler on the first call to generate() so that
# we can call __init__ before the event loop starts, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
async for output in await self.add_request(
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
q = await self.add_request(
request_id,
prompt,
sampling_params,
......@@ -217,79 +220,42 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield output
def _finish_stream(self, request_id: str):
stream = self.request_streams.pop(request_id, None)
if stream is not None:
stream.finish()
def _add_request_to_streams(
self,
request_id: str,
) -> AsyncStream:
if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")
# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs = self.client_aborted_requests
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream
if self.log_requests:
logger.info("Added request %s.", request_id)
)
return stream
async def _process_cancellations(self) -> None:
"""
Process requests cancelled from user disconnecting.
When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
self.client_aborted_requests.
As a result, if any requests are canceled from the user side
the request_id will show up in self.client_aborted_requests.
"""
# Avoid streams having circular ref to parent AsyncLLM object.
if not self.client_aborted_requests:
return
reqs_to_abort = self.client_aborted_requests.copy()
self.client_aborted_requests.clear()
# Remove from Detokenizer.
self.detokenizer.abort_requests(reqs_to_abort)
# Remove from RequestStreams.
for request_id in reqs_to_abort:
if self.log_requests:
logger.info("User-cancelled request %s.", request_id)
self._finish_stream(request_id)
# Remove from EngineCore.
await self.engine_core.abort_requests_async(reqs_to_abort)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
while True:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get()
# Note: both Detokenizer and EngineCore handle their
# own request cleanup based on finished.
if out.finished:
del self.rid_to_queue[request_id]
yield out
break
yield out
# If the request is disconnected by the client, the
# generate() task will be canceled. So, we abort the
# request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
raise
def _process_request_outputs(self, request_outputs: List[RequestOutput]):
"""Process outputs by putting them into per-request AsyncStreams."""
"""Process outputs by putting them into per-request queues."""
for request_output in request_outputs:
request_id = request_output.request_id
assert request_id in self.request_streams
# Each request in the API server pulls from the per-request stream.
stream = self.request_streams.get(request_id)
if stream is not None:
stream.put(request_output)
# If finished, remove from the tracker.
if request_output.finished:
if self.log_requests:
logger.info("Finished request %s.", request_id)
self._finish_stream(request_id)
# Note: it is possible a request was aborted and removed from
# the state due to client cancellations, so if we encounter a
# request id not in the state, we skip.
if request_id in self.rid_to_queue:
self.rid_to_queue[request_id].put_nowait(request_output)
async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
......@@ -302,24 +268,27 @@ class AsyncLLM(EngineClient):
# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
# 3) Put the RequestOutputs into the per-request AsyncStreams.
# 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs)
# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)
# 5) Abort any requests due to client cancellations.
await self._process_cancellations()
except BaseException as e:
logger.error(e)
raise e
# TODO: can we eliminate these?
async def abort(self, request_id: str) -> None:
# Note: Who Calls this? I dont think this is actually used.
raise ValueError("Not Supported on V1 yet.")
"""Abort RequestId in self, detokenizer, and engine core."""
request_ids = [request_id]
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)
# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]
def encode(
self,
......@@ -382,7 +351,3 @@ class AsyncLLM(EngineClient):
@property
def dead_error(self) -> BaseException:
return Exception() # TODO: implement
# Retain V0 name for backwards compatibility.
AsyncLLMEngine = AsyncLLM
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import PoolingRequestOutput, RequestOutput
class AsyncStream:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(exception if self._is_raisable(exception)
else AsyncStream.STOP_ITERATION)
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
finished = False
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
finished = True
if result == AsyncStream.STOP_ITERATION:
return
raise result
yield result
finally:
self._finished = True
if not finished:
self._cancel(self.request_id)
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
......@@ -32,7 +32,7 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5000
LOGGING_TIME_S = POLLING_TIMEOUT_S
class EngineCore:
......@@ -65,7 +65,8 @@ class EngineCore:
self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer()
self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
......@@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
......
......@@ -55,9 +55,12 @@ class LLMEngine:
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry, mm_registry)
self.processor = Processor(model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer(
......@@ -107,7 +110,10 @@ class LLMEngine:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
......
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
......@@ -8,7 +8,7 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache
from vllm.utils import LRUCache
logger = init_logger(__name__)
......@@ -42,7 +42,9 @@ class MMInputMapperClient:
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
......@@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
......@@ -70,26 +72,21 @@ class MMInputMapperClient:
else:
num_inputs = len(precomputed_mm_inputs)
# Check if hash is enabled
use_hash = mm_hashes is not None
if use_hash:
# Sanity
if self.use_cache:
assert mm_hashes is not None
assert num_inputs == len(
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
assert num_inputs == len(mm_hashes)
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes: Optional[List[str]] = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_hash = None
mm_input = None
if use_hash:
if self.use_cache:
assert mm_hashes is not None
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
......@@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs=mm_processor_kwargs,
)
if use_hash:
if self.use_cache:
# Add to cache
assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input)
......@@ -114,19 +111,16 @@ class MMInputMapperClient:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
if use_hash:
assert mm_hash is not None
assert ret_hashes is not None
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input)
return ret_inputs, ret_hashes
return ret_inputs
class MMInputMapperServer:
def __init__(self, ):
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs(
self,
......@@ -135,6 +129,9 @@ class MMInputMapperServer:
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
return mm_inputs
full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
......@@ -154,12 +151,45 @@ class MMHasher:
def __init__(self):
pass
def hash(self, prompt: PromptType) -> Optional[List[str]]:
def hash_dummy_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
"""Hash user-defined dummy multimodal data used for profiling."""
if mm_data is None:
return None
image_inputs = mm_data['image']
# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if isinstance(image_inputs, dict):
assert "raw_mm_data" in image_inputs and isinstance(
image_inputs["raw_mm_data"], PIL.Image.Image)
image_inputs = image_inputs.pop("raw_mm_data")
return self.hash_images(image_inputs)
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
if not mm_data:
# mm_data can be None or an empty dict.
return None
image_inputs = mm_data["image"]
return self.hash_images(image_inputs)
def hash_images(self, image_inputs) -> Optional[List[str]]:
"""Hash PIL image objects to strings."""
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
......
import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from typing import Mapping, Optional, Tuple, Union
from vllm.config import LoRAConfig, ModelConfig
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
......@@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
......@@ -23,6 +22,7 @@ class Processor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
......@@ -33,8 +33,8 @@ class Processor:
self.lora_config = lora_config
self.tokenizer = tokenizer
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.generation_config_fields = model_config.try_get_generation_config(
)
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer,
mm_registry)
......@@ -45,8 +45,9 @@ class Processor:
self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images)
self.mm_hasher = MMHasher(
) if model_config.mm_cache_preprocessor else None
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
......@@ -77,8 +78,8 @@ class Processor:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.mm_hasher is not None:
mm_hashes = self.mm_hasher.hash(prompt)
if self.use_hash:
mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt)
# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(
......@@ -118,7 +119,7 @@ class Processor:
# Apply MM mapper
mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0:
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs(
mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
......@@ -179,16 +180,3 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
initialize_ray_cluster, ray)
from vllm.v1.outputs import ModelRunnerOutput
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class RayExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# A list of workers to run a model.
self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
# Skip bundles that don't have GPUs,
# as each worker needs one GPU.
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
worker_to_ip = dict(zip(self.workers, worker_ips))
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first. This is simply a tiebreaker to make
sure the workers are sorted in a deterministic way.
"""
ip = worker_to_ip[worker]
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips)
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
"VLLM_USE_V1":
str(int(envs.VLLM_USE_V1)),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("initialize")
self._run_workers("load_model")
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""
Return worker init args for a given rank.
"""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks")
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV cache in all workers.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self._run_workers("initialize_cache", num_gpu_blocks)
self._run_workers("compile_or_warm_up_model")
def _run_workers(
self,
method: str,
*args,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> Any:
"""
Runs the given method on all workers. Can be used in the following
ways:
Args:
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 0, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 0, None)
ray_worker_refs = [
worker.execute_method.remote( # type: ignore[attr-defined]
method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
return ray.get(ray_worker_refs)
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag()
# Only the first worker (with rank 0) returns the execution result.
# Others return None.
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
return output
def profile(self, is_start=True):
raise NotImplementedError
def shutdown(self):
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def check_health(self) -> None:
logger.debug("Called check_health.")
def _check_ray_compiled_graph_installation(self):
import pkg_resources
from packaging import version
required_version = version.parse("2.39")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
if raycg is None:
raise ValueError("Ray Compiled Graph is not installed. "
"Run `pip install ray[adag]` to install it.")
cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation.")
def _compiled_ray_dag(self):
assert self.parallel_config.use_ray
self._check_ray_compiled_graph_installation()
from ray.dag import InputNode, MultiOutputNode
with InputNode() as input_batches:
outputs = [
worker.execute_model.bind( # type: ignore[attr-defined]
input_batches) for worker in self.workers
]
forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile()
def __del__(self):
self.shutdown()
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import get_ip
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
PG_WAIT_TIMEOUT = 60
try:
import ray
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup
try:
from ray._private.state import available_resources_per_node
except ImportError:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from ray._private.state import state as _state
available_resources_per_node = _state._available_resources_per_node
class RayWorkerWrapper(WorkerWrapperBase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread. It will be removed soon.
self.compiled_dag_cuda_device_set = False
def get_node_ip(self) -> str:
return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
def setup_device_if_necessary(self):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
# device.
# We can remove this API after it is fixed in compiled graph.
import torch
assert self.worker is not None, "Worker is not initialized"
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
return output
ray_import_err = None
except ImportError as e:
ray = None # type: ignore
ray_import_err = e
RayWorkerWrapper = None # type: ignore
def ray_is_available() -> bool:
"""Returns True if Ray is available."""
return ray is not None
def assert_ray_available():
"""
Raise an exception if Ray is not available.
"""
if ray is None:
raise ValueError("Failed to import Ray, please install Ray with "
"`pip install ray`.") from ray_import_err
def _verify_bundles(placement_group: "PlacementGroup",
parallel_config: ParallelConfig, device_str: str):
"""
Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
Args:
placement_group: The placement group to verify.
parallel_config: The parallel configuration.
device_str: The required device.
"""
assert ray.is_initialized(), (
"Ray is not initialized although distributed-executor-backend is ray.")
pg_data = placement_group_table(placement_group)
# bundle_idx -> node_id
bundle_to_node_ids = pg_data["bundles_to_node_id"]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
driver_node_id = ray.get_runtime_context().get_node_id()
if driver_node_id not in node_id_to_bundle:
raise RuntimeError(
f"driver node id {driver_node_id} is not included in a placement "
f"group {placement_group.id}. Node id -> bundles "
f"{node_id_to_bundle}. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f"{driver_node_id} before starting an vLLM engine.")
for node_id, bundles in node_id_to_bundle.items():
if len(bundles) < parallel_config.tensor_parallel_size:
logger.warning(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node.",
parallel_config.tensor_parallel_size, device_str, len(bundles),
device_str, node_id, parallel_config.tensor_parallel_size)
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs = current_placement_group.bundle_specs
s = time.time()
pg_ready_ref = current_placement_group.ready()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
if len(ready) > 0:
break
# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources.",
int(time.time() - s), placement_group_specs)
try:
ray.get(pg_ready_ref, timeout=0)
except ray.exceptions.GetTimeoutError:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
"`ray status` to make sure the cluster has enough resources."
) from None
def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
assert_ray_available()
# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto")
except ConnectionError:
logger.warning(
"No existing RAY instance detected. "
"A new instance will be launched with current node resources.")
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if parallel_config.placement_group:
# Placement group is already set.
return
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
device_bundles = 0
for bundle in bundles:
bundle_devices = bundle.get(device_str, 0)
if bundle_devices > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 "
f"{device_str}.")
if bundle_devices:
device_bundles += 1
if parallel_config.world_size > device_bundles:
raise ValueError(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group."
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
if parallel_config.world_size > num_devices_in_cluster:
raise ValueError(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs: List[Dict[str, float]] = ([{
device_str: 1.0
} for _ in range(parallel_config.world_size)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str, 0) < 1:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
# By default, Ray packs resources as much as possible.
current_placement_group = ray.util.placement_group(
placement_group_specs, strategy="PACK")
_wait_until_pg_ready(current_placement_group)
assert current_placement_group is not None
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group
import enum
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
......@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import BlockHashType
class Request:
......@@ -45,6 +48,7 @@ class Request:
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
......@@ -56,6 +60,12 @@ class Request:
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
......@@ -65,6 +75,7 @@ class Request:
prompt=request.prompt,
multi_modal_data=None,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None,
),
......@@ -121,6 +132,17 @@ class Request:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
@property
def kv_block_hashes(self) -> ConstantList["BlockHashType"]:
# Prevent directly appending to the kv_block_hashes.
return ConstantList(self._kv_block_hashes)
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
self._kv_block_hashes = value
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
self._kv_block_hashes.append(block_hash)
class RequestStatus(enum.IntEnum):
"""Status of a request."""
......
from dataclasses import dataclass
from typing import Dict
from typing import Dict, List, Optional, Set
import torch
......@@ -19,3 +19,13 @@ class SamplingMetadata:
generators: Dict[int, torch.Generator]
max_num_logprobs: int
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment