"vllm/vscode:/vscode.git/clone" did not exist on "95befecc184778b28c5251d8a2699be8622b683f"
Commit 2b7160c6 authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.0

parent fa718036
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Build :class:`vllm.kvprune.config.engine_config.LLMConfig` from :class:`VllmConfig`."""
from __future__ import annotations
import os
from pathlib import Path
from vllm.config import VllmConfig
from vllm.kvprune.config.engine_config import LLMConfig, KvpruneAttentionSchedule
from vllm.logger import init_logger
logger = init_logger(__name__)
def _attention_schedule_from_env() -> KvpruneAttentionSchedule:
"""Resolve :class:`KvpruneAttentionSchedule` from env.
Primary (``VLLM_KVPRUNE_ATTENTION_SCHEDULE``):
- ``fa_triton`` — FA prefill, Triton decode (default). Aliases: ``fa_prefill``,
``default``, empty.
- ``pdtriton`` — Triton prefill + Triton decode. Aliases: ``triton``,
``triton_prefill``, ``compactor_prefill``, ``pd_triton``.
- ``pdfa`` — FA prefill + FA decode (KV stores still Triton). Aliases:
``fa_full``, ``fa_both``.
Legacy: ``VLLM_KVPRUNE_ATTENTION_BACKEND`` maps ``flash``/``fa`` → ``fa_triton``,
``compactor``/``triton`` → ``pdtriton``.
"""
s = os.environ.get("VLLM_KVPRUNE_ATTENTION_SCHEDULE", "").strip().lower()
if s in ("fa_triton", "fa_prefill", "default", ""):
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
if s in ("pdtriton", "pd_triton", "triton", "triton_prefill", "compactor_prefill"):
return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
if s in ("pdfa", "fa_full", "fa_both"):
return KvpruneAttentionSchedule.PDFA
if s:
logger.warning(
"Unknown VLLM_KVPRUNE_ATTENTION_SCHEDULE=%r; using FA_PREFILL_TRITON_DECODE",
s,
)
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
v = os.environ.get("VLLM_KVPRUNE_ATTENTION_BACKEND", "").strip().lower()
if v in ("flash", "fa", "flash_attention", "flashattention"):
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
if v in ("compactor", "triton", "compactor_triton", ""):
return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
logger.warning(
"Unknown VLLM_KVPRUNE_ATTENTION_BACKEND=%r; using FA_PREFILL_TRITON_DECODE", v
)
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
def _compactor_kvcache_page_size(vllm_block_size: int | None) -> int:
"""Tokens per physical KV page for compactor :class:`LLMConfig`.
vLLM ``block_size`` is often 16; compactor ``head_sparse_decode_attention`` requires
``PAGE_SIZE % 32 == 0`` (see ``kvprune/attention/sparse_decode_kernel.py``). Standalone
compactor-vllm defaults to 128. Round up to the next multiple of 32 when needed.
"""
if vllm_block_size is None:
return 128
bs = int(vllm_block_size)
if bs <= 0:
return 128
if bs % 32 == 0:
return bs
return ((bs + 31) // 32) * 32
def vllm_config_to_llm_config(vc: VllmConfig) -> LLMConfig:
"""Map vLLM engine config to compactor :class:`LLMConfig`."""
mc = vc.model_config
cc = vc.cache_config
pc = vc.parallel_config
sc = vc.scheduler_config
block_size = cc.block_size
if block_size is None:
block_size = 16
max_num_seqs = getattr(sc, "max_num_seqs", 256)
# Do **not** forward ``model_config.enforce_eager`` (v1) into compactor
# :class:`LLMConfig`. They are independent flags: v1 uses it only to skip
# *v1* ``capture_model()``; kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner`
# uses :attr:`LLMConfig.enforce_eager` only for *compactor* decode CUDA graphs.
# Shared-weight setup in ``compactor_shared`` defaults compactor to eager decode;
# see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` (default try graphs) /
# ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``.
# Local checkpoint directory: forward so compactor skips redundant Hub fetches.
_model_s = str(mc.model)
_path: str | None = None
try:
if _model_s and Path(_model_s).is_dir() and (Path(_model_s) / "config.json").is_file():
_path = str(Path(_model_s).resolve())
except OSError:
pass
return LLMConfig(
model=_model_s,
path=_path,
nccl_port=1218,
max_num_seqs=max_num_seqs,
max_model_len=mc.max_model_len,
gpu_memory_utilization=cc.gpu_memory_utilization,
tensor_parallel_size=pc.tensor_parallel_size,
enforce_eager=False,
hf_config=mc.hf_config,
eos=-1,
eos_token_ids=None,
kvcache_page_size=_compactor_kvcache_page_size(block_size),
leverage_sketch_size=48,
attention_schedule=_attention_schedule_from_env(),
attention_backend=None,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TP>1: one kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner` per vLLM worker.
Invoked via v1 ``collective_rpc("kvprune_v1_compressed_generate", ...)`` so every tensor-
parallel rank participates in the same compactor forward/broadcast sequence as the
standalone multi-process compactor.
Compactor decode CUDA graphs (when not ``enforce_eager``) capture the full decode step
including ``compute_logits``. To force eager on embedded TP workers, set
``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` or ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``.
Peer/master session boundaries use TP-group ``broadcast``/``barrier`` (see
``ModelRunner.maybe_release_peers``), not ``multiprocessing.Event`` — RPC payloads must
be picklable across worker processes.
"""
from __future__ import annotations
import os
from typing import Any
import torch
import torch.nn as nn
from vllm.kvprune.compression.compression_config import (
BatchCompressionParams,
SequenceCompressionParams,
)
from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams
from vllm.kvprune.core.compression_bridge import (
compression_method_id_to_enum,
compression_method_str_to_id,
)
from vllm.kvprune.core.model_runner import ModelRunner
from vllm.kvprune.integration.config_adapter import vllm_config_to_llm_config
from vllm.kvprune.utils.kv_dist import barrier_sync
from vllm.kvprune.integration.weight_tie import (
delegate_kvprune_compute_logits_to_vllm,
delegate_kvprune_embed_tokens_to_vllm,
tie_kvprune_rope_buffers_from_vllm,
tie_kvprune_weights_from_vllm,
)
from vllm.kvprune.models import MODEL_REGISTRY
from vllm.kvprune.utils.sequence import Sequence
_ATTR = "_kvprune_tp_embedded_runner"
def _apply_compactor_env_overrides(cfg: Any) -> None:
"""Match :func:`~vllm.kvprune.integration.compactor_shared.create_compactor_engine_with_shared_weights` caps."""
_cap = os.environ.get("VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS", "32").strip()
if _cap:
lim = int(_cap)
if lim > 0:
cfg.max_num_seqs = min(cfg.max_num_seqs, lim)
_ce = os.environ.get("VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER", "").strip().lower()
if _ce in ("1", "true", "yes"):
cfg.enforce_eager = True
elif _ce in ("0", "false", "no"):
cfg.enforce_eager = False
else:
_dg = os.environ.get("VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH", "1").strip().lower()
cfg.enforce_eager = _dg in ("0", "false", "no")
def _build_sequences(payload: dict[str, Any]) -> list[Sequence]:
prompt_ids: list[list[int]] = payload["prompt_token_ids"]
sps: list[dict[str, Any]] = payload["sampling_params"]
cps: list[dict[str, Any]] = payload["compression_params"]
seqs: list[Sequence] = []
for i, ids in enumerate(prompt_ids):
sp = CompactorSamplingParams(
temperature=float(sps[i]["temperature"]),
max_new_tokens=int(sps[i]["max_new_tokens"]),
)
cp = SequenceCompressionParams(
compression_ratio=float(cps[i]["compression_ratio"]),
protected_first_tokens=int(cps[i].get("protected_first_tokens", 16)),
protected_last_tokens=int(cps[i].get("protected_last_tokens", 64)),
)
if cp.protected_first_tokens + cp.protected_last_tokens >= len(ids):
cp.compression_ratio = 1.0
seqs.append(
Sequence(
prompt_token_ids=list(ids),
sampling_params=sp,
compression_params=cp,
)
)
return seqs
def _batch_compression_from_payload(payload: dict[str, Any]) -> BatchCompressionParams:
cps = payload["compression_params"]
for c in cps:
if float(c["compression_ratio"]) < 1.0:
mid = compression_method_str_to_id(str(c.get("compression_method", "none")))
return BatchCompressionParams(
compression_method=compression_method_id_to_enum(mid)
)
return BatchCompressionParams()
def _get_or_create_runner(worker: Any, payload: dict[str, Any]) -> ModelRunner:
existing = getattr(worker, _ATTR, None)
if existing is not None:
return existing
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
vc = worker.vllm_config
pc = vc.parallel_config
if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1:
raise NotImplementedError(
"KV-prune TP compressed generate requires pipeline_parallel_size=1 and "
f"data_parallel_size=1; got PP={pc.pipeline_parallel_size}, "
f"DP={pc.data_parallel_size}."
)
tp_ws = get_tensor_model_parallel_world_size()
if tp_ws != pc.tensor_parallel_size:
raise RuntimeError(
f"parallel_state TP world size {tp_ws} != config.tensor_parallel_size "
f"{pc.tensor_parallel_size}"
)
hf = vc.model_config.hf_config
model_type = getattr(hf, "model_type", None)
if model_type not in MODEL_REGISTRY:
raise ValueError(
f"KV-prune TP path: unsupported model_type={model_type!r}; "
f"registry has {sorted(MODEL_REGISTRY)}"
)
cfg = vllm_config_to_llm_config(vc)
eos_ids = payload["eos_token_ids"]
cfg.eos_token_ids = sorted({int(x) for x in eos_ids})
cfg.eos = int(cfg.eos_token_ids[0])
_apply_compactor_env_overrides(cfg)
vllm_model = worker.get_model()
kv_model: nn.Module = MODEL_REGISTRY[model_type](hf)
tie_kvprune_weights_from_vllm(vllm_model, kv_model)
dev = next(vllm_model.parameters()).device
dtype = next(vllm_model.parameters()).dtype
kv_model.to(device=dev, dtype=dtype)
tie_kvprune_rope_buffers_from_vllm(vllm_model, kv_model)
delegate_kvprune_embed_tokens_to_vllm(vllm_model, kv_model)
delegate_kvprune_compute_logits_to_vllm(vllm_model, kv_model)
tp_rank = get_tensor_model_parallel_rank()
device = torch.device(f"cuda:{torch.cuda.current_device()}")
if tp_rank == 0:
runner = ModelRunner(
cfg,
rank=0,
peer_events=[],
external_model=kv_model,
embedded_in_vllm_worker=True,
device=device,
)
else:
runner = ModelRunner(
cfg,
rank=tp_rank,
batch_ready=None,
external_model=kv_model,
embedded_in_vllm_worker=True,
device=device,
)
setattr(worker, _ATTR, runner)
return runner
def run_kvprune_tp_compressed_generate(worker: Any, payload: dict[str, Any]) -> dict[str, Any]:
"""Execute one compressed generation session on this worker (all TP ranks)."""
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
tp_rank = get_tensor_model_parallel_rank()
runner = _get_or_create_runner(worker, payload)
sequences = _build_sequences(payload)
batch_c = _batch_compression_from_payload(payload)
barrier_sync(use_tp_group=True)
if tp_rank == 0:
runner.generate(sequences, batch_c)
return {
"tensor_parallel_rank": 0,
"prompt_token_ids": [list(s.prompt_token_ids) for s in sequences],
"completion_token_ids": [list(s.completion_token_ids) for s in sequences],
}
runner.run_peer_session()
return {"tensor_parallel_rank": int(tp_rank), "ok": True}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Access the in-process vLLM model weights for compactor weight sharing."""
from __future__ import annotations
import torch.nn as nn
from vllm.logger import init_logger
logger = init_logger(__name__)
def extract_vllm_causal_lm(llm: object) -> nn.Module:
"""Return the root ``nn.Module`` holding transformer + lm_head from a v1 ``LLM``.
Requires ``LLMEngine`` to have been constructed with ``multiprocess_mode=False``
so ``model_executor`` lives in-process (set ``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
"""
llm_engine = getattr(llm, "llm_engine", None)
if llm_engine is None:
raise RuntimeError("Expected an object with a ``llm_engine`` attribute (e.g. ``vllm.LLM``).")
ex = getattr(llm_engine, "model_executor", None)
if ex is None:
raise RuntimeError(
"model_executor is unavailable (multiprocess engine mode). "
"Set environment variable VLLM_ENABLE_V1_MULTIPROCESSING=0 for "
"in-process weight sharing."
)
driver = getattr(ex, "driver_worker", None)
if driver is None:
raise RuntimeError(
"Executor has no driver_worker (unexpected executor type for weight sharing)."
)
worker = getattr(driver, "worker", None)
if worker is None:
raise RuntimeError("Worker wrapper has no worker loaded.")
get_model = getattr(worker, "get_model", None)
if not callable(get_model):
raise RuntimeError("Worker does not expose get_model().")
return get_model()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Share vLLM parameter storage with compactor ``MODEL_REGISTRY`` models (TP=1)."""
from __future__ import annotations
import types
import torch
import torch.nn as nn
from vllm.kvprune.utils.context import get_context
from vllm.logger import init_logger
logger = init_logger(__name__)
def tie_kvprune_weights_from_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
*,
strict: bool = True,
) -> int:
"""Point compactor parameters to the same tensors as vLLM where names match.
Returns the number of parameters tied. Requires identical parameter names
and shapes for overlapping weights (typical when both stacks mirror HF
naming for the same architecture).
Args:
vllm_model: Model returned by ``worker.get_model()`` (e.g. ``Qwen3ForCausalLM``).
kvprune_model: Instance from ``vllm.kvprune.models.MODEL_REGISTRY``.
strict: If True, raise when any ``kvprune`` parameter name is missing from
``vllm_model`` or shapes differ.
"""
vd = dict(vllm_model.named_parameters())
kd = dict(kvprune_model.named_parameters())
tied = 0
for name, kp in kd.items():
if name not in vd:
if strict:
raise ValueError(
f"kvprune parameter {name!r} not found in vLLM model; "
"architecture/layout may differ (disable strict tying only "
"for expert debugging)."
)
continue
vp = vd[name]
if vp.shape != kp.shape:
raise ValueError(
f"Shape mismatch for {name}: vllm {vp.shape} vs kvprune {kp.shape}"
)
kp.data = vp.data
tied += 1
if tied == 0:
raise ValueError(
"No parameters were tied — check that vLLM and kvprune model types match "
"and use the same state_dict names."
)
logger.info("Tied %d parameters from vLLM into compactor model (shared storage).", tied)
return tied
def tie_kvprune_rope_buffers_from_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
) -> int:
"""Copy RoPE ``cos_sin_cache`` buffers from vLLM into kvprune.
:func:`tie_kvprune_weights_from_vllm` only aliases :class:`~torch.nn.Parameter`
tensors. RoPE tables live in buffers; kvprune's simplified ``RotaryEmbedding``
can disagree with vLLM's ``rope_parameters`` (YaRN, etc.). Copying
``cos_sin_cache`` after ``.to(device, dtype)`` keeps Q/K rotation aligned with
the main model.
kvprune uses layout ``[max_len, 1, rotary_dim]``; vLLM uses ``[max_len,
rotary_dim]``. The singleton dim is filled via ``unsqueeze(1)`` on the vLLM
tensor when copying.
"""
vd = dict(vllm_model.named_buffers())
copied = 0
for name, kb in kvprune_model.named_buffers():
if "cos_sin_cache" not in name:
continue
if name not in vd:
logger.warning(
"kvprune RoPE buffer %r not found in vLLM; leaving kvprune cache",
name,
)
continue
vb = vd[name]
if vb.shape == kb.shape:
kb.copy_(vb)
copied += 1
elif kb.dim() == 3 and vb.dim() == 2:
if (
kb.shape[0] != vb.shape[0]
or kb.shape[2] != vb.shape[1]
or kb.shape[1] != 1
):
raise ValueError(
f"cos_sin_cache shape mismatch for {name!r}: "
f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
)
kb.copy_(vb.unsqueeze(1))
copied += 1
else:
raise ValueError(
f"Unsupported cos_sin_cache layout for {name!r}: "
f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
)
if copied:
logger.info(
"Copied %d RoPE cos_sin_cache buffer(s) from vLLM into kvprune model.",
copied,
)
return copied
def delegate_kvprune_embed_tokens_to_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
) -> bool:
"""Use vLLM's ``model.embed_tokens`` forward for kvprune (TP-safe token→shard mapping).
Even with tied weights, kvprune's simplified contiguous
``VocabParallelEmbedding`` (``vocab_start = rank * partition``) can disagree with
vLLM's padded vocabulary and org/added shard ranges, producing invalid indices for
``F.embedding`` on non-zero TP ranks (``index_copy_`` / device-side assert).
Delegating the forward to vLLM's embedding module keeps masks and indices aligned
with the main model while parameters remain shared storage.
"""
if not hasattr(vllm_model, "model") or not hasattr(kvprune_model, "model"):
return False
vm = getattr(vllm_model.model, "embed_tokens", None)
km = getattr(kvprune_model.model, "embed_tokens", None)
if vm is None or km is None:
logger.warning(
"delegate_kvprune_embed_tokens_to_vllm: embed_tokens missing; skipped"
)
return False
def _forward(_self_unused: nn.Module, x):
return vm(x)
km.forward = types.MethodType(_forward, km)
logger.info(
"kvprune model.embed_tokens forward delegated to vLLM (correct vocab-parallel masks)."
)
return True
def delegate_kvprune_compute_logits_to_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
) -> bool:
"""Route ``kvprune_model.compute_logits`` through vLLM's ``compute_logits``.
Standalone compactor used :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`
with ``F.linear`` + TP gather. vLLM applies :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
(gather/all-gather, padded-vocab trim, quant hooks). Mismatch here commonly
produces garbage token distributions while the rest of the stack looks fine.
After weight tying, ``vllm_model.compute_logits(hidden)`` uses the same lm_head
storage as kvprune; only the *application* path matches production vLLM.
"""
if not callable(getattr(vllm_model, "compute_logits", None)):
logger.warning(
"delegate_kvprune_compute_logits_to_vllm: vLLM model has no compute_logits; skipped"
)
return False
def _compute_logits(_self: nn.Module, hidden_states):
# Match kvprune :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`:
# prefill logits are for the **last** token of each packed sequence only.
context = get_context()
if context.is_prefill and context.cu_seqlens_q is not None:
cuq = context.cu_seqlens_q
last_indices = (cuq[1:] - 1).to(torch.long)
n_tok = hidden_states.shape[0]
if n_tok > 0:
last_indices = last_indices.clamp(min=0, max=n_tok - 1)
hidden_states = hidden_states[last_indices].contiguous()
# vLLM lm_head + gather expect contiguous activations; non-contiguous views have
# caused garbage logits under TP in edge cases.
hidden_states = hidden_states.contiguous()
logits = vllm_model.compute_logits(hidden_states)
return logits
kvprune_model.compute_logits = types.MethodType(_compute_logits, kvprune_model)
return True
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Paged KV cache helpers and Triton KV store."""
from vllm.kvprune.kv_cache.store_kv_cache import (
decode_store_kv,
prefill_store_all_kv,
prefill_store_topk_kv,
)
__all__ = [
"decode_store_kv",
"prefill_store_all_kv",
"prefill_store_topk_kv",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment