# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """KV-pruning (compactor) path invoked from :meth:`vllm.entrypoints.llm.LLM.generate`.""" from __future__ import annotations import os from collections.abc import Callable, Sequence from pathlib import Path from typing import Any from tqdm.auto import tqdm from transformers import AutoTokenizer from vllm.kvprune.compression.compression_config import ( BatchCompressionParams, SequenceCompressionParams, ) from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams from vllm.kvprune.core.compression_bridge import ( compression_method_id_to_enum, compression_method_str_to_id, ) from vllm.kvprune.core.llm_engine import LLMEngine, _infer_stop_token_ids from vllm.kvprune.integration.compactor_shared import create_compactor_engine_with_shared_weights from vllm.kvprune.integration.compression_params import CompressionParams from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams logger = init_logger(__name__) _MP_ENV = "VLLM_ENABLE_V1_MULTIPROCESSING" _RELEASE_V1_KV_ENV = "VLLM_KVPRUNE_RELEASE_V1_KV" def _maybe_release_v1_kv_for_compactor(llm: Any) -> None: """Optionally discard v1's KV cache so more GPU memory is free for compactor. v1 reserves KV blocks at engine init; shared-weight compactor then competes for the same VRAM. ``sleep(level=1)`` discards v1 KV and may offload tagged weights per v1 sleep policy, then ``wake_up()`` reloads — compactor still ties the same v1 tensors after. **Default:** ``vllm.env_override`` sets ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` (no sleep/wake; v1 KV stays on GPU). Set ``=1`` if you need extra VRAM for compactor before the first compressed step (then ``llm.sleep`` / ``CuMemAllocator`` / ``Sleep mode freed …`` logs are expected). This does **not** remove v1's KV reservation at init; it only runs the optional sleep/wake cycle before compactor. Tests keep ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` in ``conftest``. """ if os.environ.get(_RELEASE_V1_KV_ENV, "0").strip().lower() not in ( "1", "true", "yes", ): return try: logger.info( "%s=1: discarding v1 KV via sleep(level=1) then wake_up() " "(reloads model weights to GPU).", _RELEASE_V1_KV_ENV, ) llm.sleep(level=1, mode="abort") llm.wake_up() except Exception as e: logger.warning("%s: sleep/wake failed: %s", _RELEASE_V1_KV_ENV, e) def ensure_inprocess_engine_for_weight_sharing() -> None: """Compactor must see ``worker.get_model()`` in the same process as vLLM.""" if os.environ.get(_MP_ENV, "1") != "0": os.environ[_MP_ENV] = "0" logger.info( "KV cache pruning: set %s=0 so the model stays in-process for " "shared-weight compactor (no manual env needed).", _MP_ENV, ) def _normalize_prompt_list(prompts: Any) -> list[Any]: if isinstance(prompts, str): return [prompts] if isinstance(prompts, dict): return [prompts] return list(prompts) def _normalize_sampling_params( sampling_params: SamplingParams | Sequence[SamplingParams] | None, n: int, ) -> list[SamplingParams]: if sampling_params is None: return [SamplingParams() for _ in range(n)] if isinstance(sampling_params, SamplingParams): return [sampling_params] * n sps = list(sampling_params) if len(sps) != n: raise ValueError( f"sampling_params length {len(sps)} != prompts length {n}" ) return sps def _normalize_compression_params( compression: CompressionParams | Sequence[CompressionParams] | None, n: int, ) -> list[CompressionParams]: if compression is None: return [CompressionParams(compression_ratio=1.0) for _ in range(n)] if isinstance(compression, CompressionParams): return [compression] * n comp = list(compression) if len(comp) != n: raise ValueError(f"compression length {len(comp)} != prompts length {n}") return comp def _any_compactor(comps: list[CompressionParams]) -> bool: return any(c.compression_ratio < 1.0 for c in comps) _FORCE_COMPACTOR_PATH_ENV = "VLLM_KVPRUNE_FORCE_COMPACTOR_PATH" def _should_use_kvprune_compactor_path(comps: list[CompressionParams]) -> bool: """Use integrated compactor when any prompt requests compression, or when forced. If all ``compression_ratio >= 1.0``, the default is to return ``None`` from :func:`try_compressed_generate` and fall back to the standard v1 engine (``Processed prompts`` loop). That hides TP/kvprune bugs behind a different code path. Set ``VLLM_KVPRUNE_FORCE_COMPACTOR_PATH=1`` to run the same compactor + collective RPC path as compression-on, with no KV pruning. """ if _any_compactor(comps): return True return os.environ.get(_FORCE_COMPACTOR_PATH_ENV, "").strip().lower() in ( "1", "true", "yes", ) def _to_compactor_sampling(sp: SamplingParams) -> CompactorSamplingParams: mt = sp.max_tokens if mt is None: mt = 16 return CompactorSamplingParams( temperature=float(sp.temperature), max_new_tokens=int(mt), ) def _to_sequence_compression(cp: CompressionParams) -> SequenceCompressionParams: return SequenceCompressionParams( compression_ratio=float(cp.compression_ratio), protected_first_tokens=int(cp.protected_first_tokens), protected_last_tokens=int(cp.protected_last_tokens), ) def _batch_compression_from_comps(comps: list[CompressionParams]) -> BatchCompressionParams: for c in comps: if c.compression_ratio < 1.0: mid = compression_method_str_to_id(c.compression_method) return BatchCompressionParams( compression_method=compression_method_id_to_enum(mid) ) return BatchCompressionParams() def _kvprune_compactor_hf_tokenizer(llm: Any): """HF tokenizer matching :meth:`vllm.kvprune.core.llm_engine.LLMEngine.__init__`. Loads from the **resolved on-disk** model tree (local dir or HF cache snapshot), not the bare repo id, to avoid redundant Hub downloads. """ cached = getattr(llm, "_kvprune_compactor_hf_tokenizer", None) if cached is not None: return cached mc = llm.llm_engine.vllm_config.model_config model_s = str(mc.model) src = model_s try: p = Path(model_s) if p.is_dir() and (p / "config.json").is_file(): src = str(p.resolve()) else: from huggingface_hub import snapshot_download src = snapshot_download(repo_id=model_s, local_files_only=False) except Exception: src = model_s hf_cfg = mc.hf_config _trust = bool(getattr(hf_cfg, "trust_remote_code", False)) if hf_cfg is not None else False tok = AutoTokenizer.from_pretrained(src, use_fast=True, trust_remote_code=_trust) llm._kvprune_compactor_hf_tokenizer = tok return tok def _prompt_to_compactor_input(prompt: Any) -> str | list[int]: if isinstance(prompt, str): return prompt # Decoder-only `list[int]` token ids (see `vllm.inputs.PromptType`). if isinstance(prompt, list): if not prompt: raise TypeError("Empty token-id prompt is not supported for compactor path.") if all(isinstance(t, int) for t in prompt): return list(prompt) if isinstance(prompt, dict): if "prompt_token_ids" in prompt: ids = prompt["prompt_token_ids"] return list(ids) if not isinstance(ids, list) else ids p = prompt.get("prompt") if isinstance(p, str): return p raise TypeError( f"Unsupported prompt type for compactor path: {type(prompt)}. " "Use str, list[int] token ids, or dict with 'prompt_token_ids' or 'prompt'." ) def _prompt_to_token_ids_for_tp(llm: Any, prompt: Any) -> list[int]: """Driver-side token ids for the TP collective path (same tokenizer as vLLM ``LLM``).""" comp_in = _prompt_to_compactor_input(prompt) if isinstance(comp_in, str): return llm.get_tokenizer().encode(comp_in) return list(comp_in) def _compressed_generate_tp_collective( llm: Any, plist: list[Any], sps: list[SamplingParams], comps: list[CompressionParams], ) -> list[RequestOutput]: """TP>1: run compactor on each worker via ``collective_rpc`` (all ranks).""" vc = llm.llm_engine.vllm_config pc = vc.parallel_config if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1: raise NotImplementedError( "KV-prune TP compression requires pipeline_parallel_size=1 and " f"data_parallel_size=1 (got PP={pc.pipeline_parallel_size}, " f"DP={pc.data_parallel_size})." ) hf = vc.model_config.hf_config tok = llm.get_tokenizer() eos_token_ids = _infer_stop_token_ids(tok, hf) prompt_token_ids = [_prompt_to_token_ids_for_tp(llm, p) for p in plist] max_len = int(vc.model_config.max_model_len) for i, ids in enumerate(prompt_token_ids): if len(ids) > max_len: raise ValueError( f"KV-prune TP compressed generate: prompt {i} length {len(ids)} " f"exceeds max_model_len ({max_len}). Shorten the prompt or raise " "max_model_len when constructing LLM()." ) # Payload must be picklable for multiproc/Ray RPC: do not pass multiprocessing # synchronization primitives (workers are separate processes). payload: dict[str, Any] = { "eos_token_ids": eos_token_ids, "prompt_token_ids": prompt_token_ids, "sampling_params": [ { "temperature": float(sp.temperature), "max_new_tokens": int(sp.max_tokens if sp.max_tokens is not None else 16), } for sp in sps ], "compression_params": [ { "compression_ratio": float(c.compression_ratio), "compression_method": str(c.compression_method), "protected_first_tokens": int(c.protected_first_tokens), "protected_last_tokens": int(c.protected_last_tokens), } for c in comps ], } _maybe_release_v1_kv_for_compactor(llm) try: results = llm.llm_engine.collective_rpc( "kvprune_v1_compressed_generate", args=(payload,), ) except RuntimeError as e: if "cancelled" in str(e).lower(): raise RuntimeError( "collective_rpc was cancelled (a GPU worker likely crashed). " "Scroll up for the first worker traceback — often NCCL/CUDA before " "TCPStore/Broken pipe on the driver." ) from e raise master: dict[str, Any] | None = None for r in results: if isinstance(r, dict) and r.get("tensor_parallel_rank") == 0: master = r break if master is None: raise RuntimeError( "collective_rpc did not return a dict from tensor parallel rank 0." ) return _tp_payload_to_request_outputs(llm, master) def _tp_payload_to_request_outputs(llm: Any, master: dict[str, Any]) -> list[RequestOutput]: tok = llm.get_tokenizer() out: list[RequestOutput] = [] pids_list = master["prompt_token_ids"] cids_list = master["completion_token_ids"] for i, (pids, cids) in enumerate(zip(pids_list, cids_list)): text = tok.decode(cids, skip_special_tokens=True) # Match ``_sequences_to_request_outputs``: if decode is only special tokens, # skip_special_tokens=True yields blank text while token list is non-empty. if not text.strip() and cids: text = tok.decode(cids, skip_special_tokens=False) co = CompletionOutput( index=0, text=text, token_ids=list(cids), cumulative_logprob=None, logprobs=None, finish_reason="stop", ) ro = RequestOutput( request_id=f"kvprune-tp-{i}", prompt=None, prompt_token_ids=list(pids), prompt_logprobs=None, outputs=[co], finished=True, ) out.append(ro) return out def _ensure_compactor_engine(llm: Any) -> LLMEngine: if llm._kvprune_compactor_engine is None: pc = llm.llm_engine.vllm_config.parallel_config if pc.tensor_parallel_size != 1: raise ValueError( "KV-pruning compactor path requires tensor_parallel_size=1 " "for shared weights." ) llm._kvprune_compactor_engine = create_compactor_engine_with_shared_weights(llm) logger.info("Initialized compactor LLMEngine with weights shared from vLLM.") return llm._kvprune_compactor_engine def try_compressed_generate( llm: Any, prompts: Any, sampling_params: SamplingParams | Sequence[SamplingParams] | None, *, compression: CompressionParams | Sequence[CompressionParams] | None, use_tqdm: bool | Callable[..., tqdm] = True, lora_request: Any = None, priority: list[int] | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> list[RequestOutput] | None: """Return completions on the compactor engine, or ``None`` to use normal v1. ``lora_request`` / ``priority`` / ``tokenization_kwargs`` are accepted for API parity with :meth:`~vllm.entrypoints.llm.LLM.generate` but are not passed to the compactor engine yet. """ del lora_request, priority, tokenization_kwargs, use_tqdm plist = _normalize_prompt_list(prompts) sps = _normalize_sampling_params(sampling_params, len(plist)) comps = _normalize_compression_params(compression, len(plist)) pc = llm.llm_engine.vllm_config.parallel_config if not _should_use_kvprune_compactor_path(comps): return None v1_eager = bool( getattr(llm.llm_engine.vllm_config.model_config, "enforce_eager", False) ) if not v1_eager: logger.warning( "KV-prune compression: v1 CUDA graphs are still enabled on this LLM. " "The compactor does not reuse v1 graphs; capture wastes VRAM. " "Set enforce_eager=True on LLM() if you need to avoid the extra " "v1 graph capture overhead for compressed generation." ) if pc.tensor_parallel_size > 1: return _compressed_generate_tp_collective(llm, plist, sps, comps) ensure_inprocess_engine_for_weight_sharing() if llm._kvprune_compactor_engine is None: _maybe_release_v1_kv_for_compactor(llm) engine = _ensure_compactor_engine(llm) comp_sp = [_to_compactor_sampling(sp) for sp in sps] seq_c = [_to_sequence_compression(c) for c in comps] batch_c = _batch_compression_from_comps(comps) comp_in = [_prompt_to_compactor_input(p) for p in plist] _, seqs = engine.generate( comp_in, sampling_params=comp_sp, batch_compression_params=batch_c, per_sequence_compression_params=seq_c, return_sequences=True, ) return _sequences_to_request_outputs(seqs, engine) def _sequences_to_request_outputs(seqs: list[Any], engine: LLMEngine) -> list[RequestOutput]: tok = engine.tokenizer out: list[RequestOutput] = [] for i, seq in enumerate(seqs): text = tok.decode(seq.completion_token_ids, skip_special_tokens=True) # If every emitted id is “special” (e.g. EOS / chat boundary), the stripped # string is empty while ``completion_token_ids`` is non-empty — avoid # presenting a blank answer so users can see boundary tokens / debug. if not text.strip() and seq.completion_token_ids: text = tok.decode(seq.completion_token_ids, skip_special_tokens=False) co = CompletionOutput( index=0, text=text, token_ids=list(seq.completion_token_ids), cumulative_logprob=None, logprobs=None, finish_reason="stop", ) ro = RequestOutput( request_id=f"kvprune-{i}", prompt=None, prompt_token_ids=list(seq.prompt_token_ids), prompt_logprobs=None, outputs=[co], finished=True, ) out.append(ro) return out