Commit 54b1bf44 authored by laibao's avatar laibao
Browse files

feat: kvpress runner 支持 chunked prefill prompt-end 一次性 KV compaction

parent faf55520
...@@ -875,7 +875,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -875,7 +875,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# where M is the max_model_len. # where M is the max_model_len.
token_indices = (positions_np + token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1]) req_indices * self.input_batch.token_ids_cpu.shape[1])
# NOTE(woosuk): We use torch.index_select instead of np.take here # NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large # because torch.index_select is much faster than np.take for large
# tensors. # tensors.
...@@ -1571,6 +1570,154 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1571,6 +1570,154 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
finished_recving=finished_recving, finished_recving=finished_recving,
) )
def _stash_kv_compression_prompt_payload(self) -> None:
"""Persist prompt-end compaction indices from the forward context."""
if (not envs.VLLM_ENABLE_KV_COMPRESSION
or not self.scheduler_config.chunked_prefill_enabled):
return
forward_context = get_forward_context()
payload = getattr(forward_context, "_kv_compression_prompt_payload",
None)
if payload is None:
return
req_indices = payload.get("req_indices")
idx_sorted = payload.get("idx_sorted")
keep_len = payload.get("keep_len")
prompt_lens = payload.get("prompt_lens")
if (req_indices is None or idx_sorted is None or keep_len is None
or prompt_lens is None):
return
req_indices_cpu = req_indices.to(device="cpu",
dtype=torch.int64).tolist()
keep_cpu = keep_len.to(device="cpu", dtype=torch.int64).tolist()
prompt_cpu = prompt_lens.to(device="cpu", dtype=torch.int64).tolist()
for i, b in enumerate(req_indices_cpu):
if b < 0 or b >= len(self.input_batch.req_ids):
continue
req_id = self.input_batch.req_ids[b]
if req_id is None:
continue
rs = self.requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = idx_sorted[i]
rs.kv_compression_prompt_keep_len = int(keep_cpu[i])
rs.kv_compression_prompt_prompt_len = int(prompt_cpu[i])
def _maybe_apply_kv_compression_prompt_compaction(self) -> None:
"""Apply one-shot prompt KV compaction before the first decode step."""
if (not envs.VLLM_ENABLE_KV_COMPRESSION
or not self.scheduler_config.chunked_prefill_enabled):
return
pending_req_ids: list[str] = []
for req_id in self.input_batch.req_ids:
if req_id is None:
continue
rs = self.requests.get(req_id)
if rs is None:
continue
if rs.kv_compression_prompt_idx_sorted is None:
continue
# Only apply once the prompt is fully ingested (decode stage).
if rs.num_computed_tokens < rs.num_prompt_tokens:
continue
pending_req_ids.append(req_id)
if not pending_req_ids:
return
device = self.device
pending_states: list[tuple[str, torch.Tensor, int]] = []
for req_id in pending_req_ids:
rs = self.requests[req_id]
keep = rs.kv_compression_prompt_keep_len
idx = rs.kv_compression_prompt_idx_sorted
if keep is None or idx is None:
continue
keep_i = int(keep)
if keep_i <= 0:
# No prompt tokens kept; clear and skip.
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
continue
pending_states.append((req_id, idx, keep_i))
if not pending_states:
return
B = len(pending_states)
keep_list = [k for _, _, k in pending_states]
K_max = max(keep_list)
idx_batch = torch.zeros((B, K_max), device=device, dtype=torch.int32)
for i, (_, row, k) in enumerate(pending_states):
idx_batch[i, :k] = row[:k].to(device=device, dtype=torch.int32)
keep_tensor = torch.tensor(keep_list, device=device, dtype=torch.int32)
from vllm.v1.attention.kv_compression.kv_cache_triton import (
front_compact_inplace_fa_triton, make_fa_cache_view)
# Apply compaction to every attention layer's KV cache in-place.
for group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
max_blocks = 0
for req_id, _, _ in pending_states:
rs = self.requests[req_id]
if group_id >= len(rs.block_ids):
continue
max_blocks = max(max_blocks, len(rs.block_ids[group_id]))
if max_blocks == 0:
continue
block_table_cpu = torch.zeros((B, max_blocks),
dtype=torch.int32,
device="cpu")
for i, (req_id, _, _) in enumerate(pending_states):
rs = self.requests[req_id]
if group_id >= len(rs.block_ids):
continue
ids = rs.block_ids[group_id]
if ids:
block_table_cpu[i, :len(ids)] = torch.tensor(
ids, dtype=torch.int32, device="cpu")
block_table = block_table_cpu.to(device=device, non_blocking=True)
for layer_name in kv_cache_group_spec.layer_names:
layer_index = self._extract_layer_index(layer_name)
if layer_index >= len(self.kv_caches):
continue
kv_cache = self.kv_caches[layer_index]
if not current_platform.is_rocm():
if not isinstance(kv_cache, torch.Tensor):
continue
key_cache, value_cache = kv_cache.unbind(0)
else:
if (not isinstance(kv_cache, (tuple, list))
or len(kv_cache) != 2):
continue
key_cache, value_cache = kv_cache
k_view, v_view = make_fa_cache_view(key_cache=key_cache,
value_cache=value_cache)
front_compact_inplace_fa_triton(
k_view,
v_view,
block_table,
idx_batch,
keep_tensor,
)
# Clear pending state after successful compaction.
for req_id, _, _ in pending_states:
rs = self.requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -1667,7 +1814,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1667,7 +1814,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS: # Chunked prefill (scheme 3): apply one-shot prompt KV compaction before
# the first decode step writes/reads KV at the compressed positions.
self._maybe_apply_kv_compression_prompt_compaction()
use_tbo = (envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens
>= envs.VLLM_TBO_MIN_TOKENS)
if (use_tbo and envs.VLLM_ENABLE_KV_COMPRESSION
and self.scheduler_config.chunked_prefill_enabled):
# NOTE: the TBO path does not call `_stash_kv_compression_prompt_payload`
# inside its `set_forward_context`, so scheme-3 prompt-end payloads
# would be dropped and the next-step compaction would never run.
logger.warning_once(
"TBO is currently incompatible with chunked prefill KV "
"compression (scheme 3); running without TBO.")
use_tbo = False
if use_tbo:
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions, num_tokens_across_dp, input_ids, positions,
...@@ -1694,6 +1857,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1694,6 +1857,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
self._stash_kv_compression_prompt_payload()
self.maybe_wait_for_kv_save() self.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = (
...@@ -1719,6 +1883,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1719,6 +1883,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
self._stash_kv_compression_prompt_payload()
self.maybe_wait_for_kv_save() self.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = (
...@@ -3686,7 +3851,21 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3686,7 +3851,21 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
# Chunked prefill (scheme 3): apply one-shot prompt KV compaction before
# the first decode step writes/reads KV at the compressed positions.
self._maybe_apply_kv_compression_prompt_compaction()
use_tbo = (envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens
>= envs.VLLM_TBO_MIN_TOKENS)
if (use_tbo and envs.VLLM_ENABLE_KV_COMPRESSION
and self.scheduler_config.chunked_prefill_enabled):
logger.warning_once(
"TBO is currently incompatible with chunked prefill KV "
"compression (scheme 3); running without TBO.")
use_tbo = False
if use_tbo:
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions, num_tokens_across_dp, input_ids, positions,
...@@ -3713,6 +3892,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3713,6 +3892,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
self._stash_kv_compression_prompt_payload()
self.maybe_wait_for_kv_save() self.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = (
...@@ -3738,6 +3918,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3738,6 +3918,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
self._stash_kv_compression_prompt_payload()
self.maybe_wait_for_kv_save() self.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment