Commit 2fde0fa2 authored by laibao's avatar laibao
Browse files

feat: kvpress新增调度层 KV 压缩逻辑

parent eef99f73
......@@ -28,12 +28,16 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, SlidingWindowSpec
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_prompt_keep_len,
compute_topk_budget_step,
count_prompt_must_keep_in_range)
from vllm.platforms import current_platform
from vllm import envs
logger = init_logger(__name__)
......@@ -156,6 +160,53 @@ class Scheduler(SchedulerInterface):
self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla
# KV compression is only supported on CUDA/ROCm in this fork.
# Other backends (TPU/CPU/XPU/HPU/Neuron/...) do not plumb the
# num_kv_tokens-based slot mapping/metadata and can produce incorrect
# cache mappings if enabled.
self.kv_compression_enabled = (envs.VLLM_ENABLE_KV_COMPRESSION
and current_platform.is_cuda_alike())
if envs.VLLM_ENABLE_KV_COMPRESSION and not self.kv_compression_enabled:
logger.warning_once(
"KV compression is only supported on CUDA/ROCm; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1 on this platform.")
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups):
raise ValueError(
"KV compression is incompatible with sliding window "
"attention.")
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression.")
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with "
"speculative decoding.")
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
......@@ -207,6 +258,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# For logging.
scheduled_timestamp = time.monotonic()
......@@ -274,6 +327,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens == request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
......@@ -295,6 +355,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -321,6 +382,10 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
......@@ -532,6 +597,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -586,6 +653,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
......@@ -645,6 +713,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# Track the LoRAs in this step to respect max_loras when scheduling
# waiting requests first.
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in self.running
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# For logging.
scheduled_timestamp = time.monotonic()
......@@ -826,6 +904,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -894,6 +974,14 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens
== request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
......@@ -915,6 +1003,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -941,6 +1030,10 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
......@@ -1014,6 +1107,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
......@@ -1076,7 +1170,80 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token
if not self.kv_compression_enabled:
# Keep KV length in sync with logical length when compression
# is disabled (default vLLM behavior).
request.num_kv_tokens += num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens
# Chunked prefill: do not change the prompt KV length mid-prefill.
# Otherwise, the next prefill chunk would attend to a truncated
# history (semantic change / quality collapse). Instead, keep the
# full prompt KV until the prompt is fully ingested, then apply a
# one-shot prompt compaction before decode.
if self.scheduler_config.chunked_prefill_enabled:
if start_pos >= prompt_end:
# Decode token(s): keep all.
request.num_kv_tokens += num_scheduled_token
continue
if end_pos < prompt_end:
# Prompt is still being ingested: keep all tokens for now.
request.num_kv_tokens += num_scheduled_token
continue
# This step finishes the prompt (and may include decode tokens
# in rare cases). Apply the final prompt compression length.
kept_prompt_total = compute_prompt_keep_len(
prompt_len=prompt_end,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
kept_decode = max(0, end_pos - max(start_pos, prompt_end))
request.num_kv_tokens = kept_prompt_total + kept_decode
continue
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (
kept_decode + kept_prompt_must_keep + kept_prompt_topk)
# Clear the finished request IDs.
......@@ -1091,11 +1258,16 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
*,
force_replace_block_ids: Optional[set[str]] = None,
) -> CachedRequestData:
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
resumed_from_preemption: list[bool] = []
force_replace_block_ids = force_replace_block_ids or set()
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
......@@ -1111,10 +1283,9 @@ class Scheduler(SchedulerInterface):
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
num_kv_tokens.append(req.num_kv_tokens)
resumed_from_preemption.append(
(req in resumed_reqs) or (req_id in force_replace_block_ids))
return CachedRequestData(
req_ids=req_ids,
......@@ -1122,6 +1293,7 @@ class Scheduler(SchedulerInterface):
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
)
def _try_schedule_encoder_inputs(
......@@ -1567,6 +1739,7 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
request.num_kv_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment