"benchmarks/vscode:/vscode.git/clone" did not exist on "7e63ef827a6da01d510694777dba5ea5712af837"
Commit d29c39ca authored by chenzk's avatar chenzk
Browse files

vllm kvprune wo:v1.1.0

parent f81ce56b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-pruning integration: compactor ``LLMEngine`` sharing weights with :class:`~vllm.LLM`."""
from vllm.kvprune.integration.compression_params import CompressionParams
__all__ = ["CompressionParams"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Construct compactor :class:`LLMEngine` sharing weight tensors with an in-process vLLM ``LLM``."""
from __future__ import annotations
import os
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.kvprune.config.engine_config import LLMConfig
from vllm.kvprune.core.llm_engine import LLMEngine
from vllm.kvprune.integration.config_adapter import vllm_config_to_llm_config
from vllm.kvprune.integration.vllm_model_access import extract_vllm_causal_lm
from vllm.kvprune.integration.weight_tie import (
delegate_kvprune_compute_logits_to_vllm,
delegate_kvprune_embed_tokens_to_vllm,
tie_kvprune_rope_buffers_from_vllm,
tie_kvprune_weights_from_vllm,
)
from vllm.kvprune.models import MODEL_REGISTRY
from vllm.logger import init_logger
logger = init_logger(__name__)
def build_llm_config_for_compactor(vc: VllmConfig) -> LLMConfig:
"""Public helper: vLLM config → compactor :class:`LLMConfig`."""
return vllm_config_to_llm_config(vc)
def create_compactor_engine_with_shared_weights(llm: object) -> LLMEngine:
"""Single GPU, TP=1: compactor ``LLMEngine`` whose weights alias vLLM tensors.
Call after the vLLM ``LLM`` has loaded weights. Requires in-process executor
(``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
"""
llm_engine = getattr(llm, "llm_engine", None)
if llm_engine is None:
raise RuntimeError("Expected ``llm.llm_engine``.")
vc: VllmConfig = llm_engine.vllm_config
if vc.parallel_config.tensor_parallel_size != 1:
raise ValueError(
"Shared-weight compactor backend requires tensor_parallel_size=1"
)
cfg = vllm_config_to_llm_config(vc)
# ``cfg.enforce_eager`` is for the compactor ``ModelRunner`` only (decode CUDA
# graphs), not v1. v1 graph capture is controlled solely by ``LLM(...,
# enforce_eager=...)`` / ``kvprune_compression=True`` on the entrypoint ``LLM``.
# Large vLLM max_num_seqs blows up compactor page-table GPU memory; sharing the GPU
# with v1 leaves little room for metadata + KV tensors. Default cap 32 so physical
# KV pages stay usable; set VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS=0 to disable cap,
# or raise (e.g. 128) if you have VRAM headroom.
_cap = os.environ.get("VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS", "32").strip()
if _cap:
lim = int(_cap)
if lim > 0:
cfg.max_num_seqs = min(cfg.max_num_seqs, lim)
# Compactor decode graphs (``enforce_eager=False``): honored for non-shared-weight
# engines. **Shared-weight** path (below) forces ``enforce_eager=True`` after
# delegating ``compute_logits`` to vLLM unless ``VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1``.
# Opt out of graphs for non-shared runs: ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1`` or
# ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0``.
_ce = os.environ.get("VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER", "").strip().lower()
if _ce in ("1", "true", "yes"):
cfg.enforce_eager = True
logger.info(
"KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 → "
"enforce_eager=True (skip compactor decode CUDA graphs)."
)
elif _ce in ("0", "false", "no"):
cfg.enforce_eager = False
logger.info(
"KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=0 → "
"enforce_eager=False (try compactor CUDA graph capture)."
)
else:
_dg = os.environ.get(
"VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH", "1"
).strip().lower()
if _dg in ("0", "false", "no"):
cfg.enforce_eager = True
logger.info(
"KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 → "
"enforce_eager=True (skip compactor decode CUDA graphs)."
)
else:
cfg.enforce_eager = False
logger.info(
"KV-prune compactor: default try decode CUDA graphs; ModelRunner "
"falls back to eager if capture yields none. Set "
"VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 or "
"VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 to skip capture."
)
hf = cfg.hf_config
assert hf is not None
model_type = hf.model_type
if model_type not in MODEL_REGISTRY:
raise ValueError(
f"Compactor MODEL_REGISTRY has no entry for model_type={model_type!r}; "
f"supported: {sorted(MODEL_REGISTRY)}"
)
vllm_model = extract_vllm_causal_lm(llm)
device = next(vllm_model.parameters()).device
dtype = next(vllm_model.parameters()).dtype
# Build compactor shell on CPU first. **Do not** call ``.to(device)`` before tying:
# that allocates a full second copy of weights on GPU; tying then frees the
# duplicate but peak memory can OOM on large models. Tie first so parameters
# alias vLLM tensors directly (no extra weight VRAM).
kv_model: nn.Module = MODEL_REGISTRY[model_type](hf)
tie_kvprune_weights_from_vllm(vllm_model, kv_model)
# Buffers (e.g. RoPE tables) not in ``named_parameters`` may still be on CPU.
kv_model.to(device=device, dtype=dtype)
tie_kvprune_rope_buffers_from_vllm(vllm_model, kv_model)
delegate_kvprune_embed_tokens_to_vllm(vllm_model, kv_model)
delegate_kvprune_compute_logits_to_vllm(vllm_model, kv_model)
# Compactor decode CUDA graphs capture ``model.forward`` + ``compute_logits`` in one
# graph. Here ``compute_logits`` is delegated to vLLM's LM head / LogitsProcessor
# (cublas GEMM, padded vocab, etc.). Embedding that in a nested capture commonly
# fails with ``CUBLAS_STATUS_EXECUTION_FAILED`` and invalidates stream capture
# (``cudaErrorStreamCaptureInvalidated``). Default: skip graphs for this integration.
_sw_graph = os.environ.get(
"VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH", "0"
).strip().lower() in ("1", "true", "yes")
if not _sw_graph:
cfg.enforce_eager = True
logger.info(
"KV-prune shared-weight compactor: enforce_eager=True (skip compactor "
"decode CUDA graphs; logits delegated to vLLM). Set "
"VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1 only to attempt capture (often fails)."
)
return LLMEngine(cfg, external_model=kv_model)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-pruning (compactor) path invoked from :meth:`vllm.entrypoints.llm.LLM.generate`."""
from __future__ import annotations
import os
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from vllm.kvprune.compression.compression_config import (
BatchCompressionParams,
SequenceCompressionParams,
)
from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams
from vllm.kvprune.core.compression_bridge import (
compression_method_id_to_enum,
compression_method_str_to_id,
)
from vllm.kvprune.core.llm_engine import LLMEngine, _infer_stop_token_ids
from vllm.kvprune.integration.compactor_shared import create_compactor_engine_with_shared_weights
from vllm.kvprune.integration.compression_params import CompressionParams
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
_MP_ENV = "VLLM_ENABLE_V1_MULTIPROCESSING"
_RELEASE_V1_KV_ENV = "VLLM_KVPRUNE_RELEASE_V1_KV"
def _maybe_release_v1_kv_for_compactor(llm: Any) -> None:
"""Optionally discard v1's KV cache so more GPU memory is free for compactor.
v1 reserves KV blocks at engine init; shared-weight compactor then competes for
the same VRAM. ``sleep(level=1)`` discards v1 KV and may offload tagged weights
per v1 sleep policy, then ``wake_up()`` reloads — compactor still ties the same
v1 tensors after.
**Default:** ``vllm.env_override`` sets ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` (no
sleep/wake; v1 KV stays on GPU). Set ``=1`` if you need extra VRAM for compactor
before the first compressed step (then ``llm.sleep`` / ``CuMemAllocator`` /
``Sleep mode freed …`` logs are expected). This does **not** remove v1's KV
reservation at init; it only runs the optional sleep/wake cycle before compactor.
Tests keep ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` in ``conftest``.
"""
if os.environ.get(_RELEASE_V1_KV_ENV, "0").strip().lower() not in (
"1",
"true",
"yes",
):
return
try:
logger.info(
"%s=1: discarding v1 KV via sleep(level=1) then wake_up() "
"(reloads model weights to GPU).",
_RELEASE_V1_KV_ENV,
)
llm.sleep(level=1, mode="abort")
llm.wake_up()
except Exception as e:
logger.warning("%s: sleep/wake failed: %s", _RELEASE_V1_KV_ENV, e)
def ensure_inprocess_engine_for_weight_sharing() -> None:
"""Compactor must see ``worker.get_model()`` in the same process as vLLM."""
if os.environ.get(_MP_ENV, "1") != "0":
os.environ[_MP_ENV] = "0"
logger.info(
"KV cache pruning: set %s=0 so the model stays in-process for "
"shared-weight compactor (no manual env needed).",
_MP_ENV,
)
def _normalize_prompt_list(prompts: Any) -> list[Any]:
if isinstance(prompts, str):
return [prompts]
if isinstance(prompts, dict):
return [prompts]
return list(prompts)
def _normalize_sampling_params(
sampling_params: SamplingParams | Sequence[SamplingParams] | None,
n: int,
) -> list[SamplingParams]:
if sampling_params is None:
return [SamplingParams() for _ in range(n)]
if isinstance(sampling_params, SamplingParams):
return [sampling_params] * n
sps = list(sampling_params)
if len(sps) != n:
raise ValueError(
f"sampling_params length {len(sps)} != prompts length {n}"
)
return sps
def _normalize_compression_params(
compression: CompressionParams | Sequence[CompressionParams] | None,
n: int,
) -> list[CompressionParams]:
if compression is None:
return [CompressionParams(compression_ratio=1.0) for _ in range(n)]
if isinstance(compression, CompressionParams):
return [compression] * n
comp = list(compression)
if len(comp) != n:
raise ValueError(f"compression length {len(comp)} != prompts length {n}")
return comp
def _any_compactor(comps: list[CompressionParams]) -> bool:
return any(c.compression_ratio < 1.0 for c in comps)
_FORCE_COMPACTOR_PATH_ENV = "VLLM_KVPRUNE_FORCE_COMPACTOR_PATH"
def _should_use_kvprune_compactor_path(comps: list[CompressionParams]) -> bool:
"""Use integrated compactor when any prompt requests compression, or when forced.
If all ``compression_ratio >= 1.0``, the default is to return ``None`` from
:func:`try_compressed_generate` and fall back to the standard v1 engine
(``Processed prompts`` loop). That hides TP/kvprune bugs behind a different
code path. Set ``VLLM_KVPRUNE_FORCE_COMPACTOR_PATH=1`` to run the same
compactor + collective RPC path as compression-on, with no KV pruning.
"""
if _any_compactor(comps):
return True
return os.environ.get(_FORCE_COMPACTOR_PATH_ENV, "").strip().lower() in (
"1",
"true",
"yes",
)
def _to_compactor_sampling(sp: SamplingParams) -> CompactorSamplingParams:
mt = sp.max_tokens
if mt is None:
mt = 16
return CompactorSamplingParams(
temperature=float(sp.temperature),
max_new_tokens=int(mt),
)
def _to_sequence_compression(cp: CompressionParams) -> SequenceCompressionParams:
return SequenceCompressionParams(
compression_ratio=float(cp.compression_ratio),
protected_first_tokens=int(cp.protected_first_tokens),
protected_last_tokens=int(cp.protected_last_tokens),
)
def _batch_compression_from_comps(comps: list[CompressionParams]) -> BatchCompressionParams:
for c in comps:
if c.compression_ratio < 1.0:
mid = compression_method_str_to_id(c.compression_method)
return BatchCompressionParams(
compression_method=compression_method_id_to_enum(mid)
)
return BatchCompressionParams()
def _kvprune_compactor_hf_tokenizer(llm: Any):
"""HF tokenizer matching :meth:`vllm.kvprune.core.llm_engine.LLMEngine.__init__`.
Loads from the **resolved on-disk** model tree (local dir or HF cache snapshot), not
the bare repo id, to avoid redundant Hub downloads.
"""
cached = getattr(llm, "_kvprune_compactor_hf_tokenizer", None)
if cached is not None:
return cached
mc = llm.llm_engine.vllm_config.model_config
model_s = str(mc.model)
src = model_s
try:
p = Path(model_s)
if p.is_dir() and (p / "config.json").is_file():
src = str(p.resolve())
else:
from huggingface_hub import snapshot_download
src = snapshot_download(repo_id=model_s, local_files_only=False)
except Exception:
src = model_s
hf_cfg = mc.hf_config
_trust = bool(getattr(hf_cfg, "trust_remote_code", False)) if hf_cfg is not None else False
tok = AutoTokenizer.from_pretrained(src, use_fast=True, trust_remote_code=_trust)
llm._kvprune_compactor_hf_tokenizer = tok
return tok
def _prompt_to_compactor_input(prompt: Any) -> str | list[int]:
if isinstance(prompt, str):
return prompt
# Decoder-only `list[int]` token ids (see `vllm.inputs.PromptType`).
if isinstance(prompt, list):
if not prompt:
raise TypeError("Empty token-id prompt is not supported for compactor path.")
if all(isinstance(t, int) for t in prompt):
return list(prompt)
if isinstance(prompt, dict):
if "prompt_token_ids" in prompt:
ids = prompt["prompt_token_ids"]
return list(ids) if not isinstance(ids, list) else ids
p = prompt.get("prompt")
if isinstance(p, str):
return p
raise TypeError(
f"Unsupported prompt type for compactor path: {type(prompt)}. "
"Use str, list[int] token ids, or dict with 'prompt_token_ids' or 'prompt'."
)
def _prompt_to_token_ids_for_tp(llm: Any, prompt: Any) -> list[int]:
"""Driver-side token ids for the TP collective path (same tokenizer as vLLM ``LLM``)."""
comp_in = _prompt_to_compactor_input(prompt)
if isinstance(comp_in, str):
return llm.get_tokenizer().encode(comp_in)
return list(comp_in)
def _compressed_generate_tp_collective(
llm: Any,
plist: list[Any],
sps: list[SamplingParams],
comps: list[CompressionParams],
) -> list[RequestOutput]:
"""TP>1: run compactor on each worker via ``collective_rpc`` (all ranks)."""
vc = llm.llm_engine.vllm_config
pc = vc.parallel_config
if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1:
raise NotImplementedError(
"KV-prune TP compression requires pipeline_parallel_size=1 and "
f"data_parallel_size=1 (got PP={pc.pipeline_parallel_size}, "
f"DP={pc.data_parallel_size})."
)
hf = vc.model_config.hf_config
tok = llm.get_tokenizer()
eos_token_ids = _infer_stop_token_ids(tok, hf)
prompt_token_ids = [_prompt_to_token_ids_for_tp(llm, p) for p in plist]
max_len = int(vc.model_config.max_model_len)
for i, ids in enumerate(prompt_token_ids):
if len(ids) > max_len:
raise ValueError(
f"KV-prune TP compressed generate: prompt {i} length {len(ids)} "
f"exceeds max_model_len ({max_len}). Shorten the prompt or raise "
"max_model_len when constructing LLM()."
)
# Payload must be picklable for multiproc/Ray RPC: do not pass multiprocessing
# synchronization primitives (workers are separate processes).
payload: dict[str, Any] = {
"eos_token_ids": eos_token_ids,
"prompt_token_ids": prompt_token_ids,
"sampling_params": [
{
"temperature": float(sp.temperature),
"max_new_tokens": int(sp.max_tokens if sp.max_tokens is not None else 16),
}
for sp in sps
],
"compression_params": [
{
"compression_ratio": float(c.compression_ratio),
"compression_method": str(c.compression_method),
"protected_first_tokens": int(c.protected_first_tokens),
"protected_last_tokens": int(c.protected_last_tokens),
}
for c in comps
],
}
_maybe_release_v1_kv_for_compactor(llm)
try:
results = llm.llm_engine.collective_rpc(
"kvprune_v1_compressed_generate",
args=(payload,),
)
except RuntimeError as e:
if "cancelled" in str(e).lower():
raise RuntimeError(
"collective_rpc was cancelled (a GPU worker likely crashed). "
"Scroll up for the first worker traceback — often NCCL/CUDA before "
"TCPStore/Broken pipe on the driver."
) from e
raise
master: dict[str, Any] | None = None
for r in results:
if isinstance(r, dict) and r.get("tensor_parallel_rank") == 0:
master = r
break
if master is None:
raise RuntimeError(
"collective_rpc did not return a dict from tensor parallel rank 0."
)
return _tp_payload_to_request_outputs(llm, master)
def _tp_payload_to_request_outputs(llm: Any, master: dict[str, Any]) -> list[RequestOutput]:
tok = llm.get_tokenizer()
out: list[RequestOutput] = []
pids_list = master["prompt_token_ids"]
cids_list = master["completion_token_ids"]
for i, (pids, cids) in enumerate(zip(pids_list, cids_list)):
text = tok.decode(cids, skip_special_tokens=True)
co = CompletionOutput(
index=0,
text=text,
token_ids=list(cids),
cumulative_logprob=None,
logprobs=None,
finish_reason="stop",
)
ro = RequestOutput(
request_id=f"kvprune-tp-{i}",
prompt=None,
prompt_token_ids=list(pids),
prompt_logprobs=None,
outputs=[co],
finished=True,
)
out.append(ro)
return out
def _ensure_compactor_engine(llm: Any) -> LLMEngine:
if llm._kvprune_compactor_engine is None:
pc = llm.llm_engine.vllm_config.parallel_config
if pc.tensor_parallel_size != 1:
raise ValueError(
"KV-pruning compactor path requires tensor_parallel_size=1 "
"for shared weights."
)
llm._kvprune_compactor_engine = create_compactor_engine_with_shared_weights(llm)
logger.info("Initialized compactor LLMEngine with weights shared from vLLM.")
return llm._kvprune_compactor_engine
def try_compressed_generate(
llm: Any,
prompts: Any,
sampling_params: SamplingParams | Sequence[SamplingParams] | None,
*,
compression: CompressionParams | Sequence[CompressionParams] | None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Any = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput] | None:
"""Return completions on the compactor engine, or ``None`` to use normal v1.
``lora_request`` / ``priority`` / ``tokenization_kwargs`` are accepted for API
parity with :meth:`~vllm.entrypoints.llm.LLM.generate` but are not passed to the
compactor engine yet.
"""
del lora_request, priority, tokenization_kwargs, use_tqdm
plist = _normalize_prompt_list(prompts)
sps = _normalize_sampling_params(sampling_params, len(plist))
comps = _normalize_compression_params(compression, len(plist))
pc = llm.llm_engine.vllm_config.parallel_config
# TP>1: every worker must run the same collective_rpc session. If all
# compression_ratio >= 1, the old code returned None and only the driver ran
# v1 _run_engine — other ranks never joined a matching collective, which can
# deadlock NCCL / leave workers unsynchronized (hang at "Processed prompts:").
if pc.tensor_parallel_size > 1:
if not _should_use_kvprune_compactor_path(comps):
comps = [CompressionParams(compression_ratio=1.0) for _ in plist]
elif not _should_use_kvprune_compactor_path(comps):
return None
v1_eager = bool(
getattr(llm.llm_engine.vllm_config.model_config, "enforce_eager", False)
)
if not v1_eager:
logger.warning(
"KV-prune compression: v1 CUDA graphs are still enabled on this LLM. "
"The compactor does not reuse v1 graphs; capture wastes VRAM. "
"Set kvprune_compression=True, enforce_eager=True, or "
"VLLM_KVPRUNE_COMPRESSION_DEFAULT=1 before import vllm."
)
if pc.tensor_parallel_size > 1:
return _compressed_generate_tp_collective(llm, plist, sps, comps)
ensure_inprocess_engine_for_weight_sharing()
if llm._kvprune_compactor_engine is None:
_maybe_release_v1_kv_for_compactor(llm)
engine = _ensure_compactor_engine(llm)
comp_sp = [_to_compactor_sampling(sp) for sp in sps]
seq_c = [_to_sequence_compression(c) for c in comps]
batch_c = _batch_compression_from_comps(comps)
comp_in = [_prompt_to_compactor_input(p) for p in plist]
_, seqs = engine.generate(
comp_in,
sampling_params=comp_sp,
batch_compression_params=batch_c,
per_sequence_compression_params=seq_c,
return_sequences=True,
)
return _sequences_to_request_outputs(seqs, engine)
def _sequences_to_request_outputs(seqs: list[Any], engine: LLMEngine) -> list[RequestOutput]:
tok = engine.tokenizer
out: list[RequestOutput] = []
for i, seq in enumerate(seqs):
text = tok.decode(seq.completion_token_ids, skip_special_tokens=True)
# If every emitted id is “special” (e.g. EOS / chat boundary), the stripped
# string is empty while ``completion_token_ids`` is non-empty — avoid
# presenting a blank answer so users can see boundary tokens / debug.
if not text.strip() and seq.completion_token_ids:
text = tok.decode(seq.completion_token_ids, skip_special_tokens=False)
co = CompletionOutput(
index=0,
text=text,
token_ids=list(seq.completion_token_ids),
cumulative_logprob=None,
logprobs=None,
finish_reason="stop",
)
ro = RequestOutput(
request_id=f"kvprune-{i}",
prompt=None,
prompt_token_ids=list(seq.prompt_token_ids),
prompt_logprobs=None,
outputs=[co],
finished=True,
)
out.append(ro)
return out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Per-request KV compression for :meth:`vllm.LLM.generate` (``compression=`` kwarg)."""
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class CompressionParams:
"""Per-prompt compression intent for :meth:`vllm.LLM.generate`.
If **any** prompt in the batch has ``compression_ratio < 1.0``, the **whole** batch
is run on the compactor ``LLMEngine`` (same stack as standalone compactor-vllm:
``PagedKVCache`` + pruning kernels). If all prompts have ``compression_ratio >= 1.0``,
the batch stays on standard vLLM.
``compression_method`` follows :mod:`vllm.kvprune.core.compression_bridge` aliases:
``none``, ``criticaladakv``, ``compactor``, ``snapkv`` (ignored when
``compression_ratio`` is effectively 1).
``protected_*`` map to compactor :class:`~vllm.kvprune.compression.compression_config.SequenceCompressionParams`
(defaults match standalone compactor-vllm-style usage).
"""
compression_ratio: float = 1.0
compression_method: str = "compactor"
protected_first_tokens: int = 16
protected_last_tokens: int = 64
def __post_init__(self) -> None:
if not 0.0 < self.compression_ratio <= 1.0:
raise ValueError(
f"compression_ratio must be in (0, 1], got {self.compression_ratio}"
)
self.compression_method = (
self.compression_method or "compactor"
).strip().lower()
from vllm.kvprune.core.compression_bridge import VALID_ALIASES_FOR_SAMPLING
if self.compression_method not in VALID_ALIASES_FOR_SAMPLING:
raise ValueError(
f"compression_method must be one of {sorted(VALID_ALIASES_FOR_SAMPLING)}, "
f"got {self.compression_method!r}"
)
if self.compression_ratio >= 1.0 - 1e-9:
self.compression_method = "none"
elif self.compression_method == "none":
raise ValueError(
"When compression_ratio < 1.0, compression_method cannot be 'none'."
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Build :class:`vllm.kvprune.config.engine_config.LLMConfig` from :class:`VllmConfig`."""
from __future__ import annotations
import os
from pathlib import Path
from vllm.config import VllmConfig
from vllm.kvprune.config.engine_config import LLMConfig, KvpruneAttentionSchedule
from vllm.logger import init_logger
logger = init_logger(__name__)
def _attention_schedule_from_env() -> KvpruneAttentionSchedule:
"""Resolve :class:`KvpruneAttentionSchedule` from env.
Primary (``VLLM_KVPRUNE_ATTENTION_SCHEDULE``):
- ``fa_triton`` — FA prefill, Triton decode (default). Aliases: ``fa_prefill``,
``default``, empty.
- ``pdtriton`` — Triton prefill + Triton decode. Aliases: ``triton``,
``triton_prefill``, ``compactor_prefill``, ``pd_triton``.
- ``pdfa`` — FA prefill + FA decode (KV stores still Triton). Aliases:
``fa_full``, ``fa_both``.
Legacy: ``VLLM_KVPRUNE_ATTENTION_BACKEND`` maps ``flash``/``fa`` → ``fa_triton``,
``compactor``/``triton`` → ``pdtriton``.
"""
s = os.environ.get("VLLM_KVPRUNE_ATTENTION_SCHEDULE", "").strip().lower()
if s in ("fa_triton", "fa_prefill", "default", ""):
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
if s in ("pdtriton", "pd_triton", "triton", "triton_prefill", "compactor_prefill"):
return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
if s in ("pdfa", "fa_full", "fa_both"):
return KvpruneAttentionSchedule.PDFA
if s:
logger.warning(
"Unknown VLLM_KVPRUNE_ATTENTION_SCHEDULE=%r; using FA_PREFILL_TRITON_DECODE",
s,
)
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
v = os.environ.get("VLLM_KVPRUNE_ATTENTION_BACKEND", "").strip().lower()
if v in ("flash", "fa", "flash_attention", "flashattention"):
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
if v in ("compactor", "triton", "compactor_triton", ""):
return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
logger.warning(
"Unknown VLLM_KVPRUNE_ATTENTION_BACKEND=%r; using FA_PREFILL_TRITON_DECODE", v
)
return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
def _compactor_kvcache_page_size(vllm_block_size: int | None) -> int:
"""Tokens per physical KV page for compactor :class:`LLMConfig`.
vLLM ``block_size`` is often 16; compactor ``head_sparse_decode_attention`` requires
``PAGE_SIZE % 32 == 0`` (see ``kvprune/attention/sparse_decode_kernel.py``). Standalone
compactor-vllm defaults to 128. Round up to the next multiple of 32 when needed.
"""
if vllm_block_size is None:
return 128
bs = int(vllm_block_size)
if bs <= 0:
return 128
if bs % 32 == 0:
return bs
return ((bs + 31) // 32) * 32
def vllm_config_to_llm_config(vc: VllmConfig) -> LLMConfig:
"""Map vLLM engine config to compactor :class:`LLMConfig`."""
mc = vc.model_config
cc = vc.cache_config
pc = vc.parallel_config
sc = vc.scheduler_config
block_size = cc.block_size
if block_size is None:
block_size = 16
max_num_seqs = getattr(sc, "max_num_seqs", 256)
# Do **not** forward ``model_config.enforce_eager`` (v1) into compactor
# :class:`LLMConfig`. They are independent flags: v1 uses it only to skip
# *v1* ``capture_model()``; kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner`
# uses :attr:`LLMConfig.enforce_eager` only for *compactor* decode CUDA graphs.
# Shared-weight setup in ``compactor_shared`` defaults compactor to eager decode;
# see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` (default try graphs) /
# ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``.
# Local checkpoint directory: forward so compactor skips redundant Hub fetches.
_model_s = str(mc.model)
_path: str | None = None
try:
if _model_s and Path(_model_s).is_dir() and (Path(_model_s) / "config.json").is_file():
_path = str(Path(_model_s).resolve())
except OSError:
pass
return LLMConfig(
model=_model_s,
path=_path,
nccl_port=1218,
max_num_seqs=max_num_seqs,
max_model_len=mc.max_model_len,
gpu_memory_utilization=cc.gpu_memory_utilization,
tensor_parallel_size=pc.tensor_parallel_size,
enforce_eager=False,
hf_config=mc.hf_config,
eos=-1,
eos_token_ids=None,
kvcache_page_size=_compactor_kvcache_page_size(block_size),
leverage_sketch_size=48,
attention_schedule=_attention_schedule_from_env(),
attention_backend=None,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TP>1: one kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner` per vLLM worker.
Invoked via v1 ``collective_rpc("kvprune_v1_compressed_generate", ...)`` so every tensor-
parallel rank participates in the same compactor forward/broadcast sequence as the
standalone multi-process compactor.
Compactor decode CUDA graphs (when not ``enforce_eager``) capture the full decode step
including ``compute_logits``. To force eager on embedded TP workers, set
``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` or ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``.
Peer/master session boundaries use TP-group ``broadcast``/``barrier`` (see
``ModelRunner.maybe_release_peers``), not ``multiprocessing.Event`` — RPC payloads must
be picklable across worker processes.
"""
from __future__ import annotations
import os
from typing import Any
import torch
import torch.nn as nn
from vllm.kvprune.compression.compression_config import (
BatchCompressionParams,
SequenceCompressionParams,
)
from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams
from vllm.kvprune.core.compression_bridge import (
compression_method_id_to_enum,
compression_method_str_to_id,
)
from vllm.kvprune.core.model_runner import ModelRunner
from vllm.kvprune.integration.config_adapter import vllm_config_to_llm_config
from vllm.kvprune.utils.kv_dist import barrier_sync
from vllm.kvprune.integration.weight_tie import (
delegate_kvprune_compute_logits_to_vllm,
delegate_kvprune_embed_tokens_to_vllm,
tie_kvprune_rope_buffers_from_vllm,
tie_kvprune_weights_from_vllm,
)
from vllm.kvprune.models import MODEL_REGISTRY
from vllm.kvprune.utils.sequence import Sequence
_ATTR = "_kvprune_tp_embedded_runner"
def _apply_compactor_env_overrides(cfg: Any) -> None:
"""Match :func:`~vllm.kvprune.integration.compactor_shared.create_compactor_engine_with_shared_weights` caps."""
_cap = os.environ.get("VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS", "32").strip()
if _cap:
lim = int(_cap)
if lim > 0:
cfg.max_num_seqs = min(cfg.max_num_seqs, lim)
_ce = os.environ.get("VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER", "").strip().lower()
if _ce in ("1", "true", "yes"):
cfg.enforce_eager = True
elif _ce in ("0", "false", "no"):
cfg.enforce_eager = False
else:
_dg = os.environ.get("VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH", "1").strip().lower()
cfg.enforce_eager = _dg in ("0", "false", "no")
def _build_sequences(payload: dict[str, Any]) -> list[Sequence]:
prompt_ids: list[list[int]] = payload["prompt_token_ids"]
sps: list[dict[str, Any]] = payload["sampling_params"]
cps: list[dict[str, Any]] = payload["compression_params"]
seqs: list[Sequence] = []
for i, ids in enumerate(prompt_ids):
sp = CompactorSamplingParams(
temperature=float(sps[i]["temperature"]),
max_new_tokens=int(sps[i]["max_new_tokens"]),
)
cp = SequenceCompressionParams(
compression_ratio=float(cps[i]["compression_ratio"]),
protected_first_tokens=int(cps[i].get("protected_first_tokens", 16)),
protected_last_tokens=int(cps[i].get("protected_last_tokens", 64)),
)
if cp.protected_first_tokens + cp.protected_last_tokens >= len(ids):
cp.compression_ratio = 1.0
seqs.append(
Sequence(
prompt_token_ids=list(ids),
sampling_params=sp,
compression_params=cp,
)
)
return seqs
def _batch_compression_from_payload(payload: dict[str, Any]) -> BatchCompressionParams:
cps = payload["compression_params"]
for c in cps:
if float(c["compression_ratio"]) < 1.0:
mid = compression_method_str_to_id(str(c.get("compression_method", "none")))
return BatchCompressionParams(
compression_method=compression_method_id_to_enum(mid)
)
return BatchCompressionParams()
def _get_or_create_runner(worker: Any, payload: dict[str, Any]) -> ModelRunner:
existing = getattr(worker, _ATTR, None)
if existing is not None:
return existing
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
vc = worker.vllm_config
pc = vc.parallel_config
if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1:
raise NotImplementedError(
"KV-prune TP compressed generate requires pipeline_parallel_size=1 and "
f"data_parallel_size=1; got PP={pc.pipeline_parallel_size}, "
f"DP={pc.data_parallel_size}."
)
tp_ws = get_tensor_model_parallel_world_size()
if tp_ws != pc.tensor_parallel_size:
raise RuntimeError(
f"parallel_state TP world size {tp_ws} != config.tensor_parallel_size "
f"{pc.tensor_parallel_size}"
)
hf = vc.model_config.hf_config
model_type = getattr(hf, "model_type", None)
if model_type not in MODEL_REGISTRY:
raise ValueError(
f"KV-prune TP path: unsupported model_type={model_type!r}; "
f"registry has {sorted(MODEL_REGISTRY)}"
)
cfg = vllm_config_to_llm_config(vc)
eos_ids = payload["eos_token_ids"]
cfg.eos_token_ids = sorted({int(x) for x in eos_ids})
cfg.eos = int(cfg.eos_token_ids[0])
_apply_compactor_env_overrides(cfg)
vllm_model = worker.get_model()
kv_model: nn.Module = MODEL_REGISTRY[model_type](hf)
tie_kvprune_weights_from_vllm(vllm_model, kv_model)
dev = next(vllm_model.parameters()).device
dtype = next(vllm_model.parameters()).dtype
kv_model.to(device=dev, dtype=dtype)
tie_kvprune_rope_buffers_from_vllm(vllm_model, kv_model)
delegate_kvprune_embed_tokens_to_vllm(vllm_model, kv_model)
delegate_kvprune_compute_logits_to_vllm(vllm_model, kv_model)
tp_rank = get_tensor_model_parallel_rank()
device = torch.device(f"cuda:{torch.cuda.current_device()}")
if tp_rank == 0:
runner = ModelRunner(
cfg,
rank=0,
peer_events=[],
external_model=kv_model,
embedded_in_vllm_worker=True,
device=device,
)
else:
runner = ModelRunner(
cfg,
rank=tp_rank,
batch_ready=None,
external_model=kv_model,
embedded_in_vllm_worker=True,
device=device,
)
setattr(worker, _ATTR, runner)
return runner
def run_kvprune_tp_compressed_generate(worker: Any, payload: dict[str, Any]) -> dict[str, Any]:
"""Execute one compressed generation session on this worker (all TP ranks)."""
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
tp_rank = get_tensor_model_parallel_rank()
runner = _get_or_create_runner(worker, payload)
sequences = _build_sequences(payload)
batch_c = _batch_compression_from_payload(payload)
barrier_sync(use_tp_group=True)
if tp_rank == 0:
runner.generate(sequences, batch_c)
return {
"tensor_parallel_rank": 0,
"prompt_token_ids": [list(s.prompt_token_ids) for s in sequences],
"completion_token_ids": [list(s.completion_token_ids) for s in sequences],
}
runner.run_peer_session()
return {"tensor_parallel_rank": int(tp_rank), "ok": True}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Access the in-process vLLM model weights for compactor weight sharing."""
from __future__ import annotations
import torch.nn as nn
from vllm.logger import init_logger
logger = init_logger(__name__)
def extract_vllm_causal_lm(llm: object) -> nn.Module:
"""Return the root ``nn.Module`` holding transformer + lm_head from a v1 ``LLM``.
Requires ``LLMEngine`` to have been constructed with ``multiprocess_mode=False``
so ``model_executor`` lives in-process (set ``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
"""
llm_engine = getattr(llm, "llm_engine", None)
if llm_engine is None:
raise RuntimeError("Expected an object with a ``llm_engine`` attribute (e.g. ``vllm.LLM``).")
ex = getattr(llm_engine, "model_executor", None)
if ex is None:
raise RuntimeError(
"model_executor is unavailable (multiprocess engine mode). "
"Set environment variable VLLM_ENABLE_V1_MULTIPROCESSING=0 for "
"in-process weight sharing."
)
driver = getattr(ex, "driver_worker", None)
if driver is None:
raise RuntimeError(
"Executor has no driver_worker (unexpected executor type for weight sharing)."
)
worker = getattr(driver, "worker", None)
if worker is None:
raise RuntimeError("Worker wrapper has no worker loaded.")
get_model = getattr(worker, "get_model", None)
if not callable(get_model):
raise RuntimeError("Worker does not expose get_model().")
return get_model()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Share vLLM parameter storage with compactor ``MODEL_REGISTRY`` models (TP=1)."""
from __future__ import annotations
import types
import torch
import torch.nn as nn
from vllm.kvprune.utils.context import get_context
from vllm.logger import init_logger
logger = init_logger(__name__)
def tie_kvprune_weights_from_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
*,
strict: bool = True,
) -> int:
"""Point compactor parameters to the same tensors as vLLM where names match.
Returns the number of parameters tied. Requires identical parameter names
and shapes for overlapping weights (typical when both stacks mirror HF
naming for the same architecture).
Args:
vllm_model: Model returned by ``worker.get_model()`` (e.g. ``Qwen3ForCausalLM``).
kvprune_model: Instance from ``vllm.kvprune.models.MODEL_REGISTRY``.
strict: If True, raise when any ``kvprune`` parameter name is missing from
``vllm_model`` or shapes differ.
"""
vd = dict(vllm_model.named_parameters())
kd = dict(kvprune_model.named_parameters())
tied = 0
for name, kp in kd.items():
if name not in vd:
if strict:
raise ValueError(
f"kvprune parameter {name!r} not found in vLLM model; "
"architecture/layout may differ (disable strict tying only "
"for expert debugging)."
)
continue
vp = vd[name]
if vp.shape != kp.shape:
raise ValueError(
f"Shape mismatch for {name}: vllm {vp.shape} vs kvprune {kp.shape}"
)
kp.data = vp.data
tied += 1
if tied == 0:
raise ValueError(
"No parameters were tied — check that vLLM and kvprune model types match "
"and use the same state_dict names."
)
logger.info("Tied %d parameters from vLLM into compactor model (shared storage).", tied)
return tied
def tie_kvprune_rope_buffers_from_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
) -> int:
"""Copy RoPE ``cos_sin_cache`` buffers from vLLM into kvprune.
:func:`tie_kvprune_weights_from_vllm` only aliases :class:`~torch.nn.Parameter`
tensors. RoPE tables live in buffers; kvprune's simplified ``RotaryEmbedding``
can disagree with vLLM's ``rope_parameters`` (YaRN, etc.). Copying
``cos_sin_cache`` after ``.to(device, dtype)`` keeps Q/K rotation aligned with
the main model.
kvprune uses layout ``[max_len, 1, rotary_dim]``; vLLM uses ``[max_len,
rotary_dim]``. The singleton dim is filled via ``unsqueeze(1)`` on the vLLM
tensor when copying.
"""
vd = dict(vllm_model.named_buffers())
copied = 0
for name, kb in kvprune_model.named_buffers():
if "cos_sin_cache" not in name:
continue
if name not in vd:
logger.warning(
"kvprune RoPE buffer %r not found in vLLM; leaving kvprune cache",
name,
)
continue
vb = vd[name]
if vb.shape == kb.shape:
kb.copy_(vb)
copied += 1
elif kb.dim() == 3 and vb.dim() == 2:
if (
kb.shape[0] != vb.shape[0]
or kb.shape[2] != vb.shape[1]
or kb.shape[1] != 1
):
raise ValueError(
f"cos_sin_cache shape mismatch for {name!r}: "
f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
)
kb.copy_(vb.unsqueeze(1))
copied += 1
else:
raise ValueError(
f"Unsupported cos_sin_cache layout for {name!r}: "
f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
)
if copied:
logger.info(
"Copied %d RoPE cos_sin_cache buffer(s) from vLLM into kvprune model.",
copied,
)
return copied
def delegate_kvprune_embed_tokens_to_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
) -> bool:
"""Use vLLM's ``model.embed_tokens`` forward for kvprune (TP-safe token→shard mapping).
Even with tied weights, kvprune's simplified contiguous
``VocabParallelEmbedding`` (``vocab_start = rank * partition``) can disagree with
vLLM's padded vocabulary and org/added shard ranges, producing invalid indices for
``F.embedding`` on non-zero TP ranks (``index_copy_`` / device-side assert).
Delegating the forward to vLLM's embedding module keeps masks and indices aligned
with the main model while parameters remain shared storage.
"""
if not hasattr(vllm_model, "model") or not hasattr(kvprune_model, "model"):
return False
vm = getattr(vllm_model.model, "embed_tokens", None)
km = getattr(kvprune_model.model, "embed_tokens", None)
if vm is None or km is None:
logger.warning(
"delegate_kvprune_embed_tokens_to_vllm: embed_tokens missing; skipped"
)
return False
def _forward(_self_unused: nn.Module, x):
return vm(x)
km.forward = types.MethodType(_forward, km)
logger.info(
"kvprune model.embed_tokens forward delegated to vLLM (correct vocab-parallel masks)."
)
return True
def delegate_kvprune_compute_logits_to_vllm(
vllm_model: nn.Module,
kvprune_model: nn.Module,
) -> bool:
"""Route ``kvprune_model.compute_logits`` through vLLM's ``compute_logits``.
Standalone compactor used :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`
with ``F.linear`` + TP gather. vLLM applies :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
(gather/all-gather, padded-vocab trim, quant hooks). Mismatch here commonly
produces garbage token distributions while the rest of the stack looks fine.
After weight tying, ``vllm_model.compute_logits(hidden)`` uses the same lm_head
storage as kvprune; only the *application* path matches production vLLM.
"""
if not callable(getattr(vllm_model, "compute_logits", None)):
logger.warning(
"delegate_kvprune_compute_logits_to_vllm: vLLM model has no compute_logits; skipped"
)
return False
def _compute_logits(_self: nn.Module, hidden_states):
# Match kvprune :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`:
# prefill logits are for the **last** token of each packed sequence only.
context = get_context()
if context.is_prefill and context.cu_seqlens_q is not None:
cuq = context.cu_seqlens_q
last_indices = (cuq[1:] - 1).to(torch.long)
n_tok = hidden_states.shape[0]
if n_tok > 0:
last_indices = last_indices.clamp(min=0, max=n_tok - 1)
hidden_states = hidden_states[last_indices].contiguous()
# vLLM lm_head + gather expect contiguous activations; non-contiguous views have
# caused garbage logits under TP in edge cases.
hidden_states = hidden_states.contiguous()
logits = vllm_model.compute_logits(hidden_states)
return logits
kvprune_model.compute_logits = types.MethodType(_compute_logits, kvprune_model)
return True
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Paged KV cache helpers and Triton KV store."""
from vllm.kvprune.kv_cache.store_kv_cache import (
decode_store_kv,
prefill_store_all_kv,
prefill_store_topk_kv,
)
__all__ = [
"decode_store_kv",
"prefill_store_all_kv",
"prefill_store_topk_kv",
]
import heapq
import logging
from enum import Enum, auto
from typing import List, Optional, Union
import torch
from vllm.kvprune.config.constants import RESERVED_BATCH
from vllm.kvprune.kv_cache.write_page_table import scatter_to_page_table
logger = logging.getLogger(__name__)
def cdiv(a, b):
return (a + b - 1) // b
def next_multiple(a, b):
return cdiv(a, b) * b
class KVAllocationStatus(Enum):
EXCEEDS_MAX_SEQUENCE_LENGTH = auto()
EXCEEDS_CURRENTLY_AVAILABLE_PAGES = auto()
EXCEEDS_MAX_NUM_BATCHES = auto()
SUCCESS = auto()
class PagedKVCache(torch.nn.Module):
"""
Global paged KV cache.
This module manages:
* A global K/V backing buffer for all layers:
``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
where the first dimension indexes K vs V.
* A per-layer page table:
``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
mapping logical (batch, kv-head, logical_page) to a physical page ID
in the global K/V buffer.
* Per-layer, per-(batch, kv-head) logical sequence lengths
``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
the number of allocated pages ``bh_num_pages`` for each (layer, batch,
head).
* A page allocator implemented as a min-heap of free physical pages
per layer, plus free batch indices.
Pages are of fixed size ``page_size`` tokens.
Args:
:param num_layers:
Number of transformer layers that will use this cache.
:param max_logical_pages_per_head:
Maximum number of logical pages that can be assigned to a single
(batch, kv-head) pair.
:param num_pages:
Total number of physical pages available in the global cache per
layer. The global K/V buffers are of length
``num_pages * page_size`` along the token dimension.
:param page_size:
Number of tokens stored per page.
:param H_kv:
Number of KV heads per layer.
:param head_dim:
Head dimension for K/V.
:param max_num_batches:
Maximum number of concurrent batches / sequences supported. One
batch index is reserved for internal use (``RESERVED_BATCH``).
:param dtype:
Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
:param device:
Device on which to allocate the cache (string, torch.device, or
int; defaults to ``"cuda"``).
"""
def __init__(
self,
num_layers: int,
max_logical_pages_per_head: int,
num_pages: int,
page_size: int, # tokens per page
H_kv: int,
head_dim: int,
max_num_batches: int,
dtype: torch.dtype,
device: Union[str, torch.device, int] = "cuda",
):
super().__init__()
self.n_pages = num_pages
self.num_layers = num_layers
self.page_size: int = int(page_size)
self.H_kv = int(H_kv)
self.max_pages_per_head = max_logical_pages_per_head
max_num_batches += 1
self.max_num_batches = max_num_batches
self.head_dim = head_dim
cache_shape = (2, num_layers, num_pages * page_size, head_dim)
self.kv_cache = torch.empty(cache_shape, dtype=dtype, device=device)
self.page_table = torch.empty(
(num_layers, max_num_batches, H_kv, self.max_pages_per_head),
device=device,
dtype=torch.int32,
)
# Per-(batch, head) logical seq length (tokens)
self.bh_seq_lens = torch.zeros(
(num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
)
# self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
self.bh_num_pages = torch.zeros(
(num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
)
# Page allocator (min-heap of free physical pages)
self.free_pages: List[List[int]] = [
list(range(num_pages)) for _ in range(num_layers)
]
for free_pages in self.free_pages:
heapq.heapify(free_pages)
# batch zero is reserved
self.free_batches: List[int] = list(reversed(range(max_num_batches)))
self.free_batches.remove(RESERVED_BATCH)
# Record of physical page ids owned by a batch (for freeing)
self.pages_indices_per_batch: List[List[set[int]]] = [
[set() for _ in range(num_layers)] for _ in range(max_num_batches)
]
def new_batch(self) -> Optional[int]:
"""
Reserve a new batch slot.
A batch slot corresponds to a row in ``bh_seq_lens`` /
``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
heads. This method checks whether a free batch index is available, and
whether each layer has at least ``H_kv`` free pages remaining.
If both checks pass, it returns a batch index and removes it from
``free_batches``. Otherwise, it returns ``None``.
Returns:
:return Optional[int]:
Newly reserved batch index, or ``None`` if no capacity is
available.
"""
if self.free_batches and all([self.H_kv <= len(fp) for fp in self.free_pages]):
return self.free_batches.pop()
return None
def reserve_tokens(self, batch_index: int, add_tokens: int) -> KVAllocationStatus:
"""
Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
Args:
:param batch_index:
Batch index to reserve space for.
:param add_tokens:
Number of additional tokens to reserve capacity for.
All heads in this batch and all layers reserve
the same number of extra tokens.
Returns:
:return bool:
``True`` if the reservation succeeds; ``False`` otherwise .
"""
cur_bh_lens = self.bh_seq_lens[:, batch_index] # [L, H]
curr_pages = self.bh_num_pages[:, batch_index] # [L, H]
curr_cap_tokens = curr_pages * self.page_size # [L, H]
need_tokens = cur_bh_lens + add_tokens # [L, H]
if (need_tokens <= curr_cap_tokens).all():
return KVAllocationStatus.SUCCESS
missing_tokens = need_tokens - curr_cap_tokens
add_pages = cdiv(missing_tokens, self.page_size)
new_total_pages = curr_pages + add_pages
if (new_total_pages > self.max_pages_per_head).any():
return KVAllocationStatus.EXCEEDS_MAX_SEQUENCE_LENGTH
# CPU work
pages_per_layer_cpu = add_pages.sum(dim=-1).tolist()
new_phys_pages = []
for layer_index in range(self.num_layers):
if pages_per_layer_cpu[layer_index] > len(self.free_pages[layer_index]):
return KVAllocationStatus.EXCEEDS_CURRENTLY_AVAILABLE_PAGES
for layer_index in range(self.num_layers):
this_layer_pages = [
heapq.heappop(self.free_pages[layer_index])
for _ in range(pages_per_layer_cpu[layer_index])
]
self.pages_indices_per_batch[batch_index][layer_index] |= set(
this_layer_pages
)
new_phys_pages.extend(this_layer_pages)
new_phys_pages = torch.tensor(new_phys_pages, dtype=torch.int32, device="cuda")
scatter_to_page_table(
add_pages=add_pages,
new_phys_pages=new_phys_pages,
curr_pages=curr_pages,
page_table=self.page_table[:, batch_index],
max_pages_per_head=self.max_pages_per_head,
)
self.bh_num_pages[:, batch_index, :] = new_total_pages.to(
self.bh_num_pages.dtype
)
return KVAllocationStatus.SUCCESS
def reclaim_pages(
self,
batch_index: int,
future_reserve_tokens: int = 0,
):
"""
Reclaim unused pages for a single batch index. This shrinks the KV
allocation for the batch down to the minimum number of pages needed
to hold the current (plus optional future) sequence length.
Args:
:param batch_index:
Batch index whose pages should be compacted.
:param future_reserve_tokens:
Optional number of extra tokens to keep capacity for, beyond
the current sequence length. This can reduce churn when
sequences are expected to grow slightly in the near future.
Returns:
:return int:
Approximate number of bytes freed across both K and V.
"""
device = self.bh_seq_lens.device
L, B, H = self.bh_seq_lens.shape
assert 0 <= batch_index < B
seq = self.bh_seq_lens[:, batch_index, :] + future_reserve_tokens # [L, H]
alloc = self.bh_num_pages[:, batch_index, :] # [L, H]
pt = self.page_table[:, batch_index, :, :].reshape(-1) # [L, H, P]
# Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
used_pages = cdiv(seq, self.page_size)
used_pages = torch.minimum(used_pages, alloc)
# page indices [0..P-1], broadcasted over [L, H, P]
p = torch.arange(
self.max_pages_per_head, device=device, dtype=torch.int32
).view(1, 1, self.max_pages_per_head)
# allocated: p < alloc
alloc_mask = p < alloc.unsqueeze(-1) # [L, H, P]
# to free: allocated and p in [used_pages, alloc)
free_mask = alloc_mask & (p >= used_pages.unsqueeze(-1))
free_mask_flat = free_mask.view(-1) # [L*H*P]
if not free_mask_flat.any():
return 0
idx = free_mask_flat.nonzero(as_tuple=False).squeeze(
-1
) # indices of freed slots
# Freed physical page ids
freed_pages = pt[idx]
# Compute layer index for each freed slot:
# layout is [L, H, P] → flat index = ((l * H) + h) * P + p
freed_layers = (idx // (H * self.max_pages_per_head)).to(torch.int32)
freed_pages = freed_pages.tolist()
layer_mapping = freed_layers.tolist()
self.bh_num_pages[:, batch_index, :] = used_pages
for page, layer in zip(freed_pages, layer_mapping):
self.pages_indices_per_batch[batch_index][layer].remove(page)
heapq.heappush(self.free_pages[layer], page)
approximate_bytes_freed = (
len(freed_pages)
* (self.page_size * self.head_dim * self.kv_cache.element_size())
* 2
) # multiply for two for K + V
return approximate_bytes_freed
def _free_batch_layer(self, layer_index: int, batch_index: int) -> None:
"""
Free all pages belonging to batch_index and reset its metadata.
"""
# Return pages to the global heap
for phys in self.pages_indices_per_batch[batch_index][layer_index]:
heapq.heappush(self.free_pages[layer_index], int(phys))
self.pages_indices_per_batch[batch_index][layer_index] = set()
def free_batch(self, batch_index: int) -> None:
"""
Free all resources associated with a batch index.
Args:
:param batch_index:
Batch index to release. Must have been previously allocated
via :meth:`new_batch`.
"""
for layer in range(self.num_layers):
self._free_batch_layer(layer, batch_index)
self.bh_seq_lens[:, batch_index].zero_()
self.bh_num_pages[:, batch_index].zero_()
self.free_batches.append(batch_index)
def layer_slices(self, layer: int):
"""
Return layer-local views needed by the attention module.
For a given ``layer`` index, this method returns the slices of the
global K/V cache, page table, and per-(batch, head) sequence lengths
corresponding to that layer.
Args:
:param layer:
Layer index ``l`` in ``[0, num_layers)``.
Returns:
:return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
``(k, v, pt, bh)`` as described above.
"""
assert 0 <= layer < self.num_layers
k = self.kv_cache[0, layer]
v = self.kv_cache[1, layer]
pt = self.page_table[layer]
bh = self.bh_seq_lens[layer]
return k, v, pt, bh
import torch
import triton
import triton.language as tl
from vllm.kvprune.config.constants import (
TRITON_RESERVED_BATCH as _TRITON_RESERVED_BATCH,
)
@triton.jit
def _prefill_store_topk_kv_kernel(
key,
value, # [N_total, H, D] (D stride assumed 1)
batch_mapping, # [B] int32 (local b -> true batch)
num_tokens_to_retain, # [B] int32
indices_topk, # [B, MAX_SEL] int32 (across all heads)
# Lengths & page table:
bh_lens, # [B, H] int32 (contiguous)
page_table, # [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
k_cache,
v_cache, # [N_PAGES * PAGE_SIZE, D]
sk_n,
sk_h, # strides for key,value. D stride assumed 1
sv_n,
sv_h,
# Runtime ints
MAX_SEL, # num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
K_TILE: tl.constexpr, # how many selected tokens each program processes
TRITON_RESERVED_BATCH: tl.constexpr,
):
b_local = tl.program_id(0)
tile_id = tl.program_id(1)
offs = tl.arange(0, D)
# how many tokens we actually keep for this batch
k_total = tl.load(num_tokens_to_retain + b_local)
if k_total == 0:
return
# map to true batch row in the page table
b_true = tl.load(batch_mapping + b_local)
if b_true == TRITON_RESERVED_BATCH:
return
base = tile_id * K_TILE
# process up to K_TILE tokens
for j in tl.range(0, K_TILE):
sel_idx = base + j
if sel_idx < k_total and sel_idx < MAX_SEL:
# flattened selection: sel = token * H + head
sel = tl.load(indices_topk + b_local * MAX_SEL + sel_idx)
tok = sel // HKV
head = sel - (tok * HKV)
# atomically reserve one position in (b_local, hed)
# i.e the KV cache is scrambled when storing
len_ptr = bh_lens + b_local * HKV + head
pos = tl.atomic_add(len_ptr, 1) # old length (int32)
lp = pos // PAGE_SIZE
off = pos - lp * PAGE_SIZE
# translate logical page to physical page
pt_base = (b_true * HKV + head) * N_LOGICAL_PAGES_MAX
phys = tl.load(page_table + pt_base + lp).to(tl.int64)
# destination row and element offset
dst_row = phys * PAGE_SIZE + off
dst_off = dst_row * D + offs
# load one vector from [N_total, H, D]
k_src = key + tok * sk_n + head * sk_h + offs
v_src = value + tok * sv_n + head * sv_h + offs
tl.store(
k_cache + dst_off,
tl.load(k_src, cache_modifier=".cv", eviction_policy="evict_first"),
eviction_policy="evict_first",
)
tl.store(
v_cache + dst_off,
tl.load(v_src, cache_modifier=".cv", eviction_policy="evict_first"),
eviction_policy="evict_first",
)
def prefill_store_topk_kv(
*,
new_keys: torch.Tensor, # [N_total, H, D]
new_vals: torch.Tensor, # [N_total, H, D]
indices_topk: torch.Tensor, # [B, MAX_SEL] int32 (global flattened token*H + head)
num_tokens_to_retain: torch.Tensor, # [B] int32
page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE: int,
PAD_TO_PAGE_SIZE: bool = True,
cu_seqlens_k: torch.Tensor | None = None,
K_TILE: int = 16,
TRITON_RESERVED_BATCH: int = None,
):
assert new_keys.shape == new_vals.shape
N_total, H, D = new_keys.shape
B = indices_topk.shape[0]
assert page_table.shape[1] == H
assert bh_lens.shape == (B, H)
assert new_keys.device == k_cache.device == v_cache.device
assert page_table.is_contiguous(), "page table must be contiguous."
assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
assert k_cache.is_contiguous() and v_cache.is_contiguous()
assert new_keys.stride(-1) == 1 and new_vals.stride(-1) == 1, (
"new_keys/new_vals last dim must be contiguous."
)
assert (D & (D - 1)) == 0, "D must be a power of 2"
page_table = page_table.to(torch.int32)
bh_lens = bh_lens.to(torch.int32)
batch_mapping = batch_mapping.to(torch.int32)
indices_topk = indices_topk.to(torch.int32)
num_tokens_to_retain = num_tokens_to_retain.to(torch.int32)
# strides (elements) for [N_total, H, D]
sk_n, sk_h, _ = new_keys.stride()
sv_n, sv_h, _ = new_vals.stride()
# tile second grid dim
MAX_SEL = indices_topk.shape[-1]
N_TILES = (MAX_SEL + K_TILE - 1) // K_TILE
grid = (B, max(1, N_TILES))
if TRITON_RESERVED_BATCH is None:
TRITON_RESERVED_BATCH = _TRITON_RESERVED_BATCH
_prefill_store_topk_kv_kernel[grid](
key=new_keys,
value=new_vals,
batch_mapping=batch_mapping,
num_tokens_to_retain=num_tokens_to_retain,
indices_topk=indices_topk,
bh_lens=bh_lens,
page_table=page_table,
k_cache=k_cache,
v_cache=v_cache,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
MAX_SEL=int(MAX_SEL),
HKV=H,
N_LOGICAL_PAGES_MAX=page_table.shape[2],
D=D,
PAGE_SIZE=PAGE_SIZE,
K_TILE=K_TILE,
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
)
if PAD_TO_PAGE_SIZE:
assert cu_seqlens_k is not None
assert indices_topk.is_contiguous()
assert page_table.is_contiguous()
_prefill_store_topk_pad_kernel[(B, H)](
key=new_keys,
value=new_vals,
batch_mapping=batch_mapping,
num_tokens_to_retain=num_tokens_to_retain,
indices=indices_topk,
local_lens=bh_lens,
page_table_flat=page_table,
k_cache=k_cache,
v_cache=v_cache,
cu_seqlens_k=cu_seqlens_k,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
MAX_SEL=int(MAX_SEL),
H=H, # type: ignore
N_LOGICAL_PAGES_MAX=page_table.shape[2], # type: ignore
D=D, # type: ignore
PAGE_SIZE=PAGE_SIZE, # type: ignore
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
)
@triton.jit
def _prefill_store_topk_pad_kernel(
key, # [N_total, H, D]
value, # [N_total, H, D]
batch_mapping, # [B] int32 (local b -> true batch)
num_tokens_to_retain, # [B] int32
indices, # [B, MAX_SEL] int32 (across all heads)
local_lens, # [B, H] int32 (contiguous)
page_table_flat, # [B_total*H*N_LOGICAL_PAGES_MAX] int32
k_cache,
v_cache, # [N_PAGES*PAGE_SIZE, D]
cu_seqlens_k,
sk_n,
sk_h,
sv_n,
sv_h,
MAX_SEL,
# Constexprs
H: tl.constexpr, # number of KV heads
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
TRITON_RESERVED_BATCH: tl.constexpr,
):
b_local = tl.program_id(0)
h = tl.program_id(1)
offs_d = tl.arange(0, D)
L = tl.load(local_lens + b_local * H + h)
modulo_page_size = L - (L // PAGE_SIZE) * PAGE_SIZE
if modulo_page_size == 0:
return
need = PAGE_SIZE - modulo_page_size
b_true = tl.load(batch_mapping + b_local)
if b_true == TRITON_RESERVED_BATCH:
return
pt_base = (b_true * H + h) * N_LOGICAL_PAGES_MAX
written_tokens = 0
idx = tl.load(num_tokens_to_retain + b_local)
this_batch_ctx_len = tl.load(cu_seqlens_k + b_local + 1) - tl.load(
cu_seqlens_k + b_local
)
max_additional = this_batch_ctx_len - L
while (written_tokens < need and idx < MAX_SEL) and (
written_tokens < max_additional
):
# candidate head
cand_idx = tl.load(indices + b_local * MAX_SEL + idx)
cand_h = cand_idx % H
if cand_h == h:
tok = cand_idx // H
pos = L + written_tokens
lp = pos // PAGE_SIZE
off = pos - lp * PAGE_SIZE
phys = tl.load(page_table_flat + pt_base + lp).to(tl.int32)
dst_row = phys * PAGE_SIZE + off
dst_off = dst_row.to(tl.int64) * D + offs_d
k_src = key + tok * sk_n + h * sk_h + offs_d
v_src = value + tok * sv_n + h * sv_h + offs_d
tl.store(
k_cache + dst_off,
tl.load(k_src),
)
tl.store(
v_cache + dst_off,
tl.load(v_src),
)
written_tokens += 1
idx += 1
tl.store(local_lens + b_local * H + h, L + written_tokens)
@triton.jit
def _prefill_store_all_kv_kernel(
key,
value, # [N, H, D] (D contiguous)
cu_seqlens_k, # [B + 1] int32
batch_mapping, # [B] int32 (local b -> true batch index)
bh_lens, # [B * HKV] int32 (UPDATED)
pt_flat, # [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
k_cache,
v_cache, # [N_PAGES * PAGE_SIZE, D]
# source strides (elements)
sk_n,
sk_h,
sv_n,
sv_h,
# constexpr
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
K_TILE: tl.constexpr, # number of (token, head) pairs processed per program
):
pid_b = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(cu_seqlens_k + pid_b)
end = tl.load(cu_seqlens_k + pid_b + 1)
num_toks_this_batch = end - start
if num_toks_this_batch <= 0:
return
total_elems = num_toks_this_batch * HKV
# base linear index in (token, head) grid for this program
base = pid_blk * K_TILE
offs_d = tl.arange(0, D)
# Iterate K_TILE elements in this tile
for i in tl.range(0, K_TILE):
idx = base + i
if idx < total_elems:
# map linear idx -> (t, h)
t = idx // HKV
h = idx - t * HKV
len_idx = pid_b * HKV + h
L0 = tl.load(bh_lens + len_idx)
token_idx_in_cache = L0 + t
lp = token_idx_in_cache // PAGE_SIZE # logical page
off_in_pg = token_idx_in_cache - lp * PAGE_SIZE # pos in page
# physical page
b_true = tl.load(batch_mapping + pid_b).to(tl.int32)
pt_base = (b_true * HKV + h) * N_LOGICAL_PAGES_MAX
phys = tl.load(pt_flat + pt_base + lp).to(tl.int64)
row = phys * PAGE_SIZE + off_in_pg
dst_off = row * D + offs_d
n_global = (start + t).to(tl.int64)
# Use strides for non-contiguous [N, H, D] (D stride == 1)
k_src = key + n_global * sk_n + h * sk_h + offs_d
v_src = value + n_global * sv_n + h * sv_h + offs_d
tl.store(k_cache + dst_off, tl.load(k_src))
tl.store(v_cache + dst_off, tl.load(v_src))
def prefill_store_all_kv(
*,
new_keys: torch.Tensor,
new_values: torch.Tensor, # [N, H_kv, D]
cu_seqlens_k: torch.Tensor, # [B + 1] int32
max_seqlen_k: int,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor, # [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
bh_lens: torch.Tensor, # [B, H_kv] int32 (UPDATED)
batch_mapping: torch.Tensor, # [B] int32 (local->true)
PAGE_SIZE: int,
K_TILE: int = 32, # how many (token, head) pairs per program
):
assert new_keys.stride(-1) == 1 and new_values.stride(-1) == 1, (
"last dim must be contiguous"
)
assert page_table.is_contiguous(), "page table must be contiguous"
assert bh_lens.is_contiguous(), "bh_lens must be contiguous"
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous"
assert k_cache.is_contiguous() and v_cache.is_contiguous()
N, HKV, D = new_keys.shape
B = batch_mapping.shape[0]
assert (D & (D - 1)) == 0, "D must be a power of 2"
sk_n, sk_h, _ = new_keys.stride()
sv_n, sv_h, _ = new_values.stride()
n_tiles = (max_seqlen_k * HKV + K_TILE - 1) // K_TILE
grid = (B, n_tiles)
_prefill_store_all_kv_kernel[grid](
new_keys,
new_values,
cu_seqlens_k,
batch_mapping,
bh_lens,
page_table,
k_cache,
v_cache,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
HKV=HKV,
N_LOGICAL_PAGES_MAX=page_table.shape[-1],
D=D,
PAGE_SIZE=PAGE_SIZE,
K_TILE=K_TILE,
)
bh_lens += cu_seqlens_k.diff()[:, None]
@triton.jit
def _decode_store_kv_kernel(
key,
value,
batch_mapping, # [B] int32
bh_lens, # [B*HKV] int32
page_table, # [B_total*HKV*N_LOGICAL_PAGES_MAX]
k_cache,
v_cache, # [N_PAGES*PAGE_SIZE, D]
sk_b,
sk_h,
sv_b,
sv_h,
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
TRITON_RESERVED_BATCH: tl.constexpr,
):
pid_b = tl.program_id(0)
h = tl.program_id(1)
mapped_b = tl.load(batch_mapping + pid_b)
if mapped_b == TRITON_RESERVED_BATCH:
return
offs_d = tl.arange(0, D)
length = tl.load(bh_lens + pid_b * HKV + h)
logical_page = length // PAGE_SIZE
internal_offset = length - logical_page * PAGE_SIZE
pt_base = (mapped_b * HKV + h) * N_LOGICAL_PAGES_MAX
physical_page = tl.load(page_table + pt_base + logical_page).to(tl.int64)
dst_row = physical_page * PAGE_SIZE + internal_offset
# Source addressing using strides (D stride == 1)
k_src = key + pid_b * sk_b + h * sk_h + offs_d
v_src = value + pid_b * sv_b + h * sv_h + offs_d
dst_off = dst_row * D + offs_d
tl.store(k_cache + dst_off, tl.load(k_src))
tl.store(v_cache + dst_off, tl.load(v_src))
tl.store(bh_lens + pid_b * HKV + h, length + 1)
def decode_store_kv(
*,
key: torch.Tensor, # [B, HKV, D]
value: torch.Tensor, # [B, HKV, D]
batch_mapping: torch.Tensor, # [B] int32
bh_lens: torch.Tensor, # [B, HKV] or flattened [B*HKV] int32
page_table: torch.Tensor, # [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
k_cache: torch.Tensor,
v_cache: torch.Tensor, # [N_PAGES*PAGE_SIZE, D]
PAGE_SIZE: int,
TRITON_RESERVED_BATCH: int = None,
):
assert key.shape == value.shape and key.ndim == 3, "key/value must be [B, HKV, D]"
B, HKV, D = key.shape
assert key.stride(-1) == 1 and value.stride(-1) == 1, (
"key/value last dim must be contiguous."
)
assert page_table.is_contiguous(), "page table must be contiguous."
assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
assert k_cache.is_contiguous() and v_cache.is_contiguous()
assert (D & (D - 1)) == 0, "D must be a power of 2"
sk_b, sk_h, _ = key.stride()
sv_b, sv_h, _ = value.stride()
grid = (
int(batch_mapping.shape[0]),
HKV,
)
_decode_store_kv_kernel[grid](
key=key,
value=value,
batch_mapping=batch_mapping,
bh_lens=bh_lens,
page_table=page_table,
k_cache=k_cache,
v_cache=v_cache,
sk_b=sk_b,
sk_h=sk_h,
sv_b=sv_b,
sv_h=sv_h,
HKV=HKV,
N_LOGICAL_PAGES_MAX=page_table.shape[2],
D=D,
PAGE_SIZE=PAGE_SIZE,
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH
if TRITON_RESERVED_BATCH is not None
else _TRITON_RESERVED_BATCH,
)
import torch
import triton
import triton.language as tl
def scatter_to_page_table(
add_pages: torch.Tensor, # [L, H] int32
new_phys_pages: torch.Tensor, # [N]
curr_pages: torch.Tensor, # [L, H] int32
page_table: torch.Tensor, # [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
max_pages_per_head: int,
):
"""
Append newly allocated physical pages into a layered page table via Triton.
For each (layer ``l``, head ``h``):
Args:
:param add_pages:
Tensor of shape ``[L, H]`` (int32) indicating how many pages to
append for each (layer, head).
:param new_phys_pages:
1D tensor of shape ``[N]`` (int32) containing physical page IDs
for all (layer, head) pairs, concatenated in row-major (L, H)
order. ``N`` must equal ``add_pages.sum()``.
:param curr_pages:
Tensor of shape ``[L, H]`` (int32) with the current logical page
counts per (layer, head) before this update.
:param page_table:
Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
the logical to physical page mapping. The last dimension is
logically indexed as logical_page ∈ [0, max_pages_per_head).
:param max_pages_per_head:
Maximum number of logical pages permitted per (layer, head). The
kernel skips writes beyond this bound.
Returns:
None. The function updates ``page_table`` in-place.
"""
L, H = add_pages.shape
if L == 0 or H == 0:
return
add_flat = add_pages.to(torch.int32).contiguous().view(-1)
curr_flat = curr_pages.to(torch.int32).contiguous().view(-1)
cum_page_heads = torch.empty(L * H + 1, device="cuda", dtype=torch.int32)
cum_page_heads[0] = 0
torch.cumsum(add_flat, 0, out=cum_page_heads[1:])
stride_pl, stride_ph, stride_pp = page_table.stride()
grid = (L, H)
_scatter_pages_kernel_lh[grid](
add_flat,
cum_page_heads,
new_phys_pages,
curr_flat,
page_table,
stride_pl,
stride_ph,
stride_pp,
L=L,
H=H,
max_pages_per_head=max_pages_per_head,
)
@triton.jit
def _scatter_pages_kernel_lh(
add_pages, # int32 [L*H]
cum_page_heads, # int32 [L*H], base offset in flat_new_phys per (l,h)
flat_new_phys, # int32 [total_pages]
curr_pages, # int32 [L*H], existing logical pages per (l,h)
page_table_ptr, # int32* base pointer to page_table
stride_pl, # int, stride for layer dim
stride_ph, # int, stride for head dim
stride_pp, # int, stride for page dim
L: tl.constexpr,
H: tl.constexpr,
max_pages_per_head: tl.constexpr,
):
layer_idx = tl.program_id(0)
h = tl.program_id(1)
if layer_idx >= L or h >= H:
return
lh = layer_idx * H + h
ap = tl.load(add_pages + lh)
if ap <= 0:
return
base = tl.load(cum_page_heads + lh)
cp = tl.load(curr_pages + lh)
# Append ap pages: logical pages [cp .. cp+ap)
for i in tl.range(0, ap):
phys = tl.load(flat_new_phys + base + i)
lp = cp + i
if lp < max_pages_per_head:
offset = layer_idx * stride_pl + h * stride_ph + lp * stride_pp
tl.store(page_table_ptr + offset, phys)
# TODO: write reclaim kernel
@triton.jit
def reclaim_page_kernel():
pass
def reclaim_pages(
batch_index: int,
bh_seq_lens: torch.Tensor,
bh_num_pages: torch.Tensor,
page_table: torch.Tensor,
):
pass
# KV-prune 与上游 vLLM 的集成说明
本文说明:**剪枝/压缩(Compactor)功能**在「官网 vLLM 主仓库」里改动了哪些位置、是否只有少量文件、以及随 vLLM 版本升级时如何预期合并成本。
## 1. 是否「仅仅」改了少数几个脚本?
**核心运行时接线**确实集中在少数几个**非** `vllm/kvprune/` 下的文件;功能主体在 `vllm/kvprune/` 包内独立维护。
| 路径 | 作用简述 |
|------|-----------|
| `vllm/env_override.py` | 在 `import vllm` 最早阶段设置与 kvprune 相关的默认环境变量(如 v1 多进程默认、压缩默认开关、可选释放 v1 KV 等)。 |
| `vllm/__init__.py` | 对外导出 `CompressionParams`(懒加载至 `vllm.kvprune.integration.compression_params`)。 |
| `vllm/entrypoints/llm.py` | `kvprune_compression` 参数、`generate(..., compression=...)`、v1 `enforce_eager` / `num_gpu_blocks_override` 策略、懒加载 compactor、委托 `compressed_generate`。 |
| `vllm/v1/worker/gpu_worker.py` | `kvprune_v1_compressed_generate`:供 `collective_rpc` 调用的 TP 多卡压缩生成入口。 |
| `tests/conftest.py` | 测试在导入 vLLM 前覆盖部分 `VLLM_KVPRUNE_*` 默认值,避免全量测试默认走压缩路径。 |
| `vllm\vllm\envs.py` | envs.py 中对 VLLM_KVPRUNE_* 的集中注册 |
**此外(可选/示例,非引擎必需):**
- `examples/offline_inference/` 下若干 `*kvprune*` 示例脚本:演示用法,不参与核心引擎加载。
**结论:**
- **「官网 vLLM 主包」里与 kvprune 强相关的改动,主要就是上表 4 个文件 + 测试根配置**(若把测试也算进「集成面」,共 5 处常见提法)。
- **算法、Compactor、TP 内嵌 runner 等**均在 `vllm/kvprune/`(及该目录下的 `integration/`)中,与上游 diff 相对隔离。
## 2. 随 vLLM 版本更新,是否「很容易」同步剪枝压缩功能?
**相对容易的部分:**
- **集成面小**:合并冲突主要出现在上述少数文件,而不是遍布整个 executor / attention / model 层。
- **逻辑内聚**:大量代码在 `vllm/kvprune/`,可整体移植或 `git` 三方合并时以子树为主处理。
**仍需人工跟进的点(不能假设「自动无痛」):**
- **`entrypoints/llm.py` 属于高频变更文件**:上游每次大版本可能重构 `LLM` 构造参数、`generate` 签名或引擎初始化;需要**逐次解决冲突**并回归压缩路径。
- **`v1/worker/gpu_worker.py`** 同样会随 executor / RPC 接口变动;`collective_rpc` 方法名或 worker 基类若有变化,需对齐。
- **`env_override.py`** 若上游调整导入顺序或新增全局默认环境变量,需避免覆盖冲突或行为打架。
- **vLLM v1 内部 API**(如 `worker.get_model()``vllm_config` 结构)若变更,`vllm/kvprune/integration/*` 也可能要跟着改——这类改动**不在**「仅 5 个文件」里,但仍是**集成层**维护成本。
**建议同步流程(简版):**
1. 在新上游 tag 上先合并/应用 `vllm/kvprune/` 目录。
2. 再手动合并上述 4 个主包文件 + `tests/conftest.py`
3. 跑与 kvprune 相关的测试与至少一条离线 `compression` 示例。
4. 关注发行说明中 `LLM``EngineArgs``gpu_worker`、多进程默认的破坏性变更。
## 3. 与「深度改内核」方案的区别
当前设计**没有**`model_executor` 的统一注意力路径上大规模插入 kvprune 钩子(相关辅助逻辑主要在 `vllm/kvprune` 内部)。因此:
- **上游同步时**,通常不必与 FlashAttention / 每层模型代码逐文件对打;
- **代价是**:功能边界以「共享权重 + compactor 引擎 + 可选 TP RPC」为主,与「原生 KV 算子级一体化」的改动面不同。
---
*文档随仓库维护;若集成文件列表有增删,请同步更新本节表格。*
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layers from upstream compactor (attention, linear, MoE, …).
Prefer importing concrete modules, e.g. ``from vllm.kvprune.layers.attention import ...``.
"""
__all__: list[str] = []
import torch
import torch.nn.functional as F
from torch import nn
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y
from typing import Optional
import torch
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from torch import nn
from vllm.kvprune.attention.fa_paged_bridge import (
flash_decode_from_paged,
flash_prefill_from_paged,
)
from vllm.kvprune.attention.sparse_decode_kernel import head_sparse_decode_attention
from vllm.kvprune.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
from vllm.kvprune.compression.common import extract_and_store_top_kv
from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule
from vllm.kvprune.kv_cache.store_kv_cache import decode_store_kv, prefill_store_all_kv
from vllm.kvprune.utils.context import Context, get_context
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads: int = num_heads
self.head_dim = head_dim
self.scale: float = scale
self.num_kv_heads = int(num_kv_heads)
self.k_cache: Optional[torch.Tensor] = None
self.v_cache: Optional[torch.Tensor] = None
self.page_table: Optional[torch.Tensor] = None
self.bh_seq_lens: Optional[torch.Tensor] = None
self.page_size: Optional[int] = None
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scores: Optional[torch.Tensor] = None,
):
context: Context = get_context()
batch_mapping = context.batch_mapping
seq_lens = (
None
if self.bh_seq_lens is None
else self.bh_seq_lens.index_select(0, batch_mapping).contiguous()
)
sched = context.attention_schedule
use_triton_prefill_attn = (
sched == KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
)
use_fa_decode = sched == KvpruneAttentionSchedule.PDFA
if context.is_prefill:
seq_lens_copy = seq_lens.clone() if seq_lens is not None else None
if (
self.k_cache is not None
and context.do_compression
and scores is not None
):
compression_context = context.compression_context
assert scores is not None
assert compression_context is not None
maybe_execute_in_stream(
extract_and_store_top_kv,
scores=scores,
cu_seqlens_k=context.cu_seqlens_k,
max_k_len=context.max_seqlen_k,
top_k=compression_context.max_tokens_to_retain,
H=int(self.num_kv_heads),
new_keys=k,
new_vals=v,
num_tokens_to_retain=compression_context.batch_tokens_to_retain,
page_table=self.page_table,
batch_mapping=batch_mapping,
bh_lens=seq_lens,
k_cache=self.k_cache,
v_cache=self.v_cache,
PAGE_SIZE=self.page_size,
PAD_TO_PAGE_SIZE=True,
STORE_STREAM=context.STORE_STREAM,
)
elif self.k_cache is not None:
maybe_execute_in_stream(
prefill_store_all_kv,
new_keys=k,
new_values=v,
cu_seqlens_k=context.cu_seqlens_k,
max_seqlen_k=context.max_seqlen_k,
k_cache=self.k_cache,
v_cache=self.v_cache,
page_table=self.page_table,
bh_lens=seq_lens,
batch_mapping=batch_mapping,
PAGE_SIZE=self.page_size,
STORE_STREAM=context.STORE_STREAM,
)
if use_triton_prefill_attn:
if context.do_compression and context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
assert seq_lens_copy is not None
o = causal_sparse_varlen_with_cache(
q,
k,
v,
self.k_cache,
self.v_cache,
seq_lens_bh=seq_lens_copy,
global_page_table=self.page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_q=context.max_seqlen_q,
max_seqlen_k_cache=context.max_bh_len,
HKV=int(self.num_kv_heads),
PAGE_SIZE=self.page_size,
sm_scale=self.scale,
)
elif context.do_compression:
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
assert seq_lens_copy is not None
o = flash_prefill_from_paged(
q,
k,
v,
self.k_cache,
self.v_cache,
seq_lens_bh_before=seq_lens_copy,
global_page_table=self.page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_q=context.max_seqlen_q,
PAGE_SIZE=self.page_size,
HKV=int(self.num_kv_heads),
sm_scale=self.scale,
)
else:
o = flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q=context.max_seqlen_q,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k,
cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale,
causal=True,
)
else:
assert self.k_cache is not None, "KV Cache must be initialized for decoding"
decode_store_kv(
key=k,
value=v,
batch_mapping=batch_mapping,
bh_lens=seq_lens,
page_table=self.page_table,
k_cache=self.k_cache,
v_cache=self.v_cache,
PAGE_SIZE=self.page_size,
)
if use_fa_decode:
assert seq_lens is not None
o = flash_decode_from_paged(
q,
self.k_cache,
self.v_cache,
seq_lens_bh=seq_lens,
global_page_table=self.page_table,
batch_mapping=batch_mapping,
PAGE_SIZE=self.page_size,
HKV=int(self.num_kv_heads),
sm_scale=self.scale,
)
else:
o = head_sparse_decode_attention(
q,
self.k_cache,
self.v_cache,
seq_lens,
self.page_table,
batch_mapping,
int(self.num_kv_heads),
self.page_size,
self.scale,
key_split=context.key_split,
)
# Match compactor_vllm ``Attention``: ``index_copy_`` into the global
# ``bh_seq_lens`` table. The Triton masked copy was a CUDA fast path but
# disagreed with decode_store_kv / paged attention bookkeeping in edge
# cases and could leave lengths stale → garbage logits / immediate EOS.
if self.bh_seq_lens is not None:
longbm = batch_mapping.to(
device=self.bh_seq_lens.device, dtype=torch.long
)
maybe_execute_in_stream(
self.bh_seq_lens.index_copy_,
0,
longbm,
seq_lens,
STORE_STREAM=context.STORE_STREAM if context.is_prefill else None,
)
return o
import torch
import torch.distributed as dist
import torch.nn.functional as F
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
from torch import nn
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = tensor_parallel_rank_for_sharding()
self.tp_size = tensor_parallel_world_size_for_sharding()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.weight = nn.Parameter(
torch.empty(self.num_embeddings_per_partition, embedding_dim)
)
self.weight.weight_loader = self.weight_loader
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y
tensor_parallel_all_reduce(y)
return y
class ParallelLMHead(VocabParallelEmbedding):
"""LM head with TP vocab sharding.
When embedded in a vLLM worker, logits must be gathered on the **tensor-
parallel** process group (see :func:`~vllm.distributed.communication_op.tensor_model_parallel_gather`),
not the default :func:`torch.distributed.gather` — otherwise shard order / group
mismatch yields garbage logits and decoded gibberish.
After gather, logits are truncated to ``org_vocab_size`` (HF tokenizer vocab),
matching :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
removal of padded vocabulary columns.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
*,
org_vocab_size: int | None = None,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)
# Original (unpadded) vocab size for logits truncation; defaults to num_embeddings.
self.org_vocab_size = (
int(org_vocab_size) if org_vocab_size is not None else num_embeddings
)
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
cu = context.cu_seqlens_q
last_indices = (cu[1:] - 1).to(torch.long)
n_tok = x.shape[0]
if n_tok > 0:
last_indices = last_indices.clamp(min=0, max=n_tok - 1)
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
if self.tp_size > 1:
logits = self._gather_logits_tp(logits)
if logits is not None and logits.shape[-1] > self.org_vocab_size:
logits = logits[..., : self.org_vocab_size]
return logits
def _gather_logits_tp(self, logits: torch.Tensor) -> torch.Tensor | None:
try:
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.distributed.communication_op import (
tensor_model_parallel_gather,
)
if model_parallel_is_initialized():
return tensor_model_parallel_gather(logits, dst=0, dim=-1)
except Exception:
pass
all_logits = (
[torch.empty_like(logits) for _ in range(self.tp_size)]
if self.tp_rank == 0
else None
)
dist.gather(logits, all_logits, 0)
return torch.cat(all_logits, -1) if self.tp_rank == 0 else None
import torch
from torch import nn
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
# @torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x
# @torch.compile
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)
import torch
import torch.distributed as dist
import torch.nn.functional as F
from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
from torch import nn
def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator
class LinearBase(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = tensor_parallel_rank_for_sharding()
self.tp_size = tensor_parallel_world_size_for_sharding()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class ColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = tensor_parallel_world_size_for_sharding()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class MergedColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int
):
param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = tensor_parallel_world_size_for_sharding()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str
):
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
shard_size = self.num_heads * self.head_size
shard_offset = 0
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else:
shard_size = self.num_kv_heads * self.head_size
shard_offset = (
self.num_heads * self.head_size + self.num_kv_heads * self.head_size
)
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = tensor_parallel_world_size_for_sharding()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
tensor_parallel_all_reduce(y)
return y
import torch
import torch.distributed as dist
from vllm.kvprune.triton_kernels.matmul_ogs import matmul_ogs
from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
from torch import nn
def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator
class TritonFusedMoeLinearBase(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
tp_dim: int | None = None,
) -> None:
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = tensor_parallel_rank_for_sharding()
self.tp_size = tensor_parallel_world_size_for_sharding()
self.in_features = in_features
self.out_features = out_features
self.num_experts = num_experts
self.weight = nn.Parameter(
torch.empty((num_experts, in_features, out_features)).transpose(-1, -2)
)
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty((num_experts, out_features)))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
class ReplicatedTritonFusedMoeLinear(TritonFusedMoeLinearBase):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
) -> None:
super().__init__(in_features, out_features, num_experts, bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
):
param.data[expert_idx].copy_(loaded_weight, non_blocking=True)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
w = self.weight.transpose(-1, -2)
assert w.is_contiguous()
return matmul_ogs(
x,
self.weight,
self.bias,
**kwargs,
)
class RowParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
) -> None:
tp_size = (
tensor_parallel_world_size_for_sharding()
if dist.is_initialized()
else 1
)
super().__init__(
divide(in_features, tp_size), out_features, num_experts, bias, 2
)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
):
shard_size = param.size(2)
start_idx = self.tp_rank * shard_size
local_shard = loaded_weight[:, start_idx : start_idx + shard_size]
param.data[expert_idx].copy_(local_shard, non_blocking=True)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
w = self.weight.transpose(-1, -2)
assert w.is_contiguous()
y = matmul_ogs(
x,
w,
self.bias,
**kwargs,
)
if self.tp_size > 1:
tensor_parallel_all_reduce(y)
return y
class ColumnParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
) -> None:
tp_size = (
tensor_parallel_world_size_for_sharding()
if dist.is_initialized()
else 1
)
super().__init__(
in_features, divide(out_features, tp_size), num_experts, bias, 1
)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
):
shard_size = param.size(1)
start_idx = self.tp_rank * shard_size
local_shard = loaded_weight[start_idx : start_idx + shard_size, :]
param.data[expert_idx].copy_(local_shard, non_blocking=True)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
w = self.weight.transpose(-1, -2)
assert w.is_contiguous()
y = matmul_ogs(
x,
w,
self.bias,
**kwargs,
)
return y
class MergedColumnParallelTritonFusedMoeLinear(ColumnParallelTritonFusedMoeLinear):
def __init__(
self,
in_features: int,
out_feature_list: list[int],
num_experts: int,
bias: bool = False,
):
self.out_feature_list = out_feature_list
super().__init__(in_features, sum(out_feature_list), num_experts, bias)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
expert_idx: int,
shard_id: int,
):
param_data = param.data
shard_offset = sum(self.out_feature_list[:shard_id]) // self.tp_size
shard_size = self.out_feature_list[shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
local_weight = loaded_weight.chunk(self.tp_size, dim=self.tp_dim - 1)[
self.tp_rank
]
param_data[expert_idx].copy_(local_weight, non_blocking=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment