Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
......@@ -5,10 +5,12 @@ from collections.abc import Iterable
from typing import Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.utils import cdiv, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_request_tokens)
from vllm.v1.core.specialized_manager import get_specialized_manager
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
......@@ -19,20 +21,24 @@ class KVCacheManager:
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
kv_cache_config: KVCacheConfig,
max_model_len: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
num_preallocate_tokens: int = 64,
log_stats: bool = False,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
assert len(kv_cache_config.kv_cache_groups) == 1, (
"KVCacheManager does not support hybrid models with more than 1 "
"kv cache group")
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = kv_cache_spec.block_size
self.num_gpu_blocks = kv_cache_config.num_blocks
self.max_model_len = max_model_len
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
self.sliding_window = sliding_window
self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size)
self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
# FIXME: make prefix cache stats conditional on log_stats
self.log_stats = log_stats
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
......@@ -46,9 +52,15 @@ class KVCacheManager:
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
self.num_preallocate_blocks = cdiv(num_preallocate_tokens,
self.block_size)
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
self.block_pool = BlockPool(num_gpu_blocks, enable_caching)
self.specialized_manager = get_specialized_manager(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
......@@ -109,22 +121,31 @@ class KVCacheManager:
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = hash_request_tokens(self.block_size, request)
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None:
# Check for cache hits
computed_blocks = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash
# is not in the cached_block_hash_to_id, the following
# block hashes are not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(
block_hash):
computed_blocks.append(cached_block)
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
break
last_block_hash = None
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
if last_block_hash is not None:
# Add back the last block hash if it was removed.
block_hashes.append(last_block_hash)
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
......@@ -173,13 +194,24 @@ class KVCacheManager:
new_computed_blocks = new_computed_blocks or []
req_blocks = self.req_to_blocks[request.request_id]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
removed_blocks = self.specialized_manager.remove_skipped_blocks(
req_blocks, request.num_computed_tokens)
self.block_pool.free_blocks(removed_blocks)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
......@@ -247,6 +279,7 @@ class KVCacheManager:
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks_after_append,
block_size=self.block_size,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[
......
# SPDX-License-Identifier: Apache-2.0
"""KV-Cache Utilities."""
import os
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional
from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, KVCacheTensor)
from vllm.utils import sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
......@@ -18,9 +21,8 @@ logger = init_logger(__name__)
class BlockHashType(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. But please note that
hash collisions can still theoretically occur, albeit with an extremely
low probability.
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
# Hash value of the block in an integer.
hash_value: int
......@@ -30,6 +32,20 @@ class BlockHashType(NamedTuple):
extra_keys: Optional[Any] = None
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
'PYTHONHASHSEED') is not None else sha256(os.getenv('PYTHONHASHSEED'))
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the most recent N requests.
......@@ -375,6 +391,7 @@ def generate_block_hash_extra_keys(
def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
......@@ -395,21 +412,16 @@ def hash_block_tokens(
The entire tuple is used as the hash key of the block.
"""
if not parent_block_hash:
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
parent_block_hash = hash('None')
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType(
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(block_size: int,
def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
......@@ -441,7 +453,7 @@ def hash_request_tokens(block_size: int,
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value,
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
......@@ -472,14 +484,14 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
max_model_len = vllm_config.model_config.max_model_len
needed_memory = 0
for layer_spec in kv_cache_spec.values():
needed_memory += layer_spec.bytes_for_tokens(max_model_len)
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
if needed_memory > available_memory:
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
f"increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")
......@@ -586,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
Only models with one type of KV cache are supported yet. This function tries
to convert the KV cache specs to one type if the model is a hybrid model
with multiple type of KV cache. It will convert all SlidingWindowSpec to
FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
has_full_attention = any(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
)
def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
......@@ -602,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig,
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
......
......@@ -10,8 +10,7 @@ if TYPE_CHECKING:
import numpy.typing as npt
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
......
......@@ -7,9 +7,9 @@ from collections import deque
from collections.abc import Iterable
from typing import Optional, Union
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
......@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__)
......@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface):
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
self.speculative_config = speculative_config
self.kv_cache_config = kv_cache_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.include_finished_set = include_finished_set
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching,
enable_caching=cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
log_stats=self.log_stats)
self.block_size = self.cache_config.block_size
......@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
mm_registry=mm_registry,
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
......@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
self._try_schedule_encoder_inputs(request,
request.num_computed_tokens,
num_new_tokens,
encoder_budget))
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
# NOTE(woosuk): By using `continue` instead of `break` here,
# we intentionally relax the strict FCFS scheduling policy
# to allow lower-priority requests to be scheduled when a
# higher-priority request is blocked by encoder constraints.
req_index += 1
continue
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
......@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
requested_loras: set[int] = set()
scheduled_loras: set[int] = set()
if self.lora_config:
requested_loras = set(
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
waiting_for_fsm: deque[Request] = deque()
skipped_waiting_requests: deque[Request] = deque()
# Next, schedule the WAITING requests.
if not preempted_reqs:
......@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0]
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
waiting_structured_output_req = self.waiting.popleft()
waiting_for_fsm.appendleft(
waiting_structured_output_req)
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
req_lora_id = request.lora_request.lora_int_id
if len(requested_loras) == self.lora_config.max_loras and (
req_lora_id not in requested_loras):
# Cannot schedule.
# TODO (varun): This means all the other requests in
# the WAITING queue will be blocked by this request,
# even if,
# 1. these other requests do not use LoRA, or,
# 2. these other requests use the already requested
# LoRAs.
# This is too conservative and could be optimized.
break
if self.lora_config and request.lora_request and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id
not in scheduled_loras):
# Scheduling would exceed max_loras, skip.
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
......@@ -288,21 +300,15 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
......@@ -310,6 +316,9 @@ class Scheduler(SchedulerInterface):
if num_new_tokens == 0:
# The request cannot be scheduled.
break
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks)
......@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface):
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
requested_loras.add(request.lora_request.lora_int_id)
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
......@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if waiting_for_fsm:
self.waiting.extendleft(waiting_for_fsm)
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
......@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface):
grammar_bitmask=grammar_bitmask,
)
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
self.finished_req_ids = set()
return scheduler_output
......@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface):
limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input.
"""
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
......@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface):
new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
......@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface):
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request.num_computed_tokens += num_tokens_scheduled
assert request.num_computed_tokens <= request.num_tokens
else:
# num_computed_tokens_step represents the number of tokens
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens[req_id])
num_computed_tokens_step = num_scheduled_tokens[req_id] - (
len(scheduled_spec_token_ids) + 1 -
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens += num_computed_tokens_step
request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"]
num_tokens = request.mm_positions[input_id]["length"]
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions["offset"]
num_tokens = mm_positions["length"]
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
......@@ -595,23 +610,24 @@ class Scheduler(SchedulerInterface):
stopped = False
new_logprobs = None
new_token_ids: list[int] = []
new_token_ids = generated_token_ids
if request.num_computed_tokens >= request.num_tokens:
for output_token_id in generated_token_ids:
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None:
assert logprobs is not None
if request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
......@@ -621,9 +637,7 @@ class Scheduler(SchedulerInterface):
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id,
new_token_ids,
)
req_id, new_token_ids)
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
......@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id)
self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)
self.running = new_running
return EngineCoreOutputs(
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
scheduler_stats=self.make_stats(spec_decoding_stats),
)
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
return engine_core_outputs
def add_request(self, request: Request) -> None:
self.waiting.append(request)
......@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface):
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> Optional[SchedulerStats]:
def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
return SchedulerStats(
......@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
spec_decoding_stats=spec_decoding_stats,
)
def make_spec_decoding_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats],
num_draft_tokens: int,
num_accepted_tokens: int,
) -> Optional[SpecDecodingStats]:
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats()
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
return spec_decoding_stats
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
class SpecializedManager(ABC):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise NotImplementedError
class FullAttentionManager(SpecializedManager):
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# No need to remove blocks for full attention.
return []
class SlidingWindowManager(SpecializedManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec,
block_pool: BlockPool):
super().__init__(kv_cache_spec, block_pool)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
self._null_block = block_pool.null_block
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes)
num_contiguous_blocks = 0
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
num_contiguous_blocks += 1
if (num_contiguous_blocks
>= self.sliding_window_contiguous_blocks):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
return computed_blocks
else:
num_contiguous_blocks = 0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
return removed_blocks
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
block_pool: BlockPool) -> SpecializedManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, block_pool)
return manager
......@@ -128,12 +128,18 @@ class EngineCoreOutputs(
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0
# [num_reqs]
outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the engine is paused.
engine_paused: bool = False
def __post_init__(self):
if self.timestamp == 0.0:
......@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD = b'\x00'
ABORT = b'\x01'
UTILITY = b'\x02'
START_DP = b'\x02'
UTILITY = b'\x03'
......@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
......@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient):
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
......@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers: list[StatLoggerBase] = []
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats:
for i in range(vllm_config.parallel_config.data_parallel_size):
loggers: list[StatLoggerBase] = []
if logger.isEnabledFor(logging.INFO):
self.stat_loggers.append(LoggingStatLogger())
self.stat_loggers.append(PrometheusStatLogger(vllm_config))
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
......@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient):
self.processor = Processor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
......@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
......@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_index: int = 0,
):
if not self.log_stats:
return
assert scheduler_stats is not None
for stat_logger in self.stat_loggers:
for stat_logger in self.stat_loggers[engine_index]:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
......@@ -393,7 +402,8 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None,
model_output=None,
) -> None:
for stat_logger in self.stat_loggers:
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None:
......@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient):
async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level)
async def wake_up(self) -> None:
await self.engine_core.wake_up_async()
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()
......
# SPDX-License-Identifier: Apache-2.0
import os
import queue
import signal
import sys
import threading
import time
from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection
from typing import Any, Optional
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import psutil
import zmq
import zmq.asyncio
from vllm.config import VllmConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
......@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
......@@ -39,6 +44,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore:
"""Inner loop of vLLM's Engine."""
......@@ -60,8 +67,9 @@ class EngineCore:
self.model_executor = executor_class(vllm_config)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
vllm_config)
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
......@@ -84,14 +92,16 @@ class EngineCore:
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)
self.scheduler = Scheduler(
self.scheduler: SchedulerInterface = Scheduler(
scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
)
# Setup MM Input Mapper.
......@@ -110,8 +120,8 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> tuple[int, int]:
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
# Get all kv cache needed by the model
......@@ -136,13 +146,14 @@ class EngineCore:
unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks.
# an arbitrary one to initialize the scheduler.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
......@@ -150,7 +161,7 @@ class EngineCore:
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
......@@ -253,8 +264,8 @@ class EngineCore:
def sleep(self, level: int = 1):
self.model_executor.sleep(level)
def wake_up(self):
self.model_executor.wake_up()
def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping
......@@ -274,6 +285,24 @@ class EngineCore:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_executor.save_sharded_state(path=path,
pattern=pattern,
max_size=max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
......@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore):
self,
input_path: str,
output_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
):
super().__init__(vllm_config, executor_class, log_stats)
......@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore):
args=(input_path, ),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, ),
args=(output_path, engine_index),
daemon=True).start()
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
self.global_unfinished_reqs = False
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod
def run_engine_core(*args, **kwargs):
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
......@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent()
engine_core = None
engine_core: Optional[EngineCoreProc] = None
try:
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop()
except SystemExit:
......@@ -350,26 +397,42 @@ class EngineCoreProc(EngineCore):
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
step_fn = (self.step
if self.batch_queue is None else self.step_with_batch_queue)
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
while not self.scheduler.has_requests():
logger.debug("EngineCore busy loop waiting.")
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.global_unfinished_reqs and not (
self.scheduler.has_requests()):
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
# 2) Handle any new client requests.
if waited:
logger.debug(
"EngineCore loop active - local unfinished: %s, finished: %s.",
self.scheduler.has_unfinished_requests(),
self.scheduler.has_finished_requests())
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
# 3) Step the engine core.
outputs = step_fn()
def _process_engine_step(self):
"""Called only when there are unfinished local requests."""
# 4) Put EngineCoreOutputs into the output queue.
# Step the engine core.
outputs = self.step_fn()
# Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
......@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore):
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.START_DP:
if not self.global_unfinished_reqs:
logger.debug("EngineCore starting idle loop.")
self.global_unfinished_reqs = True
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
output = UtilityOutput(call_id)
......@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str):
def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread."""
# Msgpack serialization encoding.
......@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True:
outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)
socket.send(buffer, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm.platforms.cuda import device_id_to_physical_device_id
tp_size = vllm_config.parallel_config.tensor_parallel_size
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if local_unfinished_reqs:
# 2) Step the engine core.
self._process_engine_step()
# Check if we have now finished all requests.
local_unfinished_reqs = (
self.scheduler.has_unfinished_requests())
else:
if self.scheduler.has_finished_requests():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self._process_engine_step()
if not self.global_unfinished_reqs:
# All engines are idle.
continue
# There must be unfinished requests in DP peers, run a
# dummy forward pass.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.global_unfinished_reqs:
# Notify client that we are pausing the loop.
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 16 steps.
self.counter += 1
if self.counter != 16:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
......@@ -8,10 +8,11 @@ import threading
import uuid
import weakref
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, TypeVar, Union
import zmq
import zmq.asyncio
......@@ -32,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCoreClient(ABC):
"""
......@@ -60,6 +63,9 @@ class EngineCoreClient(ABC):
"is not currently supported.")
if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode:
......@@ -86,7 +92,7 @@ class EngineCoreClient(ABC):
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
def wake_up(self) -> None:
def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError
def is_sleeping(self) -> bool:
......@@ -113,6 +119,19 @@ class EngineCoreClient(ABC):
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
raise NotImplementedError
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
......@@ -128,7 +147,7 @@ class EngineCoreClient(ABC):
async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError
async def wake_up_async(self) -> None:
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError
async def is_sleeping_async(self) -> bool:
......@@ -149,6 +168,20 @@ class EngineCoreClient(ABC):
async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError
async def save_sharded_state_async(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
raise NotImplementedError
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
class InprocClient(EngineCoreClient):
"""
......@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient):
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
def wake_up(self) -> None:
self.engine_core.wake_up()
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
......@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient):
def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id)
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
self.engine_core.save_sharded_state(path, pattern, max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngine:
"""One per data parallel rank."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name=f"EngineCore_{index}",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"dp_rank": index,
"local_dp_rank": local_dp_rank,
"executor_class": executor_class,
"log_stats": log_stats,
})
self.num_reqs_in_flight = 0
finally:
if not hasattr(self, "num_reqs_in_flight"):
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass
class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
ctx: zmq.Context
ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None
def __call__(self):
"""Clean up background resources."""
if self.proc_handle is not None:
self.proc_handle.shutdown()
for core_engine in self.core_engines:
core_engine.close()
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
if self.output_socket is not None:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
......@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient):
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup.
sync_ctx = zmq.Context()
sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed
......@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient):
self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources)
# Paths for IPC.
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
# Start EngineCore in background process.
self.resources.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=self.output_path,
process_name="EngineCore",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"executor_class": executor_class,
"log_stats": log_stats,
})
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
index, local_dp_rank)
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
# Wait for engine core process(es) to start.
for engine in self.resources.core_engines:
engine.proc_handle.wait_for_startup()
# Create input socket.
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.input_socket = self.resources.input_socket
self.utility_results: dict[int, AnyFuture] = {}
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Default case - single core engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
core_engine = new_core_engine(
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
core_engines.append(core_engine)
self.core_engine = core_engine
def shutdown(self):
self._finalizer()
......@@ -356,9 +458,9 @@ class SyncMPClient(MPClient):
def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR)
shutdown_socket.bind(shutdown_path)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller()
poller.register(shutdown_socket)
poller.register(out_socket)
......@@ -370,7 +472,7 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
(frame, ) = out_socket.recv_multipart(copy=False)
frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
......@@ -391,18 +493,15 @@ class SyncMPClient(MPClient):
def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
self.core_engine.send_multipart(msg)
def _call_utility(self, method: str, *args) -> Any:
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future()
self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
......@@ -419,34 +518,48 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None:
self._call_utility("profile", is_start)
self.call_utility("profile", is_start)
def reset_prefix_cache(self) -> None:
self._call_utility("reset_prefix_cache")
self.call_utility("reset_prefix_cache")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self._call_utility("add_lora", lora_request)
return self.call_utility("add_lora", lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self._call_utility("remove_lora", lora_id)
return self.call_utility("remove_lora", lora_id)
def list_loras(self) -> set[int]:
return self._call_utility("list_loras")
return self.call_utility("list_loras")
def pin_lora(self, lora_id: int) -> bool:
return self._call_utility("pin_lora", lora_id)
return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1) -> None:
self._call_utility("sleep", level)
self.call_utility("sleep", level)
def wake_up(self) -> None:
self._call_utility("wake_up")
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.call_utility("wake_up", tags)
def is_sleeping(self) -> bool:
return self._call_utility("is_sleeping")
return self.call_utility("is_sleeping")
def execute_dummy_batch(self) -> None:
self._call_utility("execute_dummy_batch")
self.call_utility("execute_dummy_batch")
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.call_utility("collective_rpc", method, timeout, args,
kwargs)
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
self.call_utility("save_sharded_state", path, pattern, max_size)
class AsyncMPClient(MPClient):
......@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient):
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None
async def _start_output_queue_task(self):
self.outputs_handler: Optional[Callable[
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
def _ensure_output_queue_task(self):
if self.outputs_queue is not None:
return
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
output_handler = self.outputs_handler
_self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
......@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient):
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
continue
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None:
await self._start_output_queue_task()
self._ensure_output_queue_task()
assert self.outputs_queue is not None
return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
self._ensure_output_queue_task()
if self.outputs_queue is None:
await self._start_output_queue_task()
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(self, method: str, *args) -> Any:
async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
await self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
self._ensure_output_queue_task()
return await future
async def add_request_async(self, request: EngineCoreRequest) -> None:
......@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None:
await self._call_utility_async("profile", is_start)
await self.call_utility_async("profile", is_start)
async def reset_prefix_cache_async(self) -> None:
await self._call_utility_async("reset_prefix_cache")
await self.call_utility_async("reset_prefix_cache")
async def sleep_async(self, level: int = 1) -> None:
await self._call_utility_async("sleep", level)
await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None:
await self._call_utility_async("wake_up")
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
await self.call_utility_async("wake_up", tags)
async def is_sleeping_async(self) -> bool:
return await self._call_utility_async("is_sleeping")
return await self.call_utility_async("is_sleeping")
async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch")
await self.call_utility_async("execute_dummy_batch")
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
return await self._call_utility_async("add_lora", lora_request)
return await self.call_utility_async("add_lora", lora_request)
async def remove_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("remove_lora", lora_id)
return await self.call_utility_async("remove_lora", lora_id)
async def list_loras_async(self) -> set[int]:
return await self._call_utility_async("list_loras")
return await self.call_utility_async("list_loras")
async def pin_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("pin_lora", lora_id)
return await self.call_utility_async("pin_lora", lora_id)
async def save_sharded_state_async(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
await self.call_utility_async("save_sharded_state", path, pattern,
max_size)
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return await self.call_utility_async("collective_rpc", method, timeout,
args, kwargs)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(vllm_config, executor_class, log_stats)
assert len(self.core_engines) > 1
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Launch a core engine for each data parallel rank.
dp_size = vllm_config.parallel_config.data_parallel_size
for i in range(dp_size):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines.append(new_core_engine(i, i))
self.core_engines = core_engines
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (await asyncio.gather(*[
self._call_utility_async(method, *args, engine=engine)
for engine in self.core_engines
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
])
self._ensure_output_queue_task()
def get_core_engine_for_request(self) -> CoreEngine:
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
@staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient",
outputs: EngineCoreOutputs):
if self.reqs_in_flight:
for req_id in outputs.finished_requests or ():
if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1
if outputs.engine_paused:
assert self.num_engines_running >= 1
self.num_engines_running -= 1
if not self.num_engines_running and self.reqs_in_flight:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
if coros:
await asyncio.gather(*coros)
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids:
return
if len(request_ids) == 1:
# Fast-path common case.
if engine := self.reqs_in_flight.get(request_ids[0]):
await self._abort_requests(request_ids, engine)
return
by_engine: dict[CoreEngine, list[str]] = {}
for req_id in request_ids:
if engine := self.reqs_in_flight.get(req_id):
by_engine.setdefault(engine, []).append(req_id)
for engine, req_ids in by_engine.items():
await self._abort_requests(req_ids, engine)
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))
......@@ -2,15 +2,16 @@
from collections.abc import Mapping
from copy import copy
from typing import Optional, Union
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
......@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_R = TypeVar("_R", default=Any)
class LLMEngine:
......@@ -43,7 +45,6 @@ class LLMEngine:
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
......@@ -60,11 +61,13 @@ class LLMEngine:
self.cache_config = vllm_config.cache_config
# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
# In the decoupled engine case this is handled in EngineCoreProc.
parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False
if self.dp_enabled:
self.dp_group = self.parallel_config.stateless_init_dp_group()
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
......@@ -77,7 +80,6 @@ class LLMEngine:
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
......@@ -148,7 +150,7 @@ class LLMEngine:
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if not self.dp_enabled:
if self.dp_group is None:
return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished)
......@@ -243,8 +245,8 @@ class LLMEngine:
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
def wake_up(self):
self.engine_core.wake_up()
def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
......@@ -280,3 +282,14 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
......@@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string and finish_reason != FinishReason.STOP:
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
......
......@@ -5,9 +5,8 @@ from collections.abc import Mapping
from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs import ProcessorInputs, PromptType
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
......@@ -31,7 +30,6 @@ class Processor:
self,
vllm_config: VllmConfig,
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
......@@ -123,7 +121,8 @@ class Processor:
return
supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
......@@ -137,13 +136,15 @@ class Processor:
f" != {engine_level_backend}")
else:
params.guided_decoding.backend = engine_level_backend
import vllm.platforms
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
# Request content validation
if engine_level_backend == "xgrammar":
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
params.guided_decoding.backend = engine_level_backend
elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
......@@ -157,12 +158,13 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
if params.guided_decoding.backend == "guidance":
if engine_level_backend.startswith("guidance"):
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = engine_level_backend
def process_inputs(
self,
......@@ -206,14 +208,7 @@ class Processor:
self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
......@@ -224,8 +219,9 @@ class Processor:
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (self.model_config.max_model_len -
len(decoder_inputs.prompt_token_ids))
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
......@@ -235,57 +231,46 @@ class Processor:
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if (decoder_mm_inputs := decoder_inputs.multi_modal_data):
assert isinstance(decoder_mm_inputs, MultiModalKwargs)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_inputs.modalities
for item in decoder_mm_inputs.get_items(modality)
]
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
sorted_modalities,
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs.multi_modal_placeholders,
decoder_inputs.multi_modal_hashes if self.use_hash else None,
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if self.use_hash else None,
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# modalities involved.
if len(sorted_modalities) > 1:
modality_order_dict = {
modality: order
for order, modality in enumerate(sorted_modalities)
}
# Sanity check to make sure each multimodal input has only one
# modality key.
for mm_input in individual_mm_inputs:
assert len(mm_input.modalities) == 1
# Sort MultiModalKwargs to match sorted_mm_positions
sorted_mm_inputs = sorted(
individual_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))
used_indices[modality] += 1
else:
sorted_mm_inputs = individual_mm_inputs
sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
return EngineCoreRequest(
request_id=request_id,
prompt=decoder_inputs.prompt,
prompt_token_ids=decoder_inputs.prompt_token_ids,
prompt=decoder_inputs.get("prompt"),
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
......@@ -298,15 +283,16 @@ class Processor:
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = inputs
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
......
......@@ -235,7 +235,10 @@ class WorkerProc:
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
# Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
......@@ -270,11 +273,13 @@ class WorkerProc:
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
daemon=True)
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
proc.start()
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_path)
proc, ready_socket)
worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0)
......@@ -337,21 +342,20 @@ class WorkerProc:
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
ready_socket: zmq.Socket,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
message = socket.recv_string()
message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = socket.recv(copy=False)
handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
......
......@@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size
......@@ -43,28 +44,23 @@ class KVCacheSpec:
"""
raise NotImplementedError
def bytes_for_tokens(self, num_tokens: int) -> int:
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
"""
The KV cache size for `num_tokens` tokens in bytes. Returns the real
memory size after padding `num_tokens` to full blocks.
The maximum possible memory usage of this KV cache in bytes.
Returns:
The KV cache size
The KV cache size in bytes
"""
raise NotImplementedError
@dataclass
class FullAttentionSpec(KVCacheSpec):
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
@property
def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
......@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int:
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass
class FullAttentionSpec(AttentionSpec):
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@dataclass
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
max_model_len)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass
......
......@@ -12,6 +12,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
logger = init_logger(__name__)
......@@ -31,12 +32,14 @@ class StatLoggerBase(ABC):
class LoggingStatLogger(StatLoggerBase):
def __init__(self):
def __init__(self, engine_index: int = 0):
self.engine_index = engine_index
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()
def _reset(self, now):
self.last_log_time = now
......@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats
def log(self):
......@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase):
# Format and print output.
logger.info(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
......@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100,
)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log()
class PrometheusStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics()
# Use this flag to hide metrics that were deprecated in
......@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase):
self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name"]
labelvalues = [vllm_config.model_config.served_model_name]
labelnames = ["model_name", "engine"]
labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
max_model_len = vllm_config.model_config.max_model_len
......@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self.labelname_running_lora_adapters,
])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
#
# Cache config info metric
#
......@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
if iteration_stats is None:
return
......
......@@ -4,6 +4,8 @@ import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState
......@@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@dataclass
class LoRAStats:
......
......@@ -59,6 +59,8 @@ class Request:
self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
......@@ -93,7 +95,9 @@ class Request:
token_ids: Union[int, list[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
self._output_token_ids.append(token_ids)
self._all_token_ids.append(token_ids)
else:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
......@@ -115,13 +119,6 @@ class Request:
def get_finished_reason(self) -> Union[FinishReason, None]:
return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return len(self.mm_inputs) > 0
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_positions)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"]
......
......@@ -19,6 +19,12 @@ except ImportError:
class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self):
super().__init__()
......@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling."""
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
......@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
logits = apply_top_k_top_p_tpu(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_top_p_tpu(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
if k is not None:
logits = apply_top_k_only(logits, k)
if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
......@@ -136,10 +171,18 @@ def apply_top_k_top_p(
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if k is None and p is None:
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
......@@ -153,7 +196,7 @@ def apply_top_k_top_p(
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
......@@ -164,6 +207,31 @@ def apply_top_k_top_p(
return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
......
......@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module):
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
......
......@@ -87,6 +87,12 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_random:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment