Commit 675ba75f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-ori

parents 5cc98918 296c6572
...@@ -10,8 +10,7 @@ if TYPE_CHECKING: ...@@ -10,8 +10,7 @@ if TYPE_CHECKING:
import numpy.typing as npt import numpy.typing as npt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request from vllm.v1.request import Request
......
...@@ -7,9 +7,9 @@ from collections import deque ...@@ -7,9 +7,9 @@ from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Optional, Union
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
SpeculativeConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget) compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
...@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, ...@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
from vllm.v1.core.sched.utils import check_stop from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface): ...@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface):
model_config: ModelConfig, model_config: ModelConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig], kv_cache_config: KVCacheConfig,
log_stats: bool,
structured_output_manager: StructuredOutputManager, structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.speculative_config = speculative_config self.kv_cache_config = kv_cache_config
self.log_stats = log_stats self.log_stats = log_stats
self.structured_output_manager = structured_output_manager self.structured_output_manager = structured_output_manager
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.include_finished_set = include_finished_set
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \ self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len self.max_model_len = self.scheduler_config.max_model_len
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size, kv_cache_config=kv_cache_config,
num_gpu_blocks=num_gpu_blocks,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window, enable_caching=cache_config.enable_prefix_caching,
enable_caching=self.cache_config.enable_prefix_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
log_stats=self.log_stats) log_stats=self.log_stats)
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
...@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface): ...@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=model_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
mm_registry=mm_registry,
) )
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
...@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface): ...@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens) request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( if request.has_encoder_inputs:
self._try_schedule_encoder_inputs(request, (encoder_inputs_to_schedule, num_new_tokens,
request.num_computed_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs(
num_new_tokens, request, request.num_computed_tokens, num_new_tokens,
encoder_budget)) encoder_budget)
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled because the encoder budget # The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted. # or the encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`, # NOTE(woosuk): By using `continue` instead of `break` here,
# we do not strictly follow the FCFS scheduling policy and # we intentionally relax the strict FCFS scheduling policy
# allow the lower-priority requests to be scheduled. # to allow lower-priority requests to be scheduled when a
req_index += 1 # higher-priority request is blocked by encoder constraints.
continue req_index += 1
continue
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
...@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface): ...@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs # Record the LoRAs in scheduled_running_reqs
requested_loras: set[int] = set() scheduled_loras: set[int] = set()
if self.lora_config: if self.lora_config:
requested_loras = set( scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0) if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped # Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later # and put back at the head of the waiting queue later
waiting_for_fsm: deque[Request] = deque() skipped_waiting_requests: deque[Request] = deque()
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
...@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface): ...@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0] request = self.waiting[0]
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM: if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar: if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
else: else:
waiting_structured_output_req = self.waiting.popleft() self.waiting.popleft()
waiting_for_fsm.appendleft( skipped_waiting_requests.appendleft(request)
waiting_structured_output_req)
continue continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request and (
req_lora_id = request.lora_request.lora_int_id len(scheduled_loras) == self.lora_config.max_loras
if len(requested_loras) == self.lora_config.max_loras and ( and request.lora_request.lora_int_id
req_lora_id not in requested_loras): not in scheduled_loras):
# Cannot schedule. # Scheduling would exceed max_loras, skip.
# TODO (varun): This means all the other requests in self.waiting.popleft()
# the WAITING queue will be blocked by this request, skipped_waiting_requests.appendleft(request)
# even if, continue
# 1. these other requests do not use LoRA, or,
# 2. these other requests use the already requested
# LoRAs.
# This is too conservative and could be optimized.
break
# Get already-cached tokens. # Get already-cached tokens.
computed_blocks, num_computed_tokens = \ computed_blocks, num_computed_tokens = \
...@@ -288,28 +300,25 @@ class Scheduler(SchedulerInterface): ...@@ -288,28 +300,25 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed requests, # `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens. # which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0: if (0 < self.scheduler_config.long_prefill_token_threshold <
# This happens when prompt length is divisible by the block num_new_tokens):
# size and all blocks are cached. Now we force to recompute num_new_tokens = (
# the last block. Note that we have to re-compute an entire self.scheduler_config.long_prefill_token_threshold)
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs. # Schedule encoder inputs.
(encoder_inputs_to_schedule, num_new_tokens, if request.has_encoder_inputs:
new_encoder_budget) = self._try_schedule_encoder_inputs( (encoder_inputs_to_schedule, num_new_tokens,
request, num_computed_tokens, num_new_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs(
encoder_budget) request, num_computed_tokens, num_new_tokens,
if num_new_tokens == 0: encoder_budget)
# The request cannot be scheduled. if num_new_tokens == 0:
break # The request cannot be scheduled.
break
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks) request, num_new_tokens, computed_blocks)
...@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface): ...@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface):
f"Invalid request status: {request.status}") f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
requested_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks b.block_id for b in computed_blocks + new_blocks
] ]
...@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface): ...@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if waiting_for_fsm: if skipped_waiting_requests:
self.waiting.extendleft(waiting_for_fsm) self.waiting.extendleft(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
...@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface): ...@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface):
grammar_bitmask=grammar_bitmask, grammar_bitmask=grammar_bitmask,
) )
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
self.finished_req_ids = set() self.finished_req_ids = set()
return scheduler_output return scheduler_output
...@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface): ...@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface):
limitations, the method adjusts `num_new_tokens` to schedule only the limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input. decoder tokens up to just before the unschedulable encoder input.
""" """
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: list[int] = [] encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions mm_positions = request.mm_positions
assert mm_positions is not None assert mm_positions is not None
...@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface): ...@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface):
new_running: list[Request] = [] new_running: list[Request] = []
outputs: list[EngineCoreOutput] = [] outputs: list[EngineCoreOutput] = []
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid # loop can be a performance bottleneck. We should do our best to avoid
...@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface): ...@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface):
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] generated_token_ids = sampled_token_ids[req_index]
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up scheduled_spec_token_ids = (
# its num_tokens, the request generates output tokens. scheduler_output.scheduled_spec_decode_tokens.get(req_id))
# Otherwise, we ignore the sampler output for the request. if scheduled_spec_token_ids:
request.num_computed_tokens += num_tokens_scheduled # num_computed_tokens represents the number of tokens
assert request.num_computed_tokens <= request.num_tokens
else:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled
# tokens and rejections. # tokens and rejections. If some tokens are rejected,
# It is calculated as: # num_computed_tokens is decreased by the number of rejected
# num_computed_tokens_step = num_scheduled_tokens - # tokens, where is given by:
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids = ( num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
scheduler_output.scheduled_spec_decode_tokens[req_id]) len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
num_computed_tokens_step = num_scheduled_tokens[req_id] - ( spec_decoding_stats = self.make_spec_decoding_stats(
len(scheduled_spec_token_ids) + 1 - spec_decoding_stats,
len(generated_token_ids)) num_draft_tokens=len(scheduled_spec_token_ids),
request.num_computed_tokens += num_computed_tokens_step num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = ( cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request)) self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty. # OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids: if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids): for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"] mm_positions = request.mm_positions[input_id]
num_tokens = request.mm_positions[input_id]["length"] start_pos = mm_positions["offset"]
num_tokens = mm_positions["length"]
if start_pos + num_tokens <= request.num_computed_tokens: if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored # The encoder output is already processed and stored
# in the decoder's KV cache. # in the decoder's KV cache.
...@@ -595,35 +610,34 @@ class Scheduler(SchedulerInterface): ...@@ -595,35 +610,34 @@ class Scheduler(SchedulerInterface):
stopped = False stopped = False
new_logprobs = None new_logprobs = None
new_token_ids: list[int] = [] new_token_ids = generated_token_ids
if request.num_computed_tokens >= request.num_tokens: # Append generated tokens and check for stop. Note that if
for output_token_id in generated_token_ids: # a request is still being prefilled, we expect the model runner
request.append_output_token_ids(output_token_id) # to return empty token ids for the request.
new_token_ids.append(output_token_id) for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # Check for stop and update request state.
stopped = check_stop(request, self.max_model_len) # This must be called before we make the EngineCoreOutput.
if stopped: stopped = check_stop(request, self.max_model_len)
self._free_request(request) if stopped:
break self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None: if request.sampling_params.logprobs is not None and logprobs:
assert logprobs is not None # NOTE: once we support N tokens per step (spec decode),
# NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1.
# the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1)
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output: if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request # NOTE: structured_output_request
# should not be None if use_structured_output, we have # should not be None if use_structured_output, we have
# check above, so safe to ignore type warning # check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id, req_id, new_token_ids)
new_token_ids,
)
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
...@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface): ...@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs. # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id) self.scheduled_req_ids.remove(req_id)
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
self.running = new_running self.running = new_running
return EngineCoreOutputs( engine_core_outputs = EngineCoreOutputs(
outputs=outputs, outputs=outputs,
scheduler_stats=self.make_stats(), scheduler_stats=self.make_stats(spec_decoding_stats),
) )
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
return engine_core_outputs
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.append(request)
...@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface): ...@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface):
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache() return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> Optional[SchedulerStats]: def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats: if not self.log_stats:
return None return None
return SchedulerStats( return SchedulerStats(
...@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface): ...@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage, gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
spec_decoding_stats=spec_decoding_stats,
) )
def make_spec_decoding_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats],
num_draft_tokens: int,
num_accepted_tokens: int,
) -> Optional[SpecDecodingStats]:
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats()
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
return spec_decoding_stats
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
class SpecializedManager(ABC):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise NotImplementedError
class FullAttentionManager(SpecializedManager):
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# No need to remove blocks for full attention.
return []
class SlidingWindowManager(SpecializedManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec,
block_pool: BlockPool):
super().__init__(kv_cache_spec, block_pool)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
self._null_block = block_pool.null_block
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes)
num_contiguous_blocks = 0
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
num_contiguous_blocks += 1
if (num_contiguous_blocks
>= self.sliding_window_contiguous_blocks):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
return computed_blocks
else:
num_contiguous_blocks = 0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
return removed_blocks
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
block_pool: BlockPool) -> SpecializedManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, block_pool)
return manager
...@@ -128,12 +128,18 @@ class EngineCoreOutputs( ...@@ -128,12 +128,18 @@ class EngineCoreOutputs(
#NOTE(Nick): We could consider ways to make this more compact, #NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout # e.g. columnwise layout
engine_index: int = 0
# [num_reqs] # [num_reqs]
outputs: list[EngineCoreOutput] = [] outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0 timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the engine is paused.
engine_paused: bool = False
def __post_init__(self): def __post_init__(self):
if self.timestamp == 0.0: if self.timestamp == 0.0:
...@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum): ...@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
""" """
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
UTILITY = b'\x02' START_DP = b'\x02'
UTILITY = b'\x03'
...@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig ...@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient): ...@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
...@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient): ...@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats self.log_stats = log_stats
self.stat_loggers: list[StatLoggerBase] = []
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats: if self.log_stats:
if logger.isEnabledFor(logging.INFO): for i in range(vllm_config.parallel_config.data_parallel_size):
self.stat_loggers.append(LoggingStatLogger()) loggers: list[StatLoggerBase] = []
self.stat_loggers.append(PrometheusStatLogger(vllm_config)) if logger.isEnabledFor(logging.INFO):
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
...@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient): ...@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient):
self.processor = Processor( self.processor = Processor(
vllm_config=vllm_config, vllm_config=vllm_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
input_registry=input_registry, mm_registry=mm_registry,
) )
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput). # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
...@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient): ...@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
self._record_stats( self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
) )
...@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient): ...@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient):
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
engine_index: int = 0,
): ):
if not self.log_stats: if not self.log_stats:
return return
assert scheduler_stats is not None assert scheduler_stats is not None
for stat_logger in self.stat_loggers: for stat_logger in self.stat_loggers[engine_index]:
stat_logger.record(scheduler_stats=scheduler_stats, stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats) iteration_stats=iteration_stats)
...@@ -393,8 +402,9 @@ class AsyncLLM(EngineClient): ...@@ -393,8 +402,9 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None, scheduler_outputs=None,
model_output=None, model_output=None,
) -> None: ) -> None:
for stat_logger in self.stat_loggers: for loggers in self.stat_loggers:
stat_logger.log() for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None: async def check_health(self) -> None:
logger.debug("Called check_health.") logger.debug("Called check_health.")
...@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient): ...@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient):
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level) await self.engine_core.sleep_async(level)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async() await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async() return await self.engine_core.is_sleeping_async()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import queue import queue
import signal import signal
import sys
import threading import threading
import time import time
from concurrent.futures import Future from concurrent.futures import Future
from inspect import isclass, signature from inspect import isclass, signature
from multiprocessing.connection import Connection from logging import DEBUG
from typing import Any, Optional from typing import Any, Callable, Optional, TypeVar, Union
import msgspec import msgspec
import psutil import psutil
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from vllm.config import VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
...@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, ...@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx) zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs) unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
...@@ -39,6 +44,8 @@ logger = init_logger(__name__) ...@@ -39,6 +44,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5 POLLING_TIMEOUT_S = 2.5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore: class EngineCore:
"""Inner loop of vLLM's Engine.""" """Inner loop of vLLM's Engine."""
...@@ -60,8 +67,9 @@ class EngineCore: ...@@ -60,8 +67,9 @@ class EngineCore:
self.model_executor = executor_class(vllm_config) self.model_executor = executor_class(vllm_config)
# Setup KV Caches and update CacheConfig after profiling. # Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
vllm_config) self._initialize_kv_caches(vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
...@@ -84,14 +92,16 @@ class EngineCore: ...@@ -84,14 +92,16 @@ class EngineCore:
"compatibility may not be maintained.", "compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls) vllm_config.scheduler_config.scheduler_cls)
self.scheduler = Scheduler( self.scheduler: SchedulerInterface = Scheduler(
scheduler_config=vllm_config.scheduler_config, scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config, lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config, kv_cache_config=kv_cache_config,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager, structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
) )
# Setup MM Input Mapper. # Setup MM Input Mapper.
...@@ -110,8 +120,8 @@ class EngineCore: ...@@ -110,8 +120,8 @@ class EngineCore:
self.batch_queue_size) self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(self, def _initialize_kv_caches(
vllm_config: VllmConfig) -> tuple[int, int]: self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time() start = time.time()
# Get all kv cache needed by the model # Get all kv cache needed by the model
...@@ -136,13 +146,14 @@ class EngineCore: ...@@ -136,13 +146,14 @@ class EngineCore:
unify_kv_cache_configs(kv_cache_configs) unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use # All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks. # an arbitrary one to initialize the scheduler.
assert all([ assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs for cfg in kv_cache_configs
]) ])
num_gpu_blocks = kv_cache_configs[0].num_blocks num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0 num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution # Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs) self.model_executor.initialize_from_config(kv_cache_configs)
...@@ -150,7 +161,7 @@ class EngineCore: ...@@ -150,7 +161,7 @@ class EngineCore:
elapsed = time.time() - start elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, " logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed) "warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def add_request(self, request: EngineCoreRequest): def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler.""" """Add request to the scheduler."""
...@@ -253,8 +264,8 @@ class EngineCore: ...@@ -253,8 +264,8 @@ class EngineCore:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.model_executor.sleep(level) self.model_executor.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up() self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
...@@ -274,6 +285,24 @@ class EngineCore: ...@@ -274,6 +285,24 @@ class EngineCore:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id) return self.model_executor.pin_lora(lora_id)
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_executor.save_sharded_state(path=path,
pattern=pattern,
max_size=max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
...@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore): ...@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore):
self, self,
input_path: str, input_path: str,
output_path: str, output_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
engine_index: int = 0,
): ):
super().__init__(vllm_config, executor_class, log_stats) super().__init__(vllm_config, executor_class, log_stats)
...@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore): ...@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore):
args=(input_path, ), args=(input_path, ),
daemon=True).start() daemon=True).start()
threading.Thread(target=self.process_output_socket, threading.Thread(target=self.process_output_socket,
args=(output_path, ), args=(output_path, engine_index),
daemon=True).start() daemon=True).start()
# Send Readiness signal to EngineClient. self.global_unfinished_reqs = False
ready_pipe.send({"status": "READY"})
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod @staticmethod
def run_engine_core(*args, **kwargs): def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process.""" """Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination. # Signal handler used for graceful termination.
...@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore): ...@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
engine_core = None engine_core: Optional[EngineCoreProc] = None
try: try:
engine_core = EngineCoreProc(*args, **kwargs) parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop() engine_core.run_busy_loop()
except SystemExit: except SystemExit:
...@@ -350,28 +397,44 @@ class EngineCoreProc(EngineCore): ...@@ -350,28 +397,44 @@ class EngineCoreProc(EngineCore):
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore.""" """Core busy loop of the EngineCore."""
step_fn = (self.step
if self.batch_queue is None else self.step_with_batch_queue)
# Loop until process is sent a SIGINT or SIGTERM # Loop until process is sent a SIGINT or SIGTERM
while True: while True:
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
while not self.scheduler.has_requests(): self._process_input_queue()
logger.debug("EngineCore busy loop waiting.") # 2) Step the engine core and return the outputs.
req = self.input_queue.get() self._process_engine_step()
self._handle_client_request(*req)
def _process_input_queue(self):
# 2) Handle any new client requests. """Exits when an engine step needs to be performed."""
while not self.input_queue.empty():
req = self.input_queue.get_nowait() waited = False
self._handle_client_request(*req) while not self.global_unfinished_reqs and not (
self.scheduler.has_requests()):
# 3) Step the engine core. if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
outputs = step_fn() logger.debug("EngineCore waiting for work.")
waited = True
# 4) Put EngineCoreOutputs into the output queue. req = self.input_queue.get()
if outputs is not None: self._handle_client_request(*req)
self.output_queue.put_nowait(outputs)
if waited:
logger.debug(
"EngineCore loop active - local unfinished: %s, finished: %s.",
self.scheduler.has_unfinished_requests(),
self.scheduler.has_finished_requests())
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_engine_step(self):
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs = self.step_fn()
# Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
def _handle_client_request(self, request_type: EngineCoreRequestType, def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
...@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore): ...@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore):
self.add_request(request) self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT: elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request) self.abort_requests(request)
elif request_type == EngineCoreRequestType.START_DP:
if not self.global_unfinished_reqs:
logger.debug("EngineCore starting idle loop.")
self.global_unfinished_reqs = True
elif request_type == EngineCoreRequestType.UTILITY: elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request call_id, method_name, args = request
output = UtilityOutput(call_id) output = UtilityOutput(call_id)
...@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore): ...@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str): def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread.""" """Output socket IO thread."""
# Msgpack serialization encoding. # Msgpack serialization encoding.
...@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore): ...@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True: while True:
outputs = self.output_queue.get() outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer) encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False) socket.send(buffer, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm.platforms.cuda import device_id_to_physical_device_id
tp_size = vllm_config.parallel_config.tensor_parallel_size
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if local_unfinished_reqs:
# 2) Step the engine core.
self._process_engine_step()
# Check if we have now finished all requests.
local_unfinished_reqs = (
self.scheduler.has_unfinished_requests())
else:
if self.scheduler.has_finished_requests():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self._process_engine_step()
if not self.global_unfinished_reqs:
# All engines are idle.
continue
# There must be unfinished requests in DP peers, run a
# dummy forward pass.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.global_unfinished_reqs:
# Notify client that we are pausing the loop.
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 16 steps.
self.counter += 1
if self.counter != 16:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
...@@ -8,10 +8,11 @@ import threading ...@@ -8,10 +8,11 @@ import threading
import uuid import uuid
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Any, Optional, Union from typing import Any, Callable, Optional, TypeVar, Union
import zmq import zmq
import zmq.asyncio import zmq.asyncio
...@@ -32,6 +33,8 @@ logger = init_logger(__name__) ...@@ -32,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]] AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCoreClient(ABC): class EngineCoreClient(ABC):
""" """
...@@ -60,6 +63,9 @@ class EngineCoreClient(ABC): ...@@ -60,6 +63,9 @@ class EngineCoreClient(ABC):
"is not currently supported.") "is not currently supported.")
if multiprocess_mode and asyncio_mode: if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class, log_stats) return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode: if multiprocess_mode and not asyncio_mode:
...@@ -86,7 +92,7 @@ class EngineCoreClient(ABC): ...@@ -86,7 +92,7 @@ class EngineCoreClient(ABC):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
...@@ -113,6 +119,19 @@ class EngineCoreClient(ABC): ...@@ -113,6 +119,19 @@ class EngineCoreClient(ABC):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
raise NotImplementedError
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
...@@ -128,7 +147,7 @@ class EngineCoreClient(ABC): ...@@ -128,7 +147,7 @@ class EngineCoreClient(ABC):
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
...@@ -149,6 +168,20 @@ class EngineCoreClient(ABC): ...@@ -149,6 +168,20 @@ class EngineCoreClient(ABC):
async def pin_lora_async(self, lora_id: int) -> bool: async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
async def save_sharded_state_async(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
raise NotImplementedError
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
class InprocClient(EngineCoreClient): class InprocClient(EngineCoreClient):
""" """
...@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient): ...@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
...@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient): ...@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
self.engine_core.save_sharded_state(path, pattern, max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngine:
"""One per data parallel rank."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name=f"EngineCore_{index}",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"dp_rank": index,
"local_dp_rank": local_dp_rank,
"executor_class": executor_class,
"log_stats": log_stats,
})
self.num_reqs_in_flight = 0
finally:
if not hasattr(self, "num_reqs_in_flight"):
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass @dataclass
class BackgroundResources: class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding """Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object.""" circular reference back to the client object."""
ctx: zmq.Context ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None shutdown_path: Optional[str] = None
def __call__(self): def __call__(self):
"""Clean up background resources.""" """Clean up background resources."""
if self.proc_handle is not None: for core_engine in self.core_engines:
self.proc_handle.shutdown() core_engine.close()
# ZMQ context termination can hang if the sockets # ZMQ context termination can hang if the sockets
# aren't explicitly closed first. # aren't explicitly closed first.
if self.output_socket is not None: if self.output_socket is not None:
self.output_socket.close(linger=0) self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None: if self.shutdown_path is not None:
# We must ensure that the sync output socket is # We must ensure that the sync output socket is
# closed cleanly in its own thread. # closed cleanly in its own thread.
...@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient): ...@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient):
self.decoder = MsgpackDecoder(EngineCoreOutputs) self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
sync_ctx = zmq.Context() sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed # This will ensure resources created so far are closed
...@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient): ...@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient):
self.resources = BackgroundResources(ctx=sync_ctx) self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources) self._finalizer = weakref.finalize(self, self.resources)
# Paths for IPC. # Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path() self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
# Start EngineCore in background process. new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
self.resources.proc_handle = BackgroundProcHandle( vllm_config, executor_class, log_stats, self.ctx, self.output_path,
input_path=input_path, index, local_dp_rank)
output_path=self.output_path,
process_name="EngineCore", # Start engine core process(es).
target_fn=EngineCoreProc.run_engine_core, self._init_core_engines(vllm_config, new_core_engine,
process_kwargs={ self.resources.core_engines)
"vllm_config": vllm_config,
"executor_class": executor_class, # Wait for engine core process(es) to start.
"log_stats": log_stats, for engine in self.resources.core_engines:
}) engine.proc_handle.wait_for_startup()
# Create input socket.
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.input_socket = self.resources.input_socket
self.utility_results: dict[int, AnyFuture] = {} self.utility_results: dict[int, AnyFuture] = {}
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Default case - single core engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
core_engine = new_core_engine(
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
core_engines.append(core_engine)
self.core_engine = core_engine
def shutdown(self): def shutdown(self):
self._finalizer() self._finalizer()
...@@ -356,9 +458,9 @@ class SyncMPClient(MPClient): ...@@ -356,9 +458,9 @@ class SyncMPClient(MPClient):
def process_outputs_socket(): def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR) shutdown_socket = ctx.socket(zmq.PAIR)
shutdown_socket.bind(shutdown_path)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try: try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller() poller = zmq.Poller()
poller.register(shutdown_socket) poller.register(shutdown_socket)
poller.register(out_socket) poller.register(out_socket)
...@@ -370,7 +472,7 @@ class SyncMPClient(MPClient): ...@@ -370,7 +472,7 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread. # shutdown signal, exit thread.
break break
(frame, ) = out_socket.recv_multipart(copy=False) frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer) outputs = decoder.decode(frame.buffer)
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, _process_utility_output(outputs.utility_output,
...@@ -391,18 +493,15 @@ class SyncMPClient(MPClient): ...@@ -391,18 +493,15 @@ class SyncMPClient(MPClient):
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get() return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, def _send_input(self, request_type: EngineCoreRequestType, request: Any):
request: Any) -> None:
# (RequestType, SerializedRequest) # (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False) self.core_engine.send_multipart(msg)
def _call_utility(self, method: str, *args) -> Any: def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64 call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future() future: Future[Any] = Future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY, self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args)) (call_id, method, args))
...@@ -419,34 +518,48 @@ class SyncMPClient(MPClient): ...@@ -419,34 +518,48 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids) self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self._call_utility("profile", is_start) self.call_utility("profile", is_start)
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self._call_utility("reset_prefix_cache") self.call_utility("reset_prefix_cache")
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self._call_utility("add_lora", lora_request) return self.call_utility("add_lora", lora_request)
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self._call_utility("remove_lora", lora_id) return self.call_utility("remove_lora", lora_id)
def list_loras(self) -> set[int]: def list_loras(self) -> set[int]:
return self._call_utility("list_loras") return self.call_utility("list_loras")
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self._call_utility("pin_lora", lora_id) return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self._call_utility("sleep", level) self.call_utility("sleep", level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self._call_utility("wake_up") self.call_utility("wake_up", tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self._call_utility("is_sleeping") return self.call_utility("is_sleeping")
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self._call_utility("execute_dummy_batch") self.call_utility("execute_dummy_batch")
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.call_utility("collective_rpc", method, timeout, args,
kwargs)
def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
self.call_utility("save_sharded_state", path, pattern, max_size)
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
...@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient): ...@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient):
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None self.queue_task: Optional[asyncio.Task] = None
async def _start_output_queue_task(self): self.outputs_handler: Optional[Callable[
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
def _ensure_output_queue_task(self):
if self.outputs_queue is not None:
return
# Perform IO in separate task to parallelize as much as possible. # Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client. # Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue() self.outputs_queue = asyncio.Queue()
decoder = self.decoder decoder = self.decoder
utility_results = self.utility_results utility_results = self.utility_results
outputs_queue = self.outputs_queue outputs_queue = self.outputs_queue
output_handler = self.outputs_handler
_self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path, output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL) zmq.constants.PULL)
...@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient): ...@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient):
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, _process_utility_output(outputs.utility_output,
utility_results) utility_results)
else: continue
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs) outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket(), self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask") name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None: self._ensure_output_queue_task()
await self._start_output_queue_task() assert self.outputs_queue is not None
assert self.outputs_queue is not None
return await self.outputs_queue.get() return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType, async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
msg = (request_type.value, self.encoder.encode(request)) self._ensure_output_queue_task()
await self.input_socket.send_multipart(msg, copy=False)
if self.outputs_queue is None: async def call_utility_async(self, method: str, *args) -> Any:
await self._start_output_queue_task() return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(self, method: str, *args) -> Any: async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
call_id = uuid.uuid1().int >> 64 call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future() future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
await self._send_input(EngineCoreRequestType.UTILITY, message = (EngineCoreRequestType.UTILITY.value,
(call_id, method, args)) self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
self._ensure_output_queue_task()
return await future return await future
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
...@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient): ...@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._call_utility_async("profile", is_start) await self.call_utility_async("profile", is_start)
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
await self._call_utility_async("reset_prefix_cache") await self.call_utility_async("reset_prefix_cache")
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
await self._call_utility_async("sleep", level) await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
await self._call_utility_async("wake_up") await self.call_utility_async("wake_up", tags)
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
return await self._call_utility_async("is_sleeping") return await self.call_utility_async("is_sleeping")
async def execute_dummy_batch_async(self) -> None: async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch") await self.call_utility_async("execute_dummy_batch")
async def add_lora_async(self, lora_request: LoRARequest) -> bool: async def add_lora_async(self, lora_request: LoRARequest) -> bool:
return await self._call_utility_async("add_lora", lora_request) return await self.call_utility_async("add_lora", lora_request)
async def remove_lora_async(self, lora_id: int) -> bool: async def remove_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("remove_lora", lora_id) return await self.call_utility_async("remove_lora", lora_id)
async def list_loras_async(self) -> set[int]: async def list_loras_async(self) -> set[int]:
return await self._call_utility_async("list_loras") return await self.call_utility_async("list_loras")
async def pin_lora_async(self, lora_id: int) -> bool: async def pin_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("pin_lora", lora_id) return await self.call_utility_async("pin_lora", lora_id)
async def save_sharded_state_async(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None) -> None:
await self.call_utility_async("save_sharded_state", path, pattern,
max_size)
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return await self.call_utility_async("collective_rpc", method, timeout,
args, kwargs)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(vllm_config, executor_class, log_stats)
assert len(self.core_engines) > 1
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Launch a core engine for each data parallel rank.
dp_size = vllm_config.parallel_config.data_parallel_size
for i in range(dp_size):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines.append(new_core_engine(i, i))
self.core_engines = core_engines
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (await asyncio.gather(*[
self._call_utility_async(method, *args, engine=engine)
for engine in self.core_engines
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
])
self._ensure_output_queue_task()
def get_core_engine_for_request(self) -> CoreEngine:
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
@staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient",
outputs: EngineCoreOutputs):
if self.reqs_in_flight:
for req_id in outputs.finished_requests or ():
if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1
if outputs.engine_paused:
assert self.num_engines_running >= 1
self.num_engines_running -= 1
if not self.num_engines_running and self.reqs_in_flight:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
if coros:
await asyncio.gather(*coros)
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids:
return
if len(request_ids) == 1:
# Fast-path common case.
if engine := self.reqs_in_flight.get(request_ids[0]):
await self._abort_requests(request_ids, engine)
return
by_engine: dict[CoreEngine, list[str]] = {}
for req_id in request_ids:
if engine := self.reqs_in_flight.get(req_id):
by_engine.setdefault(engine, []).append(req_id)
for engine, req_ids in by_engine.items():
await self._abort_requests(req_ids, engine)
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))
...@@ -2,15 +2,16 @@ ...@@ -2,15 +2,16 @@
from collections.abc import Mapping from collections.abc import Mapping
from copy import copy from copy import copy
from typing import Optional, Union from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor ...@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__) logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_R = TypeVar("_R", default=Any)
class LLMEngine: class LLMEngine:
...@@ -43,7 +45,6 @@ class LLMEngine: ...@@ -43,7 +45,6 @@ class LLMEngine:
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None, stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
multiprocess_mode: bool = False, multiprocess_mode: bool = False,
...@@ -60,11 +61,13 @@ class LLMEngine: ...@@ -60,11 +61,13 @@ class LLMEngine:
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
# important: init dp group before init the engine_core # important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config # In the decoupled engine case this is handled in EngineCoreProc.
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
if self.dp_enabled:
self.dp_group = self.parallel_config.stateless_init_dp_group()
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
...@@ -77,7 +80,6 @@ class LLMEngine: ...@@ -77,7 +80,6 @@ class LLMEngine:
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config, self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry) mm_registry=mm_registry)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
...@@ -148,7 +150,7 @@ class LLMEngine: ...@@ -148,7 +150,7 @@ class LLMEngine:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests() has_unfinished = self.output_processor.has_unfinished_requests()
if not self.dp_enabled: if self.dp_group is None:
return has_unfinished return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished) return self.has_unfinished_requests_dp(has_unfinished)
...@@ -243,8 +245,8 @@ class LLMEngine: ...@@ -243,8 +245,8 @@ class LLMEngine:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
...@@ -280,3 +282,14 @@ class LLMEngine: ...@@ -280,3 +282,14 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted.""" """Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
...@@ -328,7 +328,7 @@ class OutputProcessor: ...@@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks. # 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update( stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP) new_token_ids, finish_reason == FinishReason.STOP)
if stop_string and finish_reason != FinishReason.STOP: if stop_string:
finish_reason = FinishReason.STOP finish_reason = FinishReason.STOP
stop_reason = stop_string stop_reason = stop_string
......
...@@ -5,9 +5,8 @@ from collections.abc import Mapping ...@@ -5,9 +5,8 @@ from collections.abc import Mapping
from typing import Optional, Union from typing import Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import ProcessorInputs, PromptType
PromptType, SingletonInputsAdapter) from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
...@@ -31,7 +30,6 @@ class Processor: ...@@ -31,7 +30,6 @@ class Processor:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
tokenizer: BaseTokenizerGroup, tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
...@@ -123,7 +121,8 @@ class Processor: ...@@ -123,7 +121,8 @@ class Processor:
return return
supported_backends = [ supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto" "xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
] ]
engine_level_backend = self.decoding_config.guided_decoding_backend engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends: if engine_level_backend not in supported_backends:
...@@ -137,13 +136,15 @@ class Processor: ...@@ -137,13 +136,15 @@ class Processor:
f" != {engine_level_backend}") f" != {engine_level_backend}")
else: else:
params.guided_decoding.backend = engine_level_backend params.guided_decoding.backend = engine_level_backend
import vllm.platforms
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
# Request content validation # Request content validation
if engine_level_backend.startswith("xgrammar"):
if engine_level_backend == "xgrammar":
# xgrammar with no fallback # xgrammar with no fallback
validate_structured_output_request_xgrammar(params) validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar" params.guided_decoding.backend = engine_level_backend
elif engine_level_backend == "auto": elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to # "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the # choose a backend based on request contents. This is not the
...@@ -157,12 +158,13 @@ class Processor: ...@@ -157,12 +158,13 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance. # are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance" params.guided_decoding.backend = "guidance"
if params.guided_decoding.backend == "guidance": if engine_level_backend.startswith("guidance"):
# TODO ideally we would have the LLTokenizer here as Lark syntax # TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see # allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars. # Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None) validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = engine_level_backend
def process_inputs( def process_inputs(
self, self,
...@@ -206,14 +208,7 @@ class Processor: ...@@ -206,14 +208,7 @@ class Processor:
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
# TODO: Impl encoder-decoder # TODO: Impl encoder-decoder
if encoder_inputs is not None: if encoder_inputs is not None:
...@@ -224,8 +219,9 @@ class Processor: ...@@ -224,8 +219,9 @@ class Processor:
sampling_params = params.clone() sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len. # If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None: if sampling_params.max_tokens is None:
sampling_params.max_tokens = (self.model_config.max_model_len - sampling_params.max_tokens = (
len(decoder_inputs.prompt_token_ids)) self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id) self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer( sampling_params.update_from_tokenizer(
...@@ -235,57 +231,46 @@ class Processor: ...@@ -235,57 +231,46 @@ class Processor:
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None sorted_mm_hashes: Optional[list[str]] = None
if (decoder_mm_inputs := decoder_inputs.multi_modal_data): if decoder_inputs["type"] == "multimodal":
assert isinstance(decoder_mm_inputs, MultiModalKwargs) decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_inputs.modalities
for item in decoder_mm_inputs.get_items(modality)
]
# Merge and flatten multimodal placeholders, hashes and inputs # Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position # from dictionaries to lists, and sort them by each item's position
# in the input sequence. # in the input sequence.
# NOTE: interleaved modalities are not supported.
( (
sorted_modalities, sorted_item_modalities,
sorted_mm_positions, sorted_mm_positions,
sorted_mm_hashes, sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata( ) = merge_and_sort_multimodal_metadata(
decoder_inputs.multi_modal_placeholders, decoder_inputs["mm_placeholders"],
decoder_inputs.multi_modal_hashes if self.use_hash else None, decoder_inputs["mm_hashes"] if self.use_hash else None,
) )
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple # The output of merged multi-modal processor (`decoder_mm_inputs`)
# modalities involved. # is a single MultiModalKwargs for all items from all modalities.
if len(sorted_modalities) > 1: # This code flattens kwargs for individual items in a list and
modality_order_dict = { # sorts them by each item's position in the input sequence if there
modality: order # are multiple modalities.
for order, modality in enumerate(sorted_modalities) unique_modalities = set(sorted_item_modalities)
} if len(unique_modalities) > 1:
sorted_mm_inputs = []
# Sanity check to make sure each multimodal input has only one used_indices = {modality: 0 for modality in unique_modalities}
# modality key. for modality in sorted_item_modalities:
for mm_input in individual_mm_inputs: items = decoder_mm_inputs.get_items(modality)
assert len(mm_input.modalities) == 1 item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
# Sort MultiModalKwargs to match sorted_mm_positions ]))
sorted_mm_inputs = sorted( used_indices[modality] += 1
individual_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
else: else:
sorted_mm_inputs = individual_mm_inputs sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
return EngineCoreRequest( return EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt=decoder_inputs.prompt, prompt=decoder_inputs.get("prompt"),
prompt_token_ids=decoder_inputs.prompt_token_ids, prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs, mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes, mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions, mm_placeholders=sorted_mm_positions,
...@@ -298,15 +283,16 @@ class Processor: ...@@ -298,15 +283,16 @@ class Processor:
def _validate_model_inputs(self, def _validate_model_inputs(self,
inputs: ProcessorInputs, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None): lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # For encoder-decoder multimodal models, the max_prompt_len
prompt_inputs = inputs["decoder" if self.model_config. # restricts the decoder prompt length
is_multimodal_model else "encoder"] if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else: else:
prompt_inputs = inputs prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
......
...@@ -235,7 +235,10 @@ class WorkerProc: ...@@ -235,7 +235,10 @@ class WorkerProc:
worker_response_mq_handle = self.worker_response_mq.export_handle() worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process. # Send Readiness signal to EngineCore process.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: # Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle, payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL) protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send_string(WorkerProc.READY_STR)
...@@ -270,11 +273,13 @@ class WorkerProc: ...@@ -270,11 +273,13 @@ class WorkerProc:
proc = context.Process(target=WorkerProc.worker_main, proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs, kwargs=process_kwargs,
daemon=True) daemon=True)
proc.start()
# Wait for startup with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
worker_response_mq_handle = WorkerProc.wait_for_startup( proc.start()
proc, ready_path)
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_socket)
worker_response_mq = MessageQueue.create_from_handle( worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0) worker_response_mq_handle, 0)
...@@ -337,23 +342,22 @@ class WorkerProc: ...@@ -337,23 +342,22 @@ class WorkerProc:
@staticmethod @staticmethod
def wait_for_startup( def wait_for_startup(
proc: BaseProcess, proc: BaseProcess,
ready_path: str, ready_socket: zmq.Socket,
) -> Optional[Handle]: ) -> Optional[Handle]:
"""Wait until the Worker is ready.""" """Wait until the Worker is ready."""
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY. # Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.") logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive(): if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.") raise RuntimeError("WorkerProc failed to start.")
message = socket.recv_string() message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR assert message == WorkerProc.READY_STR
handle_frame = socket.recv(copy=False) handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer) handle = pickle.loads(handle_frame.buffer)
return handle return handle
class ResponseStatus(Enum): class ResponseStatus(Enum):
SUCCESS = auto() SUCCESS = auto()
......
...@@ -4,6 +4,7 @@ from dataclasses import dataclass ...@@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size from vllm.utils import cdiv, get_dtype_size
...@@ -43,28 +44,23 @@ class KVCacheSpec: ...@@ -43,28 +44,23 @@ class KVCacheSpec:
""" """
raise NotImplementedError raise NotImplementedError
def bytes_for_tokens(self, num_tokens: int) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
""" """
The KV cache size for `num_tokens` tokens in bytes. Returns the real The maximum possible memory usage of this KV cache in bytes.
memory size after padding `num_tokens` to full blocks.
Returns: Returns:
The KV cache size The KV cache size in bytes
""" """
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class FullAttentionSpec(KVCacheSpec): class AttentionSpec(KVCacheSpec):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
use_mla: bool use_mla: bool
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector # For MLA we only store a single latent vector
...@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec): ...@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
return coef * self.block_size * self.num_kv_heads * self.head_size \ return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int:
return cdiv(num_tokens, self.block_size) * self.page_size_bytes @dataclass
class FullAttentionSpec(AttentionSpec):
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@dataclass
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
max_model_len)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass @dataclass
......
...@@ -12,6 +12,7 @@ from vllm.logger import init_logger ...@@ -12,6 +12,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,12 +32,14 @@ class StatLoggerBase(ABC): ...@@ -31,12 +32,14 @@ class StatLoggerBase(ABC):
class LoggingStatLogger(StatLoggerBase): class LoggingStatLogger(StatLoggerBase):
def __init__(self): def __init__(self, engine_index: int = 0):
self.engine_index = engine_index
self._reset(time.monotonic()) self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats() self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset. # Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable. # TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics() self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()
def _reset(self, now): def _reset(self, now):
self.last_log_time = now self.last_log_time = now
...@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
def log(self): def log(self):
...@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase):
# Format and print output. # Format and print output.
logger.info( logger.info(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, " "Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, " "Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, " "GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%", "Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
scheduler_stats.num_running_reqs, scheduler_stats.num_running_reqs,
...@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
) )
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log()
class PrometheusStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics() self._unregister_vllm_metrics()
# Use this flag to hide metrics that were deprecated in # Use this flag to hide metrics that were deprecated in
...@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase):
self.show_hidden_metrics = \ self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name"] labelnames = ["model_name", "engine"]
labelvalues = [vllm_config.model_config.served_model_name] labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
...@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self.labelname_running_lora_adapters, self.labelname_running_lora_adapters,
]) ])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
# #
# Cache config info metric # Cache config info metric
# #
...@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self.counter_gpu_prefix_cache_hits.inc( self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits) scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
if iteration_stats is None: if iteration_stats is None:
return return
......
...@@ -4,6 +4,8 @@ import time ...@@ -4,6 +4,8 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState from vllm.v1.engine.output_processor import RequestState
...@@ -35,6 +37,8 @@ class SchedulerStats: ...@@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats: PrefixCacheStats = field( prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats) default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@dataclass @dataclass
class LoRAStats: class LoRAStats:
......
...@@ -59,6 +59,8 @@ class Request: ...@@ -59,6 +59,8 @@ class Request:
self.mm_positions = multi_modal_placeholders or [] self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or [] self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: list[str] = multi_modal_hashes or [] self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)
...@@ -93,9 +95,11 @@ class Request: ...@@ -93,9 +95,11 @@ class Request:
token_ids: Union[int, list[int]], token_ids: Union[int, list[int]],
) -> None: ) -> None:
if isinstance(token_ids, int): if isinstance(token_ids, int):
token_ids = [token_ids] self._output_token_ids.append(token_ids)
self._output_token_ids.extend(token_ids) self._all_token_ids.append(token_ids)
self._all_token_ids.extend(token_ids) else:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
@property @property
def num_tokens(self) -> int: def num_tokens(self) -> int:
...@@ -115,13 +119,6 @@ class Request: ...@@ -115,13 +119,6 @@ class Request:
def get_finished_reason(self) -> Union[FinishReason, None]: def get_finished_reason(self) -> Union[FinishReason, None]:
return RequestStatus.get_finished_reason(self.status) return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return len(self.mm_inputs) > 0
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_positions)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions) assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"] num_tokens = self.mm_positions[input_id]["length"]
......
...@@ -19,6 +19,12 @@ except ImportError: ...@@ -19,6 +19,12 @@ except ImportError:
class TopKTopPSampler(nn.Module): class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module): ...@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling.""" """
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p) logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
...@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module): ...@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# If only top-k is specified, use pytorch's builtin topk op. This leads logits = apply_top_k_top_p_tpu(logits, k, p)
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
def apply_top_k_top_p_tpu(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
if k is not None:
logits = apply_top_k_only(logits, k)
if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def apply_top_k_top_p( def apply_top_k_top_p(
logits: torch.Tensor, logits: torch.Tensor,
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
...@@ -136,10 +171,18 @@ def apply_top_k_top_p( ...@@ -136,10 +171,18 @@ def apply_top_k_top_p(
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits. """Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches. If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
""" """
if k is None and p is None: if p is None:
return logits if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None: if k is not None:
...@@ -153,7 +196,7 @@ def apply_top_k_top_p( ...@@ -153,7 +196,7 @@ def apply_top_k_top_p(
if p is not None: if p is not None:
# Apply top-p. # Apply top-p.
probs_sort = logits_sort.softmax(dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one # at least one
top_p_mask[:, -1] = False top_p_mask[:, -1] = False
...@@ -164,6 +207,31 @@ def apply_top_k_top_p( ...@@ -164,6 +207,31 @@ def apply_top_k_top_p(
return logits return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample( def random_sample(
probs: torch.Tensor, probs: torch.Tensor,
generators: dict[int, torch.Generator], generators: dict[int, torch.Generator],
......
...@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module): ...@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module):
output_token_ids: torch.Tensor, output_token_ids: torch.Tensor,
vocab_size: int, vocab_size: int,
) -> list[list[int]]: ) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy() output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens. # Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
......
...@@ -87,6 +87,12 @@ class Sampler(nn.Module): ...@@ -87,6 +87,12 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert not (sampling_metadata.all_greedy assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random) and sampling_metadata.all_random)
if sampling_metadata.all_random: if sampling_metadata.all_random:
......
...@@ -5,7 +5,18 @@ from typing import Optional ...@@ -5,7 +5,18 @@ from typing import Optional
import torch import torch
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
# top_k=-1,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@dataclass @dataclass
...@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata: ...@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
top_k: torch.Tensor = None top_k: torch.Tensor = None
top_p: torch.Tensor = None top_p: torch.Tensor = None
# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph. # Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None all_greedy: torch.Tensor = None
# speculation not supported
spec_token_ids = None
# Generator not supported by xla # Generator not supported by xla
generators: dict[int, generators: dict[int,
...@@ -54,106 +59,62 @@ class TPUSupportedSamplingMetadata: ...@@ -54,106 +59,62 @@ class TPUSupportedSamplingMetadata:
bad_words_token_ids = None bad_words_token_ids = None
indices_do_sample: torch.Tensor = None indices_do_sample: torch.Tensor = None
def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)
@classmethod @classmethod
def from_sampling_metadata( def from_input_batch(
cls, metadata: SamplingMetadata, cls, input_batch: InputBatch,
padded_do_sample_indices: torch.Tensor, num_do_sample: int, indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
device: torch.device) -> "TPUSupportedSamplingMetadata":
""" """
Create an XLA-frienly SamplingMetadata structure. Do so by first Copy sampling tensors slices from `input_batch` to on device tensors.
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that `InputBatch._make_sampling_metadata` causes recompilation on XLA as it
recompilation is not triggered for optional values (None/torch.Tensor). slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
In order to handle different sizes for the params that range from 1 up also reuses the on-device persistent tensors managed in `input_batch`
to `max_num_seqs`, pad tensors to the closest pre-compiled shape. to reduce waste.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape. `indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] We expect sampling params tensors to be padded to the same fixed shape.
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
""" """
metadata = cls._validate_sampling_metadata(metadata) num_reqs = input_batch.num_reqs
# NOTE we have to initialize default tensor-based params first and padded_num_reqs = len(indices_do_sample)
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices) def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
do_argmax = torch.tensor(metadata.all_greedy, fill_val) -> torch.Tensor:
dtype=torch.bool, # Copy slice from CPU to corresponding TPU pre-allocated tensor.
device=device) # Pad value is the default one.
new_metadata = cls.get_default_sampling_params(num_samples, device, cpu_tensor[num_reqs:padded_num_reqs] = fill_val
indices_do_sample=\ # Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
padded_do_sample_indices, tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
do_argmax=do_argmax
) # NOTE NickLucche The sync CPU-TPU graph we produce here must be
supported_params = \ # consistent. We can't have flags to skip copies or we'll end up
TPUSupportedSamplingMetadata._get_default_params_values() # recompiling.
# Copy input non-None values into `new_metadata` fixed-sized tensors. copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
for p_name in supported_params: DEFAULT_SAMPLING_PARAMS["temperature"])
old_val = getattr(metadata, p_name) # TODO Temporarily disabled until sampling options are enabled
new_val = getattr(new_metadata, p_name) # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
if isinstance(old_val, torch.Tensor): # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
new_val[:num_do_sample] = old_val copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
setattr(new_metadata, p_name, new_val) DEFAULT_SAMPLING_PARAMS["min_p"])
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
return new_metadata
@classmethod # Slice persistent device tensors to a fixed pre-compiled padded shape.
def get_default_sampling_params( return cls(
cls, temperature=input_batch.temperature[:padded_num_reqs],
num_samples: int, # Scalar tensor for xla-friendly tracing.
device: torch.device, all_greedy=torch.tensor(input_batch.all_greedy,
indices_do_sample=None, dtype=torch.bool,
do_argmax=None) -> "TPUSupportedSamplingMetadata": device=input_batch.device),
# As sampling happens on a single traced graph, options # TODO enable more and avoid returning None values
# are "disabled" by having them evaluate to an Identity op. top_p=None, # input_batch.top_p[:padded_num_reqs],
# Note that initialization is dependent on num_samples. top_k=None, # input_batch.top_k[:padded_num_reqs],
sampling_metadata_disable_value = \ min_p=input_batch.min_p[:padded_num_reqs],
TPUSupportedSamplingMetadata._get_default_params_values() generators=input_batch.generators,
init_kwargs = dict() indices_do_sample=indices_do_sample)
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor
return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)
@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata
@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle import pickle
from types import FunctionType
from typing import Any, Optional from typing import Any, Optional
import cloudpickle
import torch import torch
from msgspec import msgpack from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1 CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2 CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
class MsgpackEncoder: class MsgpackEncoder:
...@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any: ...@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
...@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any: ...@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return torch.from_numpy(pickle.loads(data)) return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE: if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data) return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported") raise NotImplementedError(f"Extension type code {code} is not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment