Commit 2fde0fa2 authored by laibao's avatar laibao
Browse files

feat: kvpress新增调度层 KV 压缩逻辑

parent eef99f73
...@@ -28,12 +28,16 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy, ...@@ -28,12 +28,16 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
from vllm.v1.core.sched.utils import check_stop from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig, SlidingWindowSpec
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_prompt_keep_len,
compute_topk_budget_step,
count_prompt_must_keep_in_range)
from vllm.platforms import current_platform
from vllm import envs from vllm import envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -156,6 +160,53 @@ class Scheduler(SchedulerInterface): ...@@ -156,6 +160,53 @@ class Scheduler(SchedulerInterface):
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla self.use_mla = vllm_config.model_config.use_mla
# KV compression is only supported on CUDA/ROCm in this fork.
# Other backends (TPU/CPU/XPU/HPU/Neuron/...) do not plumb the
# num_kv_tokens-based slot mapping/metadata and can produce incorrect
# cache mappings if enabled.
self.kv_compression_enabled = (envs.VLLM_ENABLE_KV_COMPRESSION
and current_platform.is_cuda_alike())
if envs.VLLM_ENABLE_KV_COMPRESSION and not self.kv_compression_enabled:
logger.warning_once(
"KV compression is only supported on CUDA/ROCm; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1 on this platform.")
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups):
raise ValueError(
"KV compression is incompatible with sliding window "
"attention.")
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression.")
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with "
"speculative decoding.")
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
...@@ -207,6 +258,8 @@ class Scheduler(SchedulerInterface): ...@@ -207,6 +258,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
...@@ -274,6 +327,13 @@ class Scheduler(SchedulerInterface): ...@@ -274,6 +327,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens - num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0) request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens == request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
...@@ -295,6 +355,7 @@ class Scheduler(SchedulerInterface): ...@@ -295,6 +355,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -321,8 +382,12 @@ class Scheduler(SchedulerInterface): ...@@ -321,8 +382,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = ( if request.request_id in force_replace_block_ids:
new_blocks.get_block_ids()) req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
...@@ -532,6 +597,8 @@ class Scheduler(SchedulerInterface): ...@@ -532,6 +597,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -586,6 +653,7 @@ class Scheduler(SchedulerInterface): ...@@ -586,6 +653,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens, num_scheduled_tokens,
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids, req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
...@@ -645,6 +713,16 @@ class Scheduler(SchedulerInterface): ...@@ -645,6 +713,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# Track the LoRAs in this step to respect max_loras when scheduling
# waiting requests first.
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in self.running
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
...@@ -826,6 +904,8 @@ class Scheduler(SchedulerInterface): ...@@ -826,6 +904,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -894,6 +974,14 @@ class Scheduler(SchedulerInterface): ...@@ -894,6 +974,14 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens - num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0) request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens
== request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
...@@ -915,6 +1003,7 @@ class Scheduler(SchedulerInterface): ...@@ -915,6 +1003,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -941,8 +1030,12 @@ class Scheduler(SchedulerInterface): ...@@ -941,8 +1030,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = ( if request.request_id in force_replace_block_ids:
new_blocks.get_block_ids()) req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
...@@ -1014,6 +1107,7 @@ class Scheduler(SchedulerInterface): ...@@ -1014,6 +1107,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens, num_scheduled_tokens,
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids, req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
...@@ -1076,8 +1170,81 @@ class Scheduler(SchedulerInterface): ...@@ -1076,8 +1170,81 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items(): for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id] request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token request.num_computed_tokens += num_scheduled_token
if not self.kv_compression_enabled:
# Keep KV length in sync with logical length when compression
# is disabled (default vLLM behavior).
request.num_kv_tokens += num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens
# Chunked prefill: do not change the prompt KV length mid-prefill.
# Otherwise, the next prefill chunk would attend to a truncated
# history (semantic change / quality collapse). Instead, keep the
# full prompt KV until the prompt is fully ingested, then apply a
# one-shot prompt compaction before decode.
if self.scheduler_config.chunked_prefill_enabled:
if start_pos >= prompt_end:
# Decode token(s): keep all.
request.num_kv_tokens += num_scheduled_token
continue
if end_pos < prompt_end:
# Prompt is still being ingested: keep all tokens for now.
request.num_kv_tokens += num_scheduled_token
continue
# This step finishes the prompt (and may include decode tokens
# in rare cases). Apply the final prompt compression length.
kept_prompt_total = compute_prompt_keep_len(
prompt_len=prompt_end,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
kept_decode = max(0, end_pos - max(start_pos, prompt_end))
request.num_kv_tokens = kept_prompt_total + kept_decode
continue
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (
kept_decode + kept_prompt_must_keep + kept_prompt_topk)
# Clear the finished request IDs. # Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because # NOTE: We shouldn't do self.finished_req_ids.clear() here because
...@@ -1091,11 +1258,16 @@ class Scheduler(SchedulerInterface): ...@@ -1091,11 +1258,16 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens: dict[str, int], num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]], spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]], req_to_new_block_ids: dict[str, tuple[list[int], ...]],
*,
force_replace_block_ids: Optional[set[str]] = None,
) -> CachedRequestData: ) -> CachedRequestData:
req_ids: list[str] = [] req_ids: list[str] = []
new_token_ids: list[list[int]] = [] new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = [] new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = [] num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
resumed_from_preemption: list[bool] = []
force_replace_block_ids = force_replace_block_ids or set()
for req in itertools.chain(running_reqs, resumed_reqs): for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id req_id = req.request_id
...@@ -1111,10 +1283,9 @@ class Scheduler(SchedulerInterface): ...@@ -1111,10 +1283,9 @@ class Scheduler(SchedulerInterface):
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do num_kv_tokens.append(req.num_kv_tokens)
# in-place appending so that we don't need to allocate a new list. resumed_from_preemption.append(
resumed_from_preemption = [False] * len(running_reqs) (req in resumed_reqs) or (req_id in force_replace_block_ids))
resumed_from_preemption += [True] * len(resumed_reqs)
return CachedRequestData( return CachedRequestData(
req_ids=req_ids, req_ids=req_ids,
...@@ -1122,6 +1293,7 @@ class Scheduler(SchedulerInterface): ...@@ -1122,6 +1293,7 @@ class Scheduler(SchedulerInterface):
new_token_ids=new_token_ids, new_token_ids=new_token_ids,
new_block_ids=new_block_ids, new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens, num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
) )
def _try_schedule_encoder_inputs( def _try_schedule_encoder_inputs(
...@@ -1567,6 +1739,7 @@ class Scheduler(SchedulerInterface): ...@@ -1567,6 +1739,7 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling. # Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
request.num_kv_tokens = num_computed_tokens
# Return that we are ready. # Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment