Commit b6a27380 authored by laibao's avatar laibao
Browse files

feat: kvpress flash_attn 透传 KV 压缩元数据

parent 54b1bf44
......@@ -161,6 +161,16 @@ class FlashAttentionMetadata:
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# KV compression metadata for token-shared selection.
kv_compression_must_keep: Optional[torch.Tensor] = None
kv_compression_topk_budget: Optional[torch.Tensor] = None
# CPU-known max Top-K budget for this step (avoids device->host sync).
kv_compression_topk_budget_max: Optional[int] = None
# Chunked prefill: prompt-end one-shot scoring/Top-K (scheme 3).
kv_compression_prompt_end: Optional[torch.Tensor] = None # [B] bool
kv_compression_prompt_lens: Optional[torch.Tensor] = None # [B] int32
kv_compression_prompt_topk_keep: Optional[torch.Tensor] = None # [B] int32
kv_compression_prompt_topk_keep_max: Optional[int] = None
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
......@@ -268,6 +278,37 @@ class FlashAttentionMetadataBuilder(
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
kv_compression_must_keep = None
kv_compression_topk_budget = None
kv_compression_topk_budget_max: Optional[int] = None
kv_compression_prompt_end = None
kv_compression_prompt_lens = None
kv_compression_prompt_topk_keep = None
kv_compression_prompt_topk_keep_max: Optional[int] = None
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.kv_compression_needs_compaction):
kv_compression_must_keep = self.runner.kv_compression_must_keep[:
num_actual_tokens]
kv_compression_topk_budget = self.runner.kv_compression_topk_budget[:
num_reqs]
# Avoid device->host sync by reading from the CPU staging buffer.
if num_reqs > 0:
kv_compression_topk_budget_max = int(
self.runner.kv_compression_topk_budget_np[:num_reqs].max())
else:
kv_compression_topk_budget_max = 0
elif (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.scheduler_config.chunked_prefill_enabled):
# Scheme 3: compute global prompt indices only on the last prefill
# chunk (per request), and perform the actual cache compaction
# before the first decode step.
if num_reqs > 0 and self.runner.kv_compression_prompt_end_np[:num_reqs].any():
kv_compression_prompt_end = self.runner.kv_compression_prompt_end[:num_reqs]
kv_compression_prompt_lens = self.runner.kv_compression_prompt_lens[:num_reqs]
kv_compression_prompt_topk_keep = self.runner.kv_compression_prompt_topk_keep[:num_reqs]
kv_compression_prompt_topk_keep_max = int(
self.runner.kv_compression_prompt_topk_keep_max or 0)
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
......@@ -426,6 +467,13 @@ class FlashAttentionMetadataBuilder(
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
kv_compression_must_keep=kv_compression_must_keep,
kv_compression_topk_budget=kv_compression_topk_budget,
kv_compression_topk_budget_max=kv_compression_topk_budget_max,
kv_compression_prompt_end=kv_compression_prompt_end,
kv_compression_prompt_lens=kv_compression_prompt_lens,
kv_compression_prompt_topk_keep=kv_compression_prompt_topk_keep,
kv_compression_prompt_topk_keep_max=kv_compression_prompt_topk_keep_max,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment