Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
...@@ -5,10 +5,12 @@ from collections.abc import Iterable ...@@ -5,10 +5,12 @@ from collections.abc import Iterable
from typing import Optional from typing import Optional
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv, sha256
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_request_tokens) hash_request_tokens)
from vllm.v1.core.specialized_manager import get_specialized_manager
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
...@@ -19,20 +21,24 @@ class KVCacheManager: ...@@ -19,20 +21,24 @@ class KVCacheManager:
def __init__( def __init__(
self, self,
block_size: int, kv_cache_config: KVCacheConfig,
num_gpu_blocks: int,
max_model_len: int, max_model_len: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True, enable_caching: bool = True,
caching_hash_algo: str = "builtin",
num_preallocate_tokens: int = 64, num_preallocate_tokens: int = 64,
log_stats: bool = False, log_stats: bool = False,
) -> None: ) -> None:
self.block_size = block_size assert len(kv_cache_config.kv_cache_groups) == 1, (
self.num_gpu_blocks = num_gpu_blocks "KVCacheManager does not support hybrid models with more than 1 "
"kv cache group")
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = kv_cache_spec.block_size
self.num_gpu_blocks = kv_cache_config.num_blocks
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_num_blocks_per_req = cdiv(max_model_len, block_size) self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size)
self.sliding_window = sliding_window
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
# FIXME: make prefix cache stats conditional on log_stats # FIXME: make prefix cache stats conditional on log_stats
self.log_stats = log_stats self.log_stats = log_stats
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some # NOTE(woosuk): To avoid frequent block allocation, we preallocate some
...@@ -46,9 +52,15 @@ class KVCacheManager: ...@@ -46,9 +52,15 @@ class KVCacheManager:
# further allocation. When it uses up all the N empty blocks, it gets # further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks. # N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) self.num_preallocate_blocks = cdiv(num_preallocate_tokens,
self.block_size)
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
self.block_pool = BlockPool(num_gpu_blocks, enable_caching) self.specialized_manager = get_specialized_manager(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
)
# Mapping from request ID to blocks to track the blocks allocated # Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request # for each request, so that we can free the blocks when the request
...@@ -109,22 +121,31 @@ class KVCacheManager: ...@@ -109,22 +121,31 @@ class KVCacheManager:
# if the scheduler has tried to schedule the request before. # if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id] block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes: if not block_hashes:
block_hashes = hash_request_tokens(self.block_size, request) block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes self.req_to_block_hashes[request.request_id] = block_hashes
self.prefix_cache_stats.requests += 1 self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None: if request.sampling_params.prompt_logprobs is None:
# Check for cache hits if len(block_hashes) * self.block_size == request.num_tokens:
computed_blocks = [] # When prompt length is divisible by the block size and all
for block_hash in block_hashes: # blocks are cached, we need to recompute the last token. This
# block_hashes is a chain of block hashes. If a block hash # have to be achieved by re-computing an entire block because
# is not in the cached_block_hash_to_id, the following # allocate_slots() assumes num_computed_tokens is always a
# block hashes are not computed yet for sure. # multiple of the block size. To achieve this, remove the last
if cached_block := self.block_pool.get_cached_block( # block hash from the block_hashes for find_longest_cache_hit
block_hash): # This limitation can potentially be removed in the future to
computed_blocks.append(cached_block) # slightly improve the performance.
else: last_block_hash = block_hashes.pop()
break else:
last_block_hash = None
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
if last_block_hash is not None:
# Add back the last block hash if it was removed.
block_hashes.append(last_block_hash)
self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks) self.prefix_cache_stats.hits += len(computed_blocks)
...@@ -173,13 +194,24 @@ class KVCacheManager: ...@@ -173,13 +194,24 @@ class KVCacheManager:
new_computed_blocks = new_computed_blocks or [] new_computed_blocks = new_computed_blocks or []
req_blocks = self.req_to_blocks[request.request_id]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
removed_blocks = self.specialized_manager.remove_skipped_blocks(
req_blocks, request.num_computed_tokens)
self.block_pool.free_blocks(removed_blocks)
# The number of computed tokens is the number of computed tokens plus # The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size) len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens, num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size) self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) - num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks)) len(new_computed_blocks))
...@@ -247,6 +279,7 @@ class KVCacheManager: ...@@ -247,6 +279,7 @@ class KVCacheManager:
num_cached_blocks=num_cached_blocks, num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks_after_append, num_full_blocks=num_full_blocks_after_append,
block_size=self.block_size, block_size=self.block_size,
hash_fn=self.caching_hash_fn,
) )
self.num_cached_block[ self.num_cached_block[
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
import os
from collections import deque from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, NamedTuple, Optional from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec, from vllm.utils import sha256
KVCacheSpec, KVCacheTensor) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -18,9 +21,8 @@ logger = init_logger(__name__) ...@@ -18,9 +21,8 @@ logger = init_logger(__name__)
class BlockHashType(NamedTuple): class BlockHashType(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys. """Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. But please note that hash collisions when the hash value is the same. By using SHA256 however,
hash collisions can still theoretically occur, albeit with an extremely hash collisions are practically impossible.
low probability.
""" """
# Hash value of the block in an integer. # Hash value of the block in an integer.
hash_value: int hash_value: int
...@@ -30,6 +32,20 @@ class BlockHashType(NamedTuple): ...@@ -30,6 +32,20 @@ class BlockHashType(NamedTuple):
extra_keys: Optional[Any] = None extra_keys: Optional[Any] = None
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
'PYTHONHASHSEED') is not None else sha256(os.getenv('PYTHONHASHSEED'))
class PrefixCachingMetrics: class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the most recent N requests. """Metrics for prefix caching with a hit rate of the most recent N requests.
...@@ -148,7 +164,7 @@ class FreeKVCacheBlockQueue: ...@@ -148,7 +164,7 @@ class FreeKVCacheBlockQueue:
builtin deque to support removing a block in the middle of the queue builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks. prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated The queue is ordered by block ID in the beginning. When a block is allocated
...@@ -178,7 +194,7 @@ class FreeKVCacheBlockQueue: ...@@ -178,7 +194,7 @@ class FreeKVCacheBlockQueue:
def popleft(self) -> KVCacheBlock: def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1. """Pop the first free block and reduce num_free_blocks by 1.
Returns: Returns:
The first free block. The first free block.
""" """
...@@ -191,7 +207,7 @@ class FreeKVCacheBlockQueue: ...@@ -191,7 +207,7 @@ class FreeKVCacheBlockQueue:
def remove(self, block: KVCacheBlock) -> None: def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1. """Remove a block in the free list and reduce num_free_blocks by 1.
Args: Args:
block: The block to remove. block: The block to remove.
""" """
...@@ -235,7 +251,7 @@ class FreeKVCacheBlockQueue: ...@@ -235,7 +251,7 @@ class FreeKVCacheBlockQueue:
def get_all_free_blocks(self) -> list[KVCacheBlock]: def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing. """Get all free blocks in the free list. Mainly used for testing.
Returns: Returns:
A list of free blocks. A list of free blocks.
""" """
...@@ -251,10 +267,10 @@ def need_extra_keys(request: Request) -> bool: ...@@ -251,10 +267,10 @@ def need_extra_keys(request: Request) -> bool:
"""Check whether the blocks allocated to this request need extra hash keys. """Check whether the blocks allocated to this request need extra hash keys.
Args: Args:
request (Request): The request. request (Request): The request.
Returns: Returns:
bool: Whether blocks allocated to this request need extra hash keys. bool: Whether blocks allocated to this request need extra hash keys.
""" """
# Multimodal requests need to include the MM hash. # Multimodal requests need to include the MM hash.
...@@ -269,13 +285,13 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, ...@@ -269,13 +285,13 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
computation. For multi-modal inputs, the extra keys are computation. For multi-modal inputs, the extra keys are
(mm_hash, start_offset) that indicate a mm input contained in the (mm_hash, start_offset) that indicate a mm input contained in the
block and its starting offset in the block tokens. block and its starting offset in the block tokens.
Args: Args:
request: The request object. request: The request object.
start_token_idx: The start token index of the block. start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block. end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block. start_mm_idx: The start multi-modal index of the block.
Returns: Returns:
A tuple of extra keys and the next multi-modal index. A tuple of extra keys and the next multi-modal index.
""" """
...@@ -333,10 +349,10 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, ...@@ -333,10 +349,10 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
"""Generate extra keys related to LoRA for block hash computation. """Generate extra keys related to LoRA for block hash computation.
Args: Args:
request: The request object. request: The request object.
Returns: Returns:
Return LoRA id of the request if it is a LoRA request. Return empty Return LoRA id of the request if it is a LoRA request. Return empty
list otherwise. list otherwise.
...@@ -351,13 +367,13 @@ def generate_block_hash_extra_keys( ...@@ -351,13 +367,13 @@ def generate_block_hash_extra_keys(
start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from """Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID). the multi-modal inputs and request specific metadata (e.g., LoRA ID).
Args: Args:
request: The request object. request: The request object.
start_token_idx: The start token index of the block. start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block. end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block. start_mm_idx: The start multi-modal index of the block.
Returns: Returns:
A tuple of extra keys and the next multi-modal index. A tuple of extra keys and the next multi-modal index.
""" """
...@@ -375,6 +391,7 @@ def generate_block_hash_extra_keys( ...@@ -375,6 +391,7 @@ def generate_block_hash_extra_keys(
def hash_block_tokens( def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int], parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int], curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType: extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
...@@ -395,21 +412,16 @@ def hash_block_tokens( ...@@ -395,21 +412,16 @@ def hash_block_tokens(
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
""" """
if not parent_block_hash: if not parent_block_hash:
# Note that we use 'None' as a string here instead of None because parent_block_hash = NONE_HASH
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
parent_block_hash = hash('None')
curr_block_token_ids_tuple = tuple(curr_block_token_ids) curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType( return BlockHashType(
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)), hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys) curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(block_size: int, def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHashType]: request: Request) -> list[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of """Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching. token IDs. The hash value is used for prefix caching.
...@@ -441,7 +453,7 @@ def hash_request_tokens(block_size: int, ...@@ -441,7 +453,7 @@ def hash_request_tokens(block_size: int,
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx) request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value, block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys) block_token_ids, req_extra_keys)
ret.append(block_hash) ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value parent_block_hash_value = block_hash.hash_value
...@@ -452,7 +464,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, ...@@ -452,7 +464,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec], kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int): available_memory: int):
""" """
Checks whether `available_memory` is enough for the KV cache to hold at Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len. least one request with the model's max_model_len.
Args: Args:
...@@ -472,14 +484,14 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, ...@@ -472,14 +484,14 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
needed_memory = 0 needed_memory = 0
for layer_spec in kv_cache_spec.values(): for layer_spec in kv_cache_spec.values():
needed_memory += layer_spec.bytes_for_tokens(max_model_len) needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
if needed_memory > available_memory: if needed_memory > available_memory:
raise ValueError( raise ValueError(
f"To serve at least one request with the models's max seq len " f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV " f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache " f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GB). Try " f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
f"increasing `gpu_memory_utilization` or decreasing " f"increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.") f"`max_model_len` when initializing the engine.")
...@@ -489,15 +501,15 @@ def create_kv_cache_group_specs( ...@@ -489,15 +501,15 @@ def create_kv_cache_group_specs(
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
""" """
Create KVCacheGroupSpec object for each kv cache group layer. Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same The layers in the same group should share the same
KVCacheSpec. KVCacheSpec.
Args: Args:
kv_cache_spec: kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec. A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names: grouped_layer_names:
A list of kv cache groups, where each element is a list of layer A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same names that belong to the same group and should share the same
KVCacheSpec. KVCacheSpec.
Returns: Returns:
A list of KVCacheGroupSpec objects, one for each group. A list of KVCacheGroupSpec objects, one for each group.
...@@ -586,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, ...@@ -586,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config return kv_cache_config
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
Only models with one type of KV cache are supported yet. This function tries
to convert the KV cache specs to one type if the model is a hybrid model
with multiple type of KV cache. It will convert all SlidingWindowSpec to
FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
has_full_attention = any(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
)
def get_kv_cache_config(vllm_config: VllmConfig, def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec], kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig: available_memory: int) -> KVCacheConfig:
...@@ -602,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig, ...@@ -602,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig,
The generated KVCacheConfigs The generated KVCacheConfigs
""" """
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_uniform(kv_cache_spec): if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for # KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for # most models. Allocate the same amount of memory for
...@@ -614,11 +654,11 @@ def get_kv_cache_config(vllm_config: VllmConfig, ...@@ -614,11 +654,11 @@ def get_kv_cache_config(vllm_config: VllmConfig,
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
""" """
Make the KV cache configurations for each worker consistent, so that all Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager. workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same, This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers. and changes the num_blocks of each worker to the smallest among all workers.
Args: Args:
kv_cache_configs: The KV cache configurations for each worker. Will be kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent. in-place modified to make them consistent.
......
...@@ -10,8 +10,7 @@ if TYPE_CHECKING: ...@@ -10,8 +10,7 @@ if TYPE_CHECKING:
import numpy.typing as npt import numpy.typing as npt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request from vllm.v1.request import Request
......
...@@ -7,9 +7,9 @@ from collections import deque ...@@ -7,9 +7,9 @@ from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Optional, Union
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
SpeculativeConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget) compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
...@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, ...@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
from vllm.v1.core.sched.utils import check_stop from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface): ...@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface):
model_config: ModelConfig, model_config: ModelConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig], kv_cache_config: KVCacheConfig,
log_stats: bool,
structured_output_manager: StructuredOutputManager, structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.speculative_config = speculative_config self.kv_cache_config = kv_cache_config
self.log_stats = log_stats self.log_stats = log_stats
self.structured_output_manager = structured_output_manager self.structured_output_manager = structured_output_manager
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.include_finished_set = include_finished_set
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \ self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len self.max_model_len = self.scheduler_config.max_model_len
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size, kv_cache_config=kv_cache_config,
num_gpu_blocks=num_gpu_blocks,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window, enable_caching=cache_config.enable_prefix_caching,
enable_caching=self.cache_config.enable_prefix_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
log_stats=self.log_stats) log_stats=self.log_stats)
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
...@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface): ...@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=model_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
mm_registry=mm_registry,
) )
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
...@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface): ...@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens) request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( if request.has_encoder_inputs:
self._try_schedule_encoder_inputs(request, (encoder_inputs_to_schedule, num_new_tokens,
request.num_computed_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs(
num_new_tokens, request, request.num_computed_tokens, num_new_tokens,
encoder_budget)) encoder_budget)
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled because the encoder budget # The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted. # or the encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`, # NOTE(woosuk): By using `continue` instead of `break` here,
# we do not strictly follow the FCFS scheduling policy and # we intentionally relax the strict FCFS scheduling policy
# allow the lower-priority requests to be scheduled. # to allow lower-priority requests to be scheduled when a
req_index += 1 # higher-priority request is blocked by encoder constraints.
continue req_index += 1
continue
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
...@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface): ...@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs # Record the LoRAs in scheduled_running_reqs
requested_loras: set[int] = set() scheduled_loras: set[int] = set()
if self.lora_config: if self.lora_config:
requested_loras = set( scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0) if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped # Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later # and put back at the head of the waiting queue later
waiting_for_fsm: deque[Request] = deque() skipped_waiting_requests: deque[Request] = deque()
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
...@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface): ...@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0] request = self.waiting[0]
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM: if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar: if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
else: else:
waiting_structured_output_req = self.waiting.popleft() self.waiting.popleft()
waiting_for_fsm.appendleft( skipped_waiting_requests.appendleft(request)
waiting_structured_output_req)
continue continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request and (
req_lora_id = request.lora_request.lora_int_id len(scheduled_loras) == self.lora_config.max_loras
if len(requested_loras) == self.lora_config.max_loras and ( and request.lora_request.lora_int_id
req_lora_id not in requested_loras): not in scheduled_loras):
# Cannot schedule. # Scheduling would exceed max_loras, skip.
# TODO (varun): This means all the other requests in self.waiting.popleft()
# the WAITING queue will be blocked by this request, skipped_waiting_requests.appendleft(request)
# even if, continue
# 1. these other requests do not use LoRA, or,
# 2. these other requests use the already requested
# LoRAs.
# This is too conservative and could be optimized.
break
# Get already-cached tokens. # Get already-cached tokens.
computed_blocks, num_computed_tokens = \ computed_blocks, num_computed_tokens = \
...@@ -288,28 +300,25 @@ class Scheduler(SchedulerInterface): ...@@ -288,28 +300,25 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed requests, # `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens. # which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0: if (0 < self.scheduler_config.long_prefill_token_threshold <
# This happens when prompt length is divisible by the block num_new_tokens):
# size and all blocks are cached. Now we force to recompute num_new_tokens = (
# the last block. Note that we have to re-compute an entire self.scheduler_config.long_prefill_token_threshold)
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs. # Schedule encoder inputs.
(encoder_inputs_to_schedule, num_new_tokens, if request.has_encoder_inputs:
new_encoder_budget) = self._try_schedule_encoder_inputs( (encoder_inputs_to_schedule, num_new_tokens,
request, num_computed_tokens, num_new_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs(
encoder_budget) request, num_computed_tokens, num_new_tokens,
if num_new_tokens == 0: encoder_budget)
# The request cannot be scheduled. if num_new_tokens == 0:
break # The request cannot be scheduled.
break
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks) request, num_new_tokens, computed_blocks)
...@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface): ...@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface):
f"Invalid request status: {request.status}") f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
requested_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks b.block_id for b in computed_blocks + new_blocks
] ]
...@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface): ...@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if waiting_for_fsm: if skipped_waiting_requests:
self.waiting.extendleft(waiting_for_fsm) self.waiting.extendleft(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
...@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface): ...@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface):
grammar_bitmask=grammar_bitmask, grammar_bitmask=grammar_bitmask,
) )
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
self.finished_req_ids = set() self.finished_req_ids = set()
return scheduler_output return scheduler_output
...@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface): ...@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface):
limitations, the method adjusts `num_new_tokens` to schedule only the limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input. decoder tokens up to just before the unschedulable encoder input.
""" """
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: list[int] = [] encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions mm_positions = request.mm_positions
assert mm_positions is not None assert mm_positions is not None
...@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface): ...@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface):
new_running: list[Request] = [] new_running: list[Request] = []
outputs: list[EngineCoreOutput] = [] outputs: list[EngineCoreOutput] = []
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid # loop can be a performance bottleneck. We should do our best to avoid
...@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface): ...@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface):
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] generated_token_ids = sampled_token_ids[req_index]
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up scheduled_spec_token_ids = (
# its num_tokens, the request generates output tokens. scheduler_output.scheduled_spec_decode_tokens.get(req_id))
# Otherwise, we ignore the sampler output for the request. if scheduled_spec_token_ids:
request.num_computed_tokens += num_tokens_scheduled # num_computed_tokens represents the number of tokens
assert request.num_computed_tokens <= request.num_tokens
else:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled
# tokens and rejections. # tokens and rejections. If some tokens are rejected,
# It is calculated as: # num_computed_tokens is decreased by the number of rejected
# num_computed_tokens_step = num_scheduled_tokens - # tokens, where is given by:
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids = ( num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
scheduler_output.scheduled_spec_decode_tokens[req_id]) len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
num_computed_tokens_step = num_scheduled_tokens[req_id] - ( spec_decoding_stats = self.make_spec_decoding_stats(
len(scheduled_spec_token_ids) + 1 - spec_decoding_stats,
len(generated_token_ids)) num_draft_tokens=len(scheduled_spec_token_ids),
request.num_computed_tokens += num_computed_tokens_step num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = ( cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request)) self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty. # OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids: if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids): for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"] mm_positions = request.mm_positions[input_id]
num_tokens = request.mm_positions[input_id]["length"] start_pos = mm_positions["offset"]
num_tokens = mm_positions["length"]
if start_pos + num_tokens <= request.num_computed_tokens: if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored # The encoder output is already processed and stored
# in the decoder's KV cache. # in the decoder's KV cache.
...@@ -595,35 +610,34 @@ class Scheduler(SchedulerInterface): ...@@ -595,35 +610,34 @@ class Scheduler(SchedulerInterface):
stopped = False stopped = False
new_logprobs = None new_logprobs = None
new_token_ids: list[int] = [] new_token_ids = generated_token_ids
if request.num_computed_tokens >= request.num_tokens: # Append generated tokens and check for stop. Note that if
for output_token_id in generated_token_ids: # a request is still being prefilled, we expect the model runner
request.append_output_token_ids(output_token_id) # to return empty token ids for the request.
new_token_ids.append(output_token_id) for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # Check for stop and update request state.
stopped = check_stop(request, self.max_model_len) # This must be called before we make the EngineCoreOutput.
if stopped: stopped = check_stop(request, self.max_model_len)
self._free_request(request) if stopped:
break self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None: if request.sampling_params.logprobs is not None and logprobs:
assert logprobs is not None # NOTE: once we support N tokens per step (spec decode),
# NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1.
# the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1)
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output: if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request # NOTE: structured_output_request
# should not be None if use_structured_output, we have # should not be None if use_structured_output, we have
# check above, so safe to ignore type warning # check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id, req_id, new_token_ids)
new_token_ids,
)
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
...@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface): ...@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs. # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id) self.scheduled_req_ids.remove(req_id)
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
self.running = new_running self.running = new_running
return EngineCoreOutputs( engine_core_outputs = EngineCoreOutputs(
outputs=outputs, outputs=outputs,
scheduler_stats=self.make_stats(), scheduler_stats=self.make_stats(spec_decoding_stats),
) )
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
return engine_core_outputs
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.append(request)
...@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface): ...@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface):
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache() return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> Optional[SchedulerStats]: def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats: if not self.log_stats:
return None return None
return SchedulerStats( return SchedulerStats(
...@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface): ...@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage, gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
spec_decoding_stats=spec_decoding_stats,
) )
def make_spec_decoding_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats],
num_draft_tokens: int,
num_accepted_tokens: int,
) -> Optional[SpecDecodingStats]:
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats()
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
return spec_decoding_stats
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
class SpecializedManager(ABC):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise NotImplementedError
class FullAttentionManager(SpecializedManager):
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# No need to remove blocks for full attention.
return []
class SlidingWindowManager(SpecializedManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec,
block_pool: BlockPool):
super().__init__(kv_cache_spec, block_pool)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
self._null_block = block_pool.null_block
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes)
num_contiguous_blocks = 0
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
num_contiguous_blocks += 1
if (num_contiguous_blocks
>= self.sliding_window_contiguous_blocks):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
return computed_blocks
else:
num_contiguous_blocks = 0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
return removed_blocks
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
block_pool: BlockPool) -> SpecializedManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, block_pool)
return manager
...@@ -128,12 +128,18 @@ class EngineCoreOutputs( ...@@ -128,12 +128,18 @@ class EngineCoreOutputs(
#NOTE(Nick): We could consider ways to make this more compact, #NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout # e.g. columnwise layout
engine_index: int = 0
# [num_reqs] # [num_reqs]
outputs: list[EngineCoreOutput] = [] outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0 timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the engine is paused.
engine_paused: bool = False
def __post_init__(self): def __post_init__(self):
if self.timestamp == 0.0: if self.timestamp == 0.0:
...@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum): ...@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
""" """
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
UTILITY = b'\x02' START_DP = b'\x02'
UTILITY = b'\x03'
...@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig ...@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient): ...@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
...@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient): ...@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats self.log_stats = log_stats
self.stat_loggers: list[StatLoggerBase] = []
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats: if self.log_stats:
if logger.isEnabledFor(logging.INFO): for i in range(vllm_config.parallel_config.data_parallel_size):
self.stat_loggers.append(LoggingStatLogger()) loggers: list[StatLoggerBase] = []
self.stat_loggers.append(PrometheusStatLogger(vllm_config)) if logger.isEnabledFor(logging.INFO):
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
...@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient): ...@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient):
self.processor = Processor( self.processor = Processor(
vllm_config=vllm_config, vllm_config=vllm_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
input_registry=input_registry, mm_registry=mm_registry,
) )
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput). # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
...@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient): ...@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
self._record_stats( self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
) )
...@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient): ...@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient):
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
engine_index: int = 0,
): ):
if not self.log_stats: if not self.log_stats:
return return
assert scheduler_stats is not None assert scheduler_stats is not None
for stat_logger in self.stat_loggers: for stat_logger in self.stat_loggers[engine_index]:
stat_logger.record(scheduler_stats=scheduler_stats, stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats) iteration_stats=iteration_stats)
...@@ -393,8 +402,9 @@ class AsyncLLM(EngineClient): ...@@ -393,8 +402,9 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None, scheduler_outputs=None,
model_output=None, model_output=None,
) -> None: ) -> None:
for stat_logger in self.stat_loggers: for loggers in self.stat_loggers:
stat_logger.log() for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None: async def check_health(self) -> None:
logger.debug("Called check_health.") logger.debug("Called check_health.")
...@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient): ...@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient):
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level) await self.engine_core.sleep_async(level)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async() await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async() return await self.engine_core.is_sleeping_async()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import queue import queue
import signal import signal
import sys
import threading import threading
import time import time
from concurrent.futures import Future from concurrent.futures import Future
from inspect import isclass, signature from inspect import isclass, signature
from multiprocessing.connection import Connection from logging import DEBUG
from typing import Any, Optional from typing import Any, Callable, Optional, TypeVar, Union
import msgspec import msgspec
import psutil import psutil
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from vllm.config import VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
...@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, ...@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx) zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs) unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
...@@ -39,6 +44,8 @@ logger = init_logger(__name__) ...@@ -39,6 +44,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5 POLLING_TIMEOUT_S = 2.5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore: class EngineCore:
"""Inner loop of vLLM's Engine.""" """Inner loop of vLLM's Engine."""
...@@ -60,8 +67,9 @@ class EngineCore: ...@@ -60,8 +67,9 @@ class EngineCore:
self.model_executor = executor_class(vllm_config) self.model_executor = executor_class(vllm_config)
# Setup KV Caches and update CacheConfig after profiling. # Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
vllm_config) self._initialize_kv_caches(vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
...@@ -84,14 +92,16 @@ class EngineCore: ...@@ -84,14 +92,16 @@ class EngineCore:
"compatibility may not be maintained.", "compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls) vllm_config.scheduler_config.scheduler_cls)
self.scheduler = Scheduler( self.scheduler: SchedulerInterface = Scheduler(
scheduler_config=vllm_config.scheduler_config, scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config, lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config, kv_cache_config=kv_cache_config,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager, structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
) )
# Setup MM Input Mapper. # Setup MM Input Mapper.
...@@ -110,8 +120,8 @@ class EngineCore: ...@@ -110,8 +120,8 @@ class EngineCore:
self.batch_queue_size) self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(self, def _initialize_kv_caches(
vllm_config: VllmConfig) -> tuple[int, int]: self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time() start = time.time()
# Get all kv cache needed by the model # Get all kv cache needed by the model
...@@ -136,13 +146,14 @@ class EngineCore: ...@@ -136,13 +146,14 @@ class EngineCore:
unify_kv_cache_configs(kv_cache_configs) unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use # All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks. # an arbitrary one to initialize the scheduler.
assert all([ assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs for cfg in kv_cache_configs
]) ])
num_gpu_blocks = kv_cache_configs[0].num_blocks num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0 num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution # Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs) self.model_executor.initialize_from_config(kv_cache_configs)
...@@ -150,7 +161,7 @@ class EngineCore: ...@@ -150,7 +161,7 @@ class EngineCore:
elapsed = time.time() - start elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, " logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed) "warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def add_request(self, request: EngineCoreRequest): def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler.""" """Add request to the scheduler."""
...@@ -253,8 +264,8 @@ class EngineCore: ...@@ -253,8 +264,8 @@ class EngineCore:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.model_executor.sleep(level) self.model_executor.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up() self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
...@@ -274,6 +285,24 @@ class EngineCore: ...@@ -274,6 +285,24 @@ class EngineCore:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id) return self.model_executor.pin_lora(lora_id)
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_executor.save_sharded_state(path=path,
pattern=pattern,
max_size=max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
...@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore): ...@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore):
self, self,
input_path: str, input_path: str,
output_path: str, output_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
engine_index: int = 0,
): ):
super().__init__(vllm_config, executor_class, log_stats) super().__init__(vllm_config, executor_class, log_stats)
...@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore): ...@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore):
args=(input_path, ), args=(input_path, ),
daemon=True).start() daemon=True).start()
threading.Thread(target=self.process_output_socket, threading.Thread(target=self.process_output_socket,
args=(output_path, ), args=(output_path, engine_index),
daemon=True).start() daemon=True).start()
# Send Readiness signal to EngineClient. self.global_unfinished_reqs = False
ready_pipe.send({"status": "READY"})
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod @staticmethod
def run_engine_core(*args, **kwargs): def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process.""" """Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination. # Signal handler used for graceful termination.
...@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore): ...@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
engine_core = None engine_core: Optional[EngineCoreProc] = None
try: try:
engine_core = EngineCoreProc(*args, **kwargs) parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop() engine_core.run_busy_loop()
except SystemExit: except SystemExit:
...@@ -350,28 +397,44 @@ class EngineCoreProc(EngineCore): ...@@ -350,28 +397,44 @@ class EngineCoreProc(EngineCore):
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore.""" """Core busy loop of the EngineCore."""
step_fn = (self.step
if self.batch_queue is None else self.step_with_batch_queue)
# Loop until process is sent a SIGINT or SIGTERM # Loop until process is sent a SIGINT or SIGTERM
while True: while True:
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
while not self.scheduler.has_requests(): self._process_input_queue()
logger.debug("EngineCore busy loop waiting.") # 2) Step the engine core and return the outputs.
req = self.input_queue.get() self._process_engine_step()
self._handle_client_request(*req)
def _process_input_queue(self):
# 2) Handle any new client requests. """Exits when an engine step needs to be performed."""
while not self.input_queue.empty():
req = self.input_queue.get_nowait() waited = False
self._handle_client_request(*req) while not self.global_unfinished_reqs and not (
self.scheduler.has_requests()):
# 3) Step the engine core. if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
outputs = step_fn() logger.debug("EngineCore waiting for work.")
waited = True
# 4) Put EngineCoreOutputs into the output queue. req = self.input_queue.get()
if outputs is not None: self._handle_client_request(*req)
self.output_queue.put_nowait(outputs)
if waited:
logger.debug(
"EngineCore loop active - local unfinished: %s, finished: %s.",
self.scheduler.has_unfinished_requests(),
self.scheduler.has_finished_requests())
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_engine_step(self):
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs = self.step_fn()
# Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
def _handle_client_request(self, request_type: EngineCoreRequestType, def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
...@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore): ...@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore):
self.add_request(request) self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT: elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request) self.abort_requests(request)
elif request_type == EngineCoreRequestType.START_DP:
if not self.global_unfinished_reqs:
logger.debug("EngineCore starting idle loop.")
self.global_unfinished_reqs = True
elif request_type == EngineCoreRequestType.UTILITY: elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request call_id, method_name, args = request
output = UtilityOutput(call_id) output = UtilityOutput(call_id)
...@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore): ...@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str): def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread.""" """Output socket IO thread."""
# Msgpack serialization encoding. # Msgpack serialization encoding.
...@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore): ...@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True: while True:
outputs = self.output_queue.get() outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer) encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False) socket.send(buffer, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm.platforms.cuda import device_id_to_physical_device_id
tp_size = vllm_config.parallel_config.tensor_parallel_size
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if local_unfinished_reqs:
# 2) Step the engine core.
self._process_engine_step()
# Check if we have now finished all requests.
local_unfinished_reqs = (
self.scheduler.has_unfinished_requests())
else:
if self.scheduler.has_finished_requests():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self._process_engine_step()
if not self.global_unfinished_reqs:
# All engines are idle.
continue
# There must be unfinished requests in DP peers, run a
# dummy forward pass.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.global_unfinished_reqs:
# Notify client that we are pausing the loop.
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 16 steps.
self.counter += 1
if self.counter != 16:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
...@@ -8,10 +8,11 @@ import threading ...@@ -8,10 +8,11 @@ import threading
import uuid import uuid
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Any, Optional, Union from typing import Any, Callable, Optional, TypeVar, Union
import zmq import zmq
import zmq.asyncio import zmq.asyncio
...@@ -32,6 +33,8 @@ logger = init_logger(__name__) ...@@ -32,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]] AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCoreClient(ABC): class EngineCoreClient(ABC):
""" """
...@@ -60,6 +63,9 @@ class EngineCoreClient(ABC): ...@@ -60,6 +63,9 @@ class EngineCoreClient(ABC):
"is not currently supported.") "is not currently supported.")
if multiprocess_mode and asyncio_mode: if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class, log_stats) return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode: if multiprocess_mode and not asyncio_mode:
...@@ -86,7 +92,7 @@ class EngineCoreClient(ABC): ...@@ -86,7 +92,7 @@ class EngineCoreClient(ABC):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
...@@ -113,6 +119,19 @@ class EngineCoreClient(ABC): ...@@ -113,6 +119,19 @@ class EngineCoreClient(ABC):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
raise NotImplementedError
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
...@@ -128,7 +147,7 @@ class EngineCoreClient(ABC): ...@@ -128,7 +147,7 @@ class EngineCoreClient(ABC):
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
...@@ -149,6 +168,20 @@ class EngineCoreClient(ABC): ...@@ -149,6 +168,20 @@ class EngineCoreClient(ABC):
async def pin_lora_async(self, lora_id: int) -> bool: async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
async def save_sharded_state_async(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
raise NotImplementedError
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
class InprocClient(EngineCoreClient): class InprocClient(EngineCoreClient):
""" """
...@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient): ...@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
...@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient): ...@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
self.engine_core.save_sharded_state(path, pattern, max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngine:
"""One per data parallel rank."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name=f"EngineCore_{index}",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"dp_rank": index,
"local_dp_rank": local_dp_rank,
"executor_class": executor_class,
"log_stats": log_stats,
})
self.num_reqs_in_flight = 0
finally:
if not hasattr(self, "num_reqs_in_flight"):
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass @dataclass
class BackgroundResources: class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding """Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object.""" circular reference back to the client object."""
ctx: zmq.Context ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None shutdown_path: Optional[str] = None
def __call__(self): def __call__(self):
"""Clean up background resources.""" """Clean up background resources."""
if self.proc_handle is not None: for core_engine in self.core_engines:
self.proc_handle.shutdown() core_engine.close()
# ZMQ context termination can hang if the sockets # ZMQ context termination can hang if the sockets
# aren't explicitly closed first. # aren't explicitly closed first.
if self.output_socket is not None: if self.output_socket is not None:
self.output_socket.close(linger=0) self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None: if self.shutdown_path is not None:
# We must ensure that the sync output socket is # We must ensure that the sync output socket is
# closed cleanly in its own thread. # closed cleanly in its own thread.
...@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient): ...@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient):
self.decoder = MsgpackDecoder(EngineCoreOutputs) self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
sync_ctx = zmq.Context() sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed # This will ensure resources created so far are closed
...@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient): ...@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient):
self.resources = BackgroundResources(ctx=sync_ctx) self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources) self._finalizer = weakref.finalize(self, self.resources)
# Paths for IPC. # Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path() self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
# Start EngineCore in background process. new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
self.resources.proc_handle = BackgroundProcHandle( vllm_config, executor_class, log_stats, self.ctx, self.output_path,
input_path=input_path, index, local_dp_rank)
output_path=self.output_path,
process_name="EngineCore", # Start engine core process(es).
target_fn=EngineCoreProc.run_engine_core, self._init_core_engines(vllm_config, new_core_engine,
process_kwargs={ self.resources.core_engines)
"vllm_config": vllm_config,
"executor_class": executor_class, # Wait for engine core process(es) to start.
"log_stats": log_stats, for engine in self.resources.core_engines:
}) engine.proc_handle.wait_for_startup()
# Create input socket.
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.input_socket = self.resources.input_socket
self.utility_results: dict[int, AnyFuture] = {} self.utility_results: dict[int, AnyFuture] = {}
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Default case - single core engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
core_engine = new_core_engine(
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
core_engines.append(core_engine)
self.core_engine = core_engine
def shutdown(self): def shutdown(self):
self._finalizer() self._finalizer()
...@@ -356,9 +458,9 @@ class SyncMPClient(MPClient): ...@@ -356,9 +458,9 @@ class SyncMPClient(MPClient):
def process_outputs_socket(): def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR) shutdown_socket = ctx.socket(zmq.PAIR)
shutdown_socket.bind(shutdown_path)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try: try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller() poller = zmq.Poller()
poller.register(shutdown_socket) poller.register(shutdown_socket)
poller.register(out_socket) poller.register(out_socket)
...@@ -370,7 +472,7 @@ class SyncMPClient(MPClient): ...@@ -370,7 +472,7 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread. # shutdown signal, exit thread.
break break
(frame, ) = out_socket.recv_multipart(copy=False) frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer) outputs = decoder.decode(frame.buffer)
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, _process_utility_output(outputs.utility_output,
...@@ -391,18 +493,15 @@ class SyncMPClient(MPClient): ...@@ -391,18 +493,15 @@ class SyncMPClient(MPClient):
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get() return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, def _send_input(self, request_type: EngineCoreRequestType, request: Any):
request: Any) -> None:
# (RequestType, SerializedRequest) # (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False) self.core_engine.send_multipart(msg)
def _call_utility(self, method: str, *args) -> Any: def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64 call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future() future: Future[Any] = Future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY, self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args)) (call_id, method, args))
...@@ -419,34 +518,48 @@ class SyncMPClient(MPClient): ...@@ -419,34 +518,48 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids) self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self._call_utility("profile", is_start) self.call_utility("profile", is_start)
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self._call_utility("reset_prefix_cache") self.call_utility("reset_prefix_cache")
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self._call_utility("add_lora", lora_request) return self.call_utility("add_lora", lora_request)
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self._call_utility("remove_lora", lora_id) return self.call_utility("remove_lora", lora_id)
def list_loras(self) -> set[int]: def list_loras(self) -> set[int]:
return self._call_utility("list_loras") return self.call_utility("list_loras")
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self._call_utility("pin_lora", lora_id) return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self._call_utility("sleep", level) self.call_utility("sleep", level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self._call_utility("wake_up") self.call_utility("wake_up", tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self._call_utility("is_sleeping") return self.call_utility("is_sleeping")
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self._call_utility("execute_dummy_batch") self.call_utility("execute_dummy_batch")
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.call_utility("collective_rpc", method, timeout, args,
kwargs)
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
self.call_utility("save_sharded_state", path, pattern, max_size)
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
...@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient): ...@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient):
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None self.queue_task: Optional[asyncio.Task] = None
async def _start_output_queue_task(self): self.outputs_handler: Optional[Callable[
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
def _ensure_output_queue_task(self):
if self.outputs_queue is not None:
return
# Perform IO in separate task to parallelize as much as possible. # Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client. # Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue() self.outputs_queue = asyncio.Queue()
decoder = self.decoder decoder = self.decoder
utility_results = self.utility_results utility_results = self.utility_results
outputs_queue = self.outputs_queue outputs_queue = self.outputs_queue
output_handler = self.outputs_handler
_self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path, output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL) zmq.constants.PULL)
...@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient): ...@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient):
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, _process_utility_output(outputs.utility_output,
utility_results) utility_results)
else: continue
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs) outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket(), self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask") name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None: self._ensure_output_queue_task()
await self._start_output_queue_task() assert self.outputs_queue is not None
assert self.outputs_queue is not None
return await self.outputs_queue.get() return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType, async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
msg = (request_type.value, self.encoder.encode(request)) self._ensure_output_queue_task()
await self.input_socket.send_multipart(msg, copy=False)
if self.outputs_queue is None: async def call_utility_async(self, method: str, *args) -> Any:
await self._start_output_queue_task() return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(self, method: str, *args) -> Any: async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
call_id = uuid.uuid1().int >> 64 call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future() future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
await self._send_input(EngineCoreRequestType.UTILITY, message = (EngineCoreRequestType.UTILITY.value,
(call_id, method, args)) self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
self._ensure_output_queue_task()
return await future return await future
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
...@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient): ...@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._call_utility_async("profile", is_start) await self.call_utility_async("profile", is_start)
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
await self._call_utility_async("reset_prefix_cache") await self.call_utility_async("reset_prefix_cache")
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
await self._call_utility_async("sleep", level) await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
await self._call_utility_async("wake_up") await self.call_utility_async("wake_up", tags)
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
return await self._call_utility_async("is_sleeping") return await self.call_utility_async("is_sleeping")
async def execute_dummy_batch_async(self) -> None: async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch") await self.call_utility_async("execute_dummy_batch")
async def add_lora_async(self, lora_request: LoRARequest) -> bool: async def add_lora_async(self, lora_request: LoRARequest) -> bool:
return await self._call_utility_async("add_lora", lora_request) return await self.call_utility_async("add_lora", lora_request)
async def remove_lora_async(self, lora_id: int) -> bool: async def remove_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("remove_lora", lora_id) return await self.call_utility_async("remove_lora", lora_id)
async def list_loras_async(self) -> set[int]: async def list_loras_async(self) -> set[int]:
return await self._call_utility_async("list_loras") return await self.call_utility_async("list_loras")
async def pin_lora_async(self, lora_id: int) -> bool: async def pin_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("pin_lora", lora_id) return await self.call_utility_async("pin_lora", lora_id)
async def save_sharded_state_async(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
await self.call_utility_async("save_sharded_state", path, pattern,
max_size)
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return await self.call_utility_async("collective_rpc", method, timeout,
args, kwargs)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(vllm_config, executor_class, log_stats)
assert len(self.core_engines) > 1
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Launch a core engine for each data parallel rank.
dp_size = vllm_config.parallel_config.data_parallel_size
for i in range(dp_size):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines.append(new_core_engine(i, i))
self.core_engines = core_engines
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (await asyncio.gather(*[
self._call_utility_async(method, *args, engine=engine)
for engine in self.core_engines
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
])
self._ensure_output_queue_task()
def get_core_engine_for_request(self) -> CoreEngine:
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
@staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient",
outputs: EngineCoreOutputs):
if self.reqs_in_flight:
for req_id in outputs.finished_requests or ():
if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1
if outputs.engine_paused:
assert self.num_engines_running >= 1
self.num_engines_running -= 1
if not self.num_engines_running and self.reqs_in_flight:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
if coros:
await asyncio.gather(*coros)
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids:
return
if len(request_ids) == 1:
# Fast-path common case.
if engine := self.reqs_in_flight.get(request_ids[0]):
await self._abort_requests(request_ids, engine)
return
by_engine: dict[CoreEngine, list[str]] = {}
for req_id in request_ids:
if engine := self.reqs_in_flight.get(req_id):
by_engine.setdefault(engine, []).append(req_id)
for engine, req_ids in by_engine.items():
await self._abort_requests(req_ids, engine)
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))
...@@ -2,15 +2,16 @@ ...@@ -2,15 +2,16 @@
from collections.abc import Mapping from collections.abc import Mapping
from copy import copy from copy import copy
from typing import Optional, Union from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor ...@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__) logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_R = TypeVar("_R", default=Any)
class LLMEngine: class LLMEngine:
...@@ -43,7 +45,6 @@ class LLMEngine: ...@@ -43,7 +45,6 @@ class LLMEngine:
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None, stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
multiprocess_mode: bool = False, multiprocess_mode: bool = False,
...@@ -60,11 +61,13 @@ class LLMEngine: ...@@ -60,11 +61,13 @@ class LLMEngine:
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
# important: init dp group before init the engine_core # important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config # In the decoupled engine case this is handled in EngineCoreProc.
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
if self.dp_enabled:
self.dp_group = self.parallel_config.stateless_init_dp_group()
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
...@@ -77,7 +80,6 @@ class LLMEngine: ...@@ -77,7 +80,6 @@ class LLMEngine:
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config, self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry) mm_registry=mm_registry)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
...@@ -148,7 +150,7 @@ class LLMEngine: ...@@ -148,7 +150,7 @@ class LLMEngine:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests() has_unfinished = self.output_processor.has_unfinished_requests()
if not self.dp_enabled: if self.dp_group is None:
return has_unfinished return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished) return self.has_unfinished_requests_dp(has_unfinished)
...@@ -243,8 +245,8 @@ class LLMEngine: ...@@ -243,8 +245,8 @@ class LLMEngine:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
...@@ -280,3 +282,14 @@ class LLMEngine: ...@@ -280,3 +282,14 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted.""" """Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
...@@ -328,7 +328,7 @@ class OutputProcessor: ...@@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks. # 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update( stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP) new_token_ids, finish_reason == FinishReason.STOP)
if stop_string and finish_reason != FinishReason.STOP: if stop_string:
finish_reason = FinishReason.STOP finish_reason = FinishReason.STOP
stop_reason = stop_string stop_reason = stop_string
......
...@@ -5,9 +5,8 @@ from collections.abc import Mapping ...@@ -5,9 +5,8 @@ from collections.abc import Mapping
from typing import Optional, Union from typing import Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import ProcessorInputs, PromptType
PromptType, SingletonInputsAdapter) from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
...@@ -31,7 +30,6 @@ class Processor: ...@@ -31,7 +30,6 @@ class Processor:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
tokenizer: BaseTokenizerGroup, tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
...@@ -123,7 +121,8 @@ class Processor: ...@@ -123,7 +121,8 @@ class Processor:
return return
supported_backends = [ supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto" "xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
] ]
engine_level_backend = self.decoding_config.guided_decoding_backend engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends: if engine_level_backend not in supported_backends:
...@@ -137,13 +136,15 @@ class Processor: ...@@ -137,13 +136,15 @@ class Processor:
f" != {engine_level_backend}") f" != {engine_level_backend}")
else: else:
params.guided_decoding.backend = engine_level_backend params.guided_decoding.backend = engine_level_backend
import vllm.platforms
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
# Request content validation # Request content validation
if engine_level_backend.startswith("xgrammar"):
if engine_level_backend == "xgrammar":
# xgrammar with no fallback # xgrammar with no fallback
validate_structured_output_request_xgrammar(params) validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar" params.guided_decoding.backend = engine_level_backend
elif engine_level_backend == "auto": elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to # "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the # choose a backend based on request contents. This is not the
...@@ -157,12 +158,13 @@ class Processor: ...@@ -157,12 +158,13 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance. # are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance" params.guided_decoding.backend = "guidance"
if params.guided_decoding.backend == "guidance": if engine_level_backend.startswith("guidance"):
# TODO ideally we would have the LLTokenizer here as Lark syntax # TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see # allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars. # Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None) validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = engine_level_backend
def process_inputs( def process_inputs(
self, self,
...@@ -206,14 +208,7 @@ class Processor: ...@@ -206,14 +208,7 @@ class Processor:
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
# TODO: Impl encoder-decoder # TODO: Impl encoder-decoder
if encoder_inputs is not None: if encoder_inputs is not None:
...@@ -224,8 +219,9 @@ class Processor: ...@@ -224,8 +219,9 @@ class Processor:
sampling_params = params.clone() sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len. # If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None: if sampling_params.max_tokens is None:
sampling_params.max_tokens = (self.model_config.max_model_len - sampling_params.max_tokens = (
len(decoder_inputs.prompt_token_ids)) self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id) self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer( sampling_params.update_from_tokenizer(
...@@ -235,57 +231,46 @@ class Processor: ...@@ -235,57 +231,46 @@ class Processor:
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None sorted_mm_hashes: Optional[list[str]] = None
if (decoder_mm_inputs := decoder_inputs.multi_modal_data): if decoder_inputs["type"] == "multimodal":
assert isinstance(decoder_mm_inputs, MultiModalKwargs) decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_inputs.modalities
for item in decoder_mm_inputs.get_items(modality)
]
# Merge and flatten multimodal placeholders, hashes and inputs # Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position # from dictionaries to lists, and sort them by each item's position
# in the input sequence. # in the input sequence.
# NOTE: interleaved modalities are not supported.
( (
sorted_modalities, sorted_item_modalities,
sorted_mm_positions, sorted_mm_positions,
sorted_mm_hashes, sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata( ) = merge_and_sort_multimodal_metadata(
decoder_inputs.multi_modal_placeholders, decoder_inputs["mm_placeholders"],
decoder_inputs.multi_modal_hashes if self.use_hash else None, decoder_inputs["mm_hashes"] if self.use_hash else None,
) )
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple # The output of merged multi-modal processor (`decoder_mm_inputs`)
# modalities involved. # is a single MultiModalKwargs for all items from all modalities.
if len(sorted_modalities) > 1: # This code flattens kwargs for individual items in a list and
modality_order_dict = { # sorts them by each item's position in the input sequence if there
modality: order # are multiple modalities.
for order, modality in enumerate(sorted_modalities) unique_modalities = set(sorted_item_modalities)
} if len(unique_modalities) > 1:
sorted_mm_inputs = []
# Sanity check to make sure each multimodal input has only one used_indices = {modality: 0 for modality in unique_modalities}
# modality key. for modality in sorted_item_modalities:
for mm_input in individual_mm_inputs: items = decoder_mm_inputs.get_items(modality)
assert len(mm_input.modalities) == 1 item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
# Sort MultiModalKwargs to match sorted_mm_positions ]))
sorted_mm_inputs = sorted( used_indices[modality] += 1
individual_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
else: else:
sorted_mm_inputs = individual_mm_inputs sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
return EngineCoreRequest( return EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt=decoder_inputs.prompt, prompt=decoder_inputs.get("prompt"),
prompt_token_ids=decoder_inputs.prompt_token_ids, prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs, mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes, mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions, mm_placeholders=sorted_mm_positions,
...@@ -298,15 +283,16 @@ class Processor: ...@@ -298,15 +283,16 @@ class Processor:
def _validate_model_inputs(self, def _validate_model_inputs(self,
inputs: ProcessorInputs, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None): lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # For encoder-decoder multimodal models, the max_prompt_len
prompt_inputs = inputs["decoder" if self.model_config. # restricts the decoder prompt length
is_multimodal_model else "encoder"] if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else: else:
prompt_inputs = inputs prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
......
...@@ -235,7 +235,10 @@ class WorkerProc: ...@@ -235,7 +235,10 @@ class WorkerProc:
worker_response_mq_handle = self.worker_response_mq.export_handle() worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process. # Send Readiness signal to EngineCore process.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: # Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle, payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL) protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send_string(WorkerProc.READY_STR)
...@@ -270,11 +273,13 @@ class WorkerProc: ...@@ -270,11 +273,13 @@ class WorkerProc:
proc = context.Process(target=WorkerProc.worker_main, proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs, kwargs=process_kwargs,
daemon=True) daemon=True)
proc.start()
# Wait for startup with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
worker_response_mq_handle = WorkerProc.wait_for_startup( proc.start()
proc, ready_path)
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_socket)
worker_response_mq = MessageQueue.create_from_handle( worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0) worker_response_mq_handle, 0)
...@@ -337,23 +342,22 @@ class WorkerProc: ...@@ -337,23 +342,22 @@ class WorkerProc:
@staticmethod @staticmethod
def wait_for_startup( def wait_for_startup(
proc: BaseProcess, proc: BaseProcess,
ready_path: str, ready_socket: zmq.Socket,
) -> Optional[Handle]: ) -> Optional[Handle]:
"""Wait until the Worker is ready.""" """Wait until the Worker is ready."""
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY. # Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.") logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive(): if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.") raise RuntimeError("WorkerProc failed to start.")
message = socket.recv_string() message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR assert message == WorkerProc.READY_STR
handle_frame = socket.recv(copy=False) handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer) handle = pickle.loads(handle_frame.buffer)
return handle return handle
class ResponseStatus(Enum): class ResponseStatus(Enum):
SUCCESS = auto() SUCCESS = auto()
......
...@@ -4,6 +4,7 @@ from dataclasses import dataclass ...@@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size from vllm.utils import cdiv, get_dtype_size
...@@ -43,28 +44,23 @@ class KVCacheSpec: ...@@ -43,28 +44,23 @@ class KVCacheSpec:
""" """
raise NotImplementedError raise NotImplementedError
def bytes_for_tokens(self, num_tokens: int) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
""" """
The KV cache size for `num_tokens` tokens in bytes. Returns the real The maximum possible memory usage of this KV cache in bytes.
memory size after padding `num_tokens` to full blocks.
Returns: Returns:
The KV cache size The KV cache size in bytes
""" """
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class FullAttentionSpec(KVCacheSpec): class AttentionSpec(KVCacheSpec):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
use_mla: bool use_mla: bool
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector # For MLA we only store a single latent vector
...@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec): ...@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
return coef * self.block_size * self.num_kv_heads * self.head_size \ return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int:
return cdiv(num_tokens, self.block_size) * self.page_size_bytes @dataclass
class FullAttentionSpec(AttentionSpec):
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@dataclass
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
max_model_len)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass @dataclass
......
...@@ -12,6 +12,7 @@ from vllm.logger import init_logger ...@@ -12,6 +12,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,12 +32,14 @@ class StatLoggerBase(ABC): ...@@ -31,12 +32,14 @@ class StatLoggerBase(ABC):
class LoggingStatLogger(StatLoggerBase): class LoggingStatLogger(StatLoggerBase):
def __init__(self): def __init__(self, engine_index: int = 0):
self.engine_index = engine_index
self._reset(time.monotonic()) self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats() self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset. # Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable. # TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics() self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()
def _reset(self, now): def _reset(self, now):
self.last_log_time = now self.last_log_time = now
...@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
def log(self): def log(self):
...@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase):
# Format and print output. # Format and print output.
logger.info( logger.info(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, " "Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, " "Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, " "GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%", "Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
scheduler_stats.num_running_reqs, scheduler_stats.num_running_reqs,
...@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
) )
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log()
class PrometheusStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics() self._unregister_vllm_metrics()
# Use this flag to hide metrics that were deprecated in # Use this flag to hide metrics that were deprecated in
...@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase):
self.show_hidden_metrics = \ self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name"] labelnames = ["model_name", "engine"]
labelvalues = [vllm_config.model_config.served_model_name] labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
...@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self.labelname_running_lora_adapters, self.labelname_running_lora_adapters,
]) ])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
# #
# Cache config info metric # Cache config info metric
# #
...@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self.counter_gpu_prefix_cache_hits.inc( self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits) scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
if iteration_stats is None: if iteration_stats is None:
return return
......
...@@ -4,6 +4,8 @@ import time ...@@ -4,6 +4,8 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState from vllm.v1.engine.output_processor import RequestState
...@@ -35,6 +37,8 @@ class SchedulerStats: ...@@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats: PrefixCacheStats = field( prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats) default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@dataclass @dataclass
class LoRAStats: class LoRAStats:
......
...@@ -59,6 +59,8 @@ class Request: ...@@ -59,6 +59,8 @@ class Request:
self.mm_positions = multi_modal_placeholders or [] self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or [] self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: list[str] = multi_modal_hashes or [] self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)
...@@ -93,9 +95,11 @@ class Request: ...@@ -93,9 +95,11 @@ class Request:
token_ids: Union[int, list[int]], token_ids: Union[int, list[int]],
) -> None: ) -> None:
if isinstance(token_ids, int): if isinstance(token_ids, int):
token_ids = [token_ids] self._output_token_ids.append(token_ids)
self._output_token_ids.extend(token_ids) self._all_token_ids.append(token_ids)
self._all_token_ids.extend(token_ids) else:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
@property @property
def num_tokens(self) -> int: def num_tokens(self) -> int:
...@@ -115,13 +119,6 @@ class Request: ...@@ -115,13 +119,6 @@ class Request:
def get_finished_reason(self) -> Union[FinishReason, None]: def get_finished_reason(self) -> Union[FinishReason, None]:
return RequestStatus.get_finished_reason(self.status) return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return len(self.mm_inputs) > 0
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_positions)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions) assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"] num_tokens = self.mm_positions[input_id]["length"]
......
...@@ -19,6 +19,12 @@ except ImportError: ...@@ -19,6 +19,12 @@ except ImportError:
class TopKTopPSampler(nn.Module): class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module): ...@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling.""" """
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p) logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
...@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module): ...@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# If only top-k is specified, use pytorch's builtin topk op. This leads logits = apply_top_k_top_p_tpu(logits, k, p)
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
def apply_top_k_top_p_tpu(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
if k is not None:
logits = apply_top_k_only(logits, k)
if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def apply_top_k_top_p( def apply_top_k_top_p(
logits: torch.Tensor, logits: torch.Tensor,
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
...@@ -136,10 +171,18 @@ def apply_top_k_top_p( ...@@ -136,10 +171,18 @@ def apply_top_k_top_p(
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits. """Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches. If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
""" """
if k is None and p is None: if p is None:
return logits if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None: if k is not None:
...@@ -153,7 +196,7 @@ def apply_top_k_top_p( ...@@ -153,7 +196,7 @@ def apply_top_k_top_p(
if p is not None: if p is not None:
# Apply top-p. # Apply top-p.
probs_sort = logits_sort.softmax(dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one # at least one
top_p_mask[:, -1] = False top_p_mask[:, -1] = False
...@@ -164,6 +207,31 @@ def apply_top_k_top_p( ...@@ -164,6 +207,31 @@ def apply_top_k_top_p(
return logits return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample( def random_sample(
probs: torch.Tensor, probs: torch.Tensor,
generators: dict[int, torch.Generator], generators: dict[int, torch.Generator],
......
...@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module): ...@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module):
output_token_ids: torch.Tensor, output_token_ids: torch.Tensor,
vocab_size: int, vocab_size: int,
) -> list[list[int]]: ) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy() output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens. # Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
......
...@@ -87,6 +87,12 @@ class Sampler(nn.Module): ...@@ -87,6 +87,12 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert not (sampling_metadata.all_greedy assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random) and sampling_metadata.all_random)
if sampling_metadata.all_random: if sampling_metadata.all_random:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment