Commit 3bc7eb74 authored by laibao's avatar laibao
Browse files

refactor: 抽取调度器 KV 长度账本逻辑;修复 num_kv_tokens 初始化

parent 4634cbcf
...@@ -34,9 +34,10 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -34,9 +34,10 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_prompt_keep_len, from vllm.v1.kv_compression.scheduler_accounting import (
compute_topk_budget_step, maybe_init_num_kv_tokens_on_running_transition,
count_prompt_must_keep_in_range) update_num_kv_tokens_after_schedule,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs from vllm import envs
...@@ -597,8 +598,11 @@ class Scheduler(SchedulerInterface): ...@@ -597,8 +598,11 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled: maybe_init_num_kv_tokens_on_running_transition(
request.num_kv_tokens = num_computed_tokens request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -904,8 +908,11 @@ class Scheduler(SchedulerInterface): ...@@ -904,8 +908,11 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled: maybe_init_num_kv_tokens_on_running_transition(
request.num_kv_tokens = num_computed_tokens request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -1172,78 +1179,13 @@ class Scheduler(SchedulerInterface): ...@@ -1172,78 +1179,13 @@ class Scheduler(SchedulerInterface):
request = self.requests[req_id] request = self.requests[req_id]
start_pos = request.num_computed_tokens start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token request.num_computed_tokens += num_scheduled_token
if not self.kv_compression_enabled: update_num_kv_tokens_after_schedule(
# Keep KV length in sync with logical length when compression request=request,
# is disabled (default vLLM behavior).
request.num_kv_tokens += num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens
# Chunked prefill: do not change the prompt KV length mid-prefill.
# Otherwise, the next prefill chunk would attend to a truncated
# history (semantic change / quality collapse). Instead, keep the
# full prompt KV until the prompt is fully ingested, then apply a
# one-shot prompt compaction before decode.
if self.scheduler_config.chunked_prefill_enabled:
if start_pos >= prompt_end:
# Decode token(s): keep all.
request.num_kv_tokens += num_scheduled_token
continue
if end_pos < prompt_end:
# Prompt is still being ingested: keep all tokens for now.
request.num_kv_tokens += num_scheduled_token
continue
# This step finishes the prompt (and may include decode tokens
# in rare cases). Apply the final prompt compression length.
kept_prompt_total = compute_prompt_keep_len(
prompt_len=prompt_end,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
kept_decode = max(0, end_pos - max(start_pos, prompt_end))
request.num_kv_tokens = kept_prompt_total + kept_decode
continue
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos, start_pos=start_pos,
end_pos=end_pos, num_scheduled_token=num_scheduled_token,
protected_prefix=protected_prefix, chunked_prefill_enabled=self.scheduler_config.chunked_prefill_enabled,
protected_suffix=protected_suffix, kv_compression_enabled=self.kv_compression_enabled,
keep_last_token=keep_last,
) )
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (
kept_decode + kept_prompt_must_keep + kept_prompt_topk)
# Clear the finished request IDs. # Clear the finished request IDs.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any
import vllm.envs as envs
from vllm.v1.kv_compression.budget import (compute_prompt_keep_len,
compute_topk_budget_step,
count_prompt_must_keep_in_range)
def maybe_init_num_kv_tokens_on_running_transition(
*,
request: Any,
num_computed_tokens: int,
kv_compression_enabled: bool,
) -> None:
"""Initialize `request.num_kv_tokens` when a request becomes RUNNING.
When KV compression is disabled, KV length always matches the logical length.
When KV compression is enabled, most requests start with
`num_computed_tokens == 0` so no init is needed. However, connector/cache-hit
paths can start a RUNNING request with precomputed tokens; in that case we
must ensure KV length is initialized to avoid incorrect KV write offsets.
"""
if not kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
return
if getattr(request, "num_kv_tokens", 0) == 0 and num_computed_tokens > 0:
request.num_kv_tokens = num_computed_tokens
def update_num_kv_tokens_after_schedule(
*,
request: Any,
start_pos: int,
num_scheduled_token: int,
chunked_prefill_enabled: bool,
kv_compression_enabled: bool,
) -> None:
"""Advance `request.num_kv_tokens` after scheduling a step."""
if num_scheduled_token <= 0:
return
if not kv_compression_enabled:
request.num_kv_tokens += num_scheduled_token
return
# When KV compression is enabled, only keep a subset of prompt tokens.
# Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = start_pos + int(num_scheduled_token)
prompt_end = int(getattr(request, "num_prompt_tokens", 0))
# Chunked prefill: do not change the prompt KV length mid-prefill.
# Otherwise, the next prefill chunk would attend to a truncated history
# (semantic change / quality collapse). Instead, keep the full prompt KV
# until the prompt is fully ingested, then apply a one-shot prompt
# compaction before decode.
if chunked_prefill_enabled:
if start_pos >= prompt_end:
# Decode token(s): keep all.
request.num_kv_tokens += num_scheduled_token
return
if end_pos < prompt_end:
# Prompt is still being ingested: keep all tokens for now.
request.num_kv_tokens += num_scheduled_token
return
# This step finishes the prompt (and may include decode tokens in rare
# cases). Apply the final prompt compression length.
kept_prompt_total = compute_prompt_keep_len(
prompt_len=prompt_end,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
kept_decode = max(0, end_pos - max(start_pos, prompt_end))
request.num_kv_tokens = kept_prompt_total + kept_decode
return
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (kept_decode + kept_prompt_must_keep +
kept_prompt_topk)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment