Commit faf55520 authored by laibao's avatar laibao
Browse files

feat: kvpress runner 侧生成 Top-K 压缩元数据

parent 2df94aa9
......@@ -55,6 +55,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.kv_compression.budget import (compute_prompt_topk_keep_total,
compute_topk_budget_step)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
......@@ -338,6 +340,77 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
self.kv_compression_must_keep_cpu = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_must_keep_np = self.kv_compression_must_keep_cpu.numpy()
self.kv_compression_must_keep = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device=self.device,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
self.kv_compression_topk_budget_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_topk_budget_np = self.kv_compression_topk_budget_cpu.numpy()
self.kv_compression_topk_budget = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Chunked-prefill prompt-end KV compression metadata (scheme 3).
# Per-request: whether this step finishes the prompt and should compute
# global prompt indices (score/topk) for a one-shot compaction.
self.kv_compression_prompt_end_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_prompt_end_np = self.kv_compression_prompt_end_cpu.numpy()
self.kv_compression_prompt_end = torch.zeros(
self.max_num_reqs,
dtype=torch.bool,
device=self.device,
)
# Per-request: prompt length (tokens) and Top-K keep count among prompt
# candidates (excluding protected prefix/suffix).
self.kv_compression_prompt_lens_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_prompt_lens_np = self.kv_compression_prompt_lens_cpu.numpy()
self.kv_compression_prompt_lens = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.kv_compression_prompt_topk_keep_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_prompt_topk_keep_np = self.kv_compression_prompt_topk_keep_cpu.numpy()
self.kv_compression_prompt_topk_keep = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.kv_compression_prompt_topk_keep_max: Optional[int] = None
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
......@@ -692,6 +765,105 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else:
kv_positions_np = None
if use_kv_compression:
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
if self.scheduler_config.chunked_prefill_enabled:
# Scheme 3: with chunked prefill, defer compaction until after
# the full prompt is ingested. Otherwise, the next prefill chunk
# would attend to a truncated history and quality can collapse.
prompt_end_np = self.kv_compression_prompt_end_np[:num_reqs]
prompt_end_np.fill(False)
prompt_lens_np = self.kv_compression_prompt_lens_np[:num_reqs]
prompt_lens_np.fill(0)
topk_keep_np = self.kv_compression_prompt_topk_keep_np[:num_reqs]
topk_keep_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
base_pos = int(self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
ends_prompt = (base_pos < prompt_len) and (end_pos >= prompt_len)
if not ends_prompt:
continue
prompt_end_np[req_idx] = True
prompt_lens_np[req_idx] = prompt_len
topk_keep_np[req_idx] = compute_prompt_topk_keep_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
self.kv_compression_prompt_topk_keep_max = int(
topk_keep_np[:num_reqs].max()) if num_reqs > 0 else 0
self.kv_compression_needs_compaction = False
else:
must_keep_np = self.kv_compression_must_keep_np[
:total_num_scheduled_tokens]
must_keep_np.fill(False)
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs]
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(
cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64)
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are
# unconditionally kept and there is no Top-K budget, KV
# compaction is a no-op and we can skip score/topk/dst entirely
# in the attention backend.
self.kv_compression_needs_compaction = (
(not must_keep_np.all()) or (topk_budget_np > 0).any())
else:
self.kv_compression_needs_compaction = False
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
......@@ -769,6 +941,29 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
if use_kv_compression:
if self.scheduler_config.chunked_prefill_enabled:
self.kv_compression_prompt_end[:num_reqs].copy_(
self.kv_compression_prompt_end_cpu[:num_reqs],
non_blocking=True,
)
self.kv_compression_prompt_lens[:num_reqs].copy_(
self.kv_compression_prompt_lens_cpu[:num_reqs],
non_blocking=True,
)
self.kv_compression_prompt_topk_keep[:num_reqs].copy_(
self.kv_compression_prompt_topk_keep_cpu[:num_reqs],
non_blocking=True,
)
elif self.kv_compression_needs_compaction:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=True,
)
self.kv_compression_topk_budget[:num_reqs].copy_(
self.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=True,
)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
......@@ -3096,6 +3291,120 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
arange,
out=positions_np)
# KV positions (where the KV for each scheduled token is temporarily
# written). When KV compression is enabled, KV positions are decoupled
# from logical positions.
use_kv_compression = envs.VLLM_ENABLE_KV_COMPRESSION
if use_kv_compression:
kv_positions_np = self.kv_positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_kv_tokens_cpu[req_indices],
arange,
out=kv_positions_np)
else:
kv_positions_np = None
if use_kv_compression:
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
if self.scheduler_config.chunked_prefill_enabled:
# Scheme 3: with chunked prefill, defer compaction until after
# the full prompt is ingested. Otherwise, the next prefill chunk
# would attend to a truncated history and quality can collapse.
prompt_end_np = self.kv_compression_prompt_end_np[:num_reqs]
prompt_end_np.fill(False)
prompt_lens_np = self.kv_compression_prompt_lens_np[:num_reqs]
prompt_lens_np.fill(0)
topk_keep_np = self.kv_compression_prompt_topk_keep_np[:num_reqs]
topk_keep_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
base_pos = int(
self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
ends_prompt = (base_pos < prompt_len) and (end_pos >=
prompt_len)
if not ends_prompt:
continue
prompt_end_np[req_idx] = True
prompt_lens_np[req_idx] = prompt_len
topk_keep_np[req_idx] = compute_prompt_topk_keep_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
self.kv_compression_prompt_topk_keep_max = int(
topk_keep_np[:num_reqs].max()) if num_reqs > 0 else 0
self.kv_compression_needs_compaction = False
else:
must_keep_np = self.kv_compression_must_keep_np[
:total_num_scheduled_tokens]
must_keep_np.fill(False)
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs]
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(
cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(
self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64)
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are
# unconditionally kept and there is no Top-K budget, KV
# compaction is a no-op and we can skip score/topk/dst entirely
# in the attention backend.
self.kv_compression_needs_compaction = (
(not must_keep_np.all()) or (topk_budget_np > 0).any())
else:
self.kv_compression_needs_compaction = False
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
......@@ -3122,6 +3431,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id]
slot_positions_np = (kv_positions_np
if use_kv_compression else positions_np)
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
......@@ -3130,11 +3441,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
slot_positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
block_offsets = slot_positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
......@@ -3144,7 +3455,12 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.seq_lens_np[:num_reqs] = (
if use_kv_compression:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_kv_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
else:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
......@@ -3166,6 +3482,29 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
if use_kv_compression:
if self.scheduler_config.chunked_prefill_enabled:
self.kv_compression_prompt_end[:num_reqs].copy_(
self.kv_compression_prompt_end_cpu[:num_reqs],
non_blocking=True,
)
self.kv_compression_prompt_lens[:num_reqs].copy_(
self.kv_compression_prompt_lens_cpu[:num_reqs],
non_blocking=True,
)
self.kv_compression_prompt_topk_keep[:num_reqs].copy_(
self.kv_compression_prompt_topk_keep_cpu[:num_reqs],
non_blocking=True,
)
elif self.kv_compression_needs_compaction:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=True,
)
self.kv_compression_topk_budget[:num_reqs].copy_(
self.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=True,
)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
......@@ -3742,4 +4081,4 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if envs.VLLM_USE_ZERO_MTP:
GPUModelRunner=GPUModelRunnerMTP
else:
GPUModelRunner=GPUModelRunnerBase
\ No newline at end of file
GPUModelRunner=GPUModelRunnerBase
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment