# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from typing import Any import torch import vllm.envs as envs from vllm.forward_context import get_forward_context from vllm.platforms import current_platform from vllm.v1.kv_cache_interface import ( AttentionSpec, CrossAttentionSpec, UniformTypeKVCacheSpecs, ) from vllm.v1.kv_compression.forward_context import get_kv_compression_prompt_payload def stash_kv_compression_prompt_payload_to_requests(*, runner: Any) -> None: """Persist prompt-end compaction indices from the forward context. This is the runner-side half of chunked-prefill scheme 3: flash_attn -> forward_context payload -> request state stash -> (next step) one-shot KV compaction. """ if not envs.VLLM_ENABLE_KV_COMPRESSION: return scheduler_config = getattr(runner, "scheduler_config", None) if scheduler_config is None or not getattr(scheduler_config, "enable_chunked_prefill", False): return forward_context = get_forward_context() payload = get_kv_compression_prompt_payload(forward_context) if payload is None: return req_indices = payload.get("req_indices") idx_sorted = payload.get("idx_sorted") keep_len = payload.get("keep_len") prompt_lens = payload.get("prompt_lens") if (req_indices is None or idx_sorted is None or keep_len is None or prompt_lens is None): return input_batch = getattr(runner, "input_batch", None) if input_batch is None: return req_ids = getattr(input_batch, "req_ids", None) if req_ids is None: return requests = getattr(runner, "requests", None) if requests is None: return req_indices_cpu = req_indices.to(device="cpu", dtype=torch.int64).tolist() keep_cpu = keep_len.to(device="cpu", dtype=torch.int64).tolist() prompt_cpu = prompt_lens.to(device="cpu", dtype=torch.int64).tolist() for i, b in enumerate(req_indices_cpu): if b < 0 or b >= len(req_ids): continue req_id = req_ids[b] if req_id is None: continue rs = requests.get(req_id) if rs is None: continue rs.kv_compression_prompt_idx_sorted = idx_sorted[i] rs.kv_compression_prompt_keep_len = int(keep_cpu[i]) rs.kv_compression_prompt_prompt_len = int(prompt_cpu[i]) def maybe_apply_kv_compression_prompt_compaction(*, runner: Any) -> None: """Apply one-shot prompt KV compaction before the first decode step.""" if not envs.VLLM_ENABLE_KV_COMPRESSION: return if not current_platform.is_cuda_alike(): return scheduler_config = getattr(runner, "scheduler_config", None) if scheduler_config is None or not getattr(scheduler_config, "enable_chunked_prefill", False): return input_batch = getattr(runner, "input_batch", None) if input_batch is None: return requests = getattr(runner, "requests", None) if requests is None: return pending_req_ids: list[str] = [] for req_id in input_batch.req_ids: if req_id is None: continue rs = requests.get(req_id) if rs is None: continue if rs.kv_compression_prompt_idx_sorted is None: continue # Only apply once the prompt is fully ingested (decode stage). if rs.num_computed_tokens < rs.num_prompt_tokens: continue pending_req_ids.append(req_id) if not pending_req_ids: return device = runner.device pending_states: list[tuple[str, torch.Tensor, int]] = [] for req_id in pending_req_ids: rs = requests[req_id] keep = rs.kv_compression_prompt_keep_len idx = rs.kv_compression_prompt_idx_sorted if keep is None or idx is None: continue keep_i = int(keep) if keep_i <= 0: # No prompt tokens kept; clear and skip. rs.kv_compression_prompt_idx_sorted = None rs.kv_compression_prompt_keep_len = None rs.kv_compression_prompt_prompt_len = None continue pending_states.append((req_id, idx, keep_i)) if not pending_states: return B = len(pending_states) keep_list = [k for _, _, k in pending_states] K_max = max(keep_list) idx_batch = torch.zeros((B, K_max), device=device, dtype=torch.int32) for i, (_, row, k) in enumerate(pending_states): idx_batch[i, :k] = row[:k].to(device=device, dtype=torch.int32) keep_tensor = torch.tensor(keep_list, device=device, dtype=torch.int32) from vllm.v1.kv_compression.kv_cache_triton import ( front_compact_inplace_fa_triton, make_fa_cache_view) kv_cache_config = getattr(runner, "kv_cache_config", None) if kv_cache_config is None: return # Apply compaction to every attention layer's KV cache in-place. for group_id, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): max_blocks = 0 for req_id, _, _ in pending_states: rs = requests[req_id] if group_id >= len(rs.block_ids): continue max_blocks = max(max_blocks, len(rs.block_ids[group_id])) if max_blocks == 0: continue block_table_cpu = torch.zeros((B, max_blocks), dtype=torch.int32, device="cpu") for i, (req_id, _, _) in enumerate(pending_states): rs = requests[req_id] if group_id >= len(rs.block_ids): continue ids = rs.block_ids[group_id] if ids: block_table_cpu[i, :len(ids)] = torch.tensor(ids, dtype=torch.int32, device="cpu") block_table = block_table_cpu.to(device=device, non_blocking=True) static_forward_context = getattr( getattr(runner, "compilation_config", None), "static_forward_context", None, ) if static_forward_context is None: continue seen_cache_ptrs: set[int] = set() for layer_name in kv_cache_group_spec.layer_names: # Skip non-self-attention caches (e.g., encoder/decoder cross-attn) # and non-attention cache specs (e.g., Mamba). kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): kv_cache_spec = kv_cache_spec.kv_cache_specs.get(layer_name) if kv_cache_spec is None or not isinstance(kv_cache_spec, AttentionSpec): continue if isinstance(kv_cache_spec, CrossAttentionSpec): continue layer = static_forward_context.get(layer_name) if layer is None: continue kv_cache_list = getattr(layer, "kv_cache", None) if not isinstance(kv_cache_list, list) or not kv_cache_list: continue kv_cache = kv_cache_list[0] if not current_platform.is_rocm(): if not isinstance(kv_cache, torch.Tensor): continue cache_ptr = int(kv_cache.data_ptr()) if cache_ptr in seen_cache_ptrs: continue seen_cache_ptrs.add(cache_ptr) key_cache, value_cache = kv_cache.unbind(0) else: if (not isinstance(kv_cache, (tuple, list)) or len(kv_cache) != 2): continue key_cache, value_cache = kv_cache cache_ptr = int(key_cache.data_ptr()) if cache_ptr in seen_cache_ptrs: continue seen_cache_ptrs.add(cache_ptr) k_view, v_view = make_fa_cache_view(key_cache=key_cache, value_cache=value_cache) front_compact_inplace_fa_triton( k_view, v_view, block_table, idx_batch, keep_tensor, ) # Clear pending state after successful compaction. for req_id, _, _ in pending_states: rs = requests.get(req_id) if rs is None: continue rs.kv_compression_prompt_idx_sorted = None rs.kv_compression_prompt_keep_len = None rs.kv_compression_prompt_prompt_len = None