Commit 2b7160c6 authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.0

parent fa718036
import argparse
import inspect
import logging
import os
import sys
from pathlib import Path
def _maybe_add_src_to_path() -> None:
# Allow running without `pip install -e .` by pointing to `compactor-vllm/src`.
here = Path(__file__).resolve()
repo_root = here.parents[1]
src_dir = repo_root / "src"
if src_dir.is_dir() and str(src_dir) not in sys.path:
sys.path.insert(0, str(src_dir))
_maybe_add_src_to_path()
from compactor_vllm import LLM, LLMConfig, SamplingParams # noqa: E402
from compactor_vllm.compression import ( # noqa: E402
BatchCompressionParams,
CompressionMethod,
SequenceCompressionParams,
)
from compactor_vllm.config.engine_config import AttentionBackend # noqa: E402
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Minimal smoke test for compactor-vllm (no speculative decoding)."
)
parser.add_argument(
"--model",
type=str,
default=os.environ.get("MODEL", "/mnt/data/llm-models/Qwen3-8B"),
help="Local model directory or HF id. In the container this is usually a local dir.",
)
parser.add_argument(
"--tp",
type=int,
default=int(os.environ.get("TP", "1")),
help="Tensor parallel size (world size).",
)
parser.add_argument(
"--nccl-port",
type=int,
default=int(os.environ.get("NCCL_PORT", "1218")),
help="TCP port for torch.distributed init (only used for NCCL init_method=tcp://localhost:<port>).",
)
parser.add_argument("--max-model-len", type=int, default=2048)
parser.add_argument("--max-num-seqs", type=int, default=2)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9")),
help="Fraction of total GPU memory used for KV cache + activations.",
)
parser.add_argument(
"--attention-backend",
type=str,
default="compactor_triton",
choices=[b.name.lower() for b in AttentionBackend],
)
parser.add_argument(
"--compression-method",
type=str,
default="compactor",
choices=[m.name.lower() for m in CompressionMethod],
)
parser.add_argument(
"--compression-ratio",
type=float,
default=0.8,
help="Sequence-level compression ratio (e.g. 0.8 keeps 80%% of tokens).",
)
parser.add_argument("--chunk-size", type=int, default=512)
parser.add_argument(
"--no-chunked-compression",
dest="do_chunked_compression",
action="store_false",
)
parser.set_defaults(do_chunked_compression=True)
parser.add_argument("--prompt", type=str, default="用一句话介绍你自己,给我讲一个故事,200字左右。")
parser.add_argument("--max-new-tokens", type=int, default=64)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="0.0 = greedy decoding (recommended for smoke tests).",
)
parser.add_argument(
"--tokenizer-enable-thinking",
dest="tokenizer_enable_thinking",
action="store_true",
help="Pass enable_thinking=True to tokenizer.apply_chat_template (if supported).",
)
parser.add_argument(
"--no-tokenizer-enable-thinking",
dest="tokenizer_enable_thinking",
action="store_false",
help="Pass enable_thinking=False to tokenizer.apply_chat_template (if supported).",
)
parser.set_defaults(tokenizer_enable_thinking=False)
parser.add_argument(
"--tokenizer-add-generation-prompt",
dest="tokenizer_add_generation_prompt",
action="store_true",
help="Pass add_generation_prompt=True to tokenizer.apply_chat_template (if supported).",
)
parser.add_argument(
"--no-tokenizer-add-generation-prompt",
dest="tokenizer_add_generation_prompt",
action="store_false",
help="Pass add_generation_prompt=False to tokenizer.apply_chat_template (if supported).",
)
parser.set_defaults(tokenizer_add_generation_prompt=True)
parser.add_argument(
"--tokenizer-continue-final-message",
dest="tokenizer_continue_final_message",
action="store_true",
help="Pass continue_final_message=True to tokenizer.apply_chat_template (if supported).",
)
parser.add_argument(
"--no-tokenizer-continue-final-message",
dest="tokenizer_continue_final_message",
action="store_false",
help="Pass continue_final_message=False to tokenizer.apply_chat_template (if supported).",
)
parser.set_defaults(tokenizer_continue_final_message=False)
parser.add_argument(
"--skip-special-tokens",
dest="skip_special_tokens",
action="store_true",
help="Skip special tokens in output decoding (recommended).",
)
parser.add_argument(
"--no-skip-special-tokens",
dest="skip_special_tokens",
action="store_false",
help="Keep special tokens in output decoding (e.g. <|im_end|>).",
)
parser.set_defaults(skip_special_tokens=True)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
)
return parser.parse_args()
def main() -> None:
args = _parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s - %(levelname)s - %(message)s",
)
attention_backend = AttentionBackend[args.attention_backend.upper()]
compression_method = CompressionMethod[args.compression_method.upper()]
model = args.model
cfg = LLMConfig(
model=model,
path=model,
tensor_parallel_size=int(args.tp),
nccl_port=int(args.nccl_port),
max_model_len=int(args.max_model_len),
max_num_seqs=int(args.max_num_seqs),
gpu_memory_utilization=float(args.gpu_memory_utilization),
enforce_eager=True,
attention_backend=attention_backend,
show_progress_bar=False,
)
llm = LLM(cfg)
tokenizer_kwargs = {
"add_generation_prompt": bool(args.tokenizer_add_generation_prompt),
"enable_thinking": bool(args.tokenizer_enable_thinking),
"continue_final_message": bool(args.tokenizer_continue_final_message),
}
if tokenizer_kwargs.get("add_generation_prompt") and tokenizer_kwargs.get(
"continue_final_message"
):
# HF tokenizer API rejects these being simultaneously True.
tokenizer_kwargs["continue_final_message"] = False
# Be defensive: only pass kwargs supported by this tokenizer build.
try:
supported = set(inspect.signature(llm.tokenizer.apply_chat_template).parameters)
tokenizer_kwargs = {k: v for k, v in tokenizer_kwargs.items() if k in supported}
except (TypeError, ValueError):
pass
outs = llm.generate_chat(
[[{"role": "user", "content": args.prompt}]],
sampling_params=SamplingParams(
temperature=float(args.temperature),
max_new_tokens=int(args.max_new_tokens),
),
batch_compression_params=BatchCompressionParams(
compression_method=compression_method,
do_chunked_compression=bool(args.do_chunked_compression),
chunk_size=int(args.chunk_size),
),
per_sequence_compression_params=SequenceCompressionParams(
compression_ratio=float(args.compression_ratio),
),
tokenizer_kwargs=tokenizer_kwargs,
detokenizer_kwargs={"skip_special_tokens": bool(args.skip_special_tokens)},
)
print(outs[0])
llm.exit()
if __name__ == "__main__":
main()
Package Version
---------------------------------- ------------------------------------------
accelerate 1.12.0
addict 2.4.0
aiofiles 25.1.0
aiohappyeyeballs 2.6.1
aiohttp 3.13.2
aiohttp-cors 0.8.1
aiosignal 1.4.0
airportsdata 20250909
amdsmi 24.5.3+02cbffb.dirty
annotated-doc 0.0.4
annotated-types 0.7.0
anyio 4.12.0
apex 1.5.0+das.opt1.dtk25042
astor 0.8.1
async-timeout 5.0.1
attrs 25.4.0
backports.asyncio.runner 1.2.0
blake3 1.0.8
blinker 1.9.0
boto3 1.42.10
botocore 1.42.10
cachetools 6.2.4
certifi 2025.11.12
charset-normalizer 3.4.4
click 8.2.1
cloudpickle 3.1.2
cmake 3.29.0
coloredlogs 15.0.1
colorful 0.5.8
compressed-tensors 0.10.2
contourpy 1.3.2
cryptography 3.4.8
cupy 12.3.0
cycler 0.12.1
datasets 4.4.1
dbus-python 1.2.18
dcu-megatron 0.13.0+das.opt1.dtk25042
deepspeed 0.15.4+das.opt1.dtk25042
depyf 0.18.0
dgl 2.2.1+das.opt1.dtk25042
dill 0.4.0
diskcache 5.6.3
distlib 0.4.0
distro 1.7.0
dnspython 2.8.0
dropout_layer_norm 2.6.1+das.opt1.dtk2504
eft 0.0.7
einops 0.8.1
email-validator 2.3.0
exceptiongroup 1.3.1
fastapi 0.124.4
fastapi-cli 0.0.16
fastapi-cloud-cli 0.6.0
fastar 0.8.0
fastpt 2.1.1+das.dtk25042
fastrlock 0.8.3
filelock 3.20.1
flash_attn 2.6.1+das.opt1.dtk2504.20251216.gbd5c0f0c
flash_mla 1.0.0+das.opt1.dtk2504.20251210.g124c5ef1
Flask 3.1.2
flatbuffers 25.9.23
fonttools 4.61.1
frozenlist 1.8.0
fsspec 2025.12.0
fused_dense_lib 2.6.1+das.opt1.dtk2504
future 1.0.0
gguf 0.17.1
google-api-core 2.28.1
google-auth 2.45.0
googleapis-common-protos 1.72.0
greenlet 3.3.0
grouped-gemm 0.5.0+das.dtk2504
grouped-gemm-int4 0.5.0+das.dtk2504
grpcio 1.76.0
h11 0.16.0
h2 4.3.0
hf-xet 1.2.0
hiredis 3.3.0
hjson 3.1.0
hpack 4.1.0
httpcore 1.0.9
httplib2 0.20.2
httptools 0.7.1
httpx 0.28.1
huggingface-hub 0.36.0
humanfriendly 10.0
humanize 4.14.0
Hypercorn 0.18.0
hyperframe 6.1.0
hypothesis 5.35.1
idna 3.11
importlib_metadata 8.7.0
iniconfig 2.3.0
interegular 0.3.3
itsdangerous 2.2.0
jeepney 0.7.1
Jinja2 3.1.6
jiter 0.12.0
jmespath 1.0.1
jsonschema 4.25.1
jsonschema-specifications 2025.9.1
keyring 23.5.0
kiwisolver 1.4.9
lark 1.2.2
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
libnacl 2.1.0
lightop 0.6.0+das.dtk25042.20251216.g3830d4e2
llguidance 0.7.30
llvmlite 0.44.0
lm-format-enforcer 0.10.12
lmslim 0.3.1+das.opt1.dtk25042.20251202.g07a5af3e
markdown-it-py 4.0.0
MarkupSafe 3.0.3
matplotlib 3.10.8
mdurl 0.1.2
megatron-core 0.13.2
mistral_common 1.8.6
mmcv 2.2.0+das.opt1.dtk25042
mmengine 0.10.7
moe-w8a8 0.0.1+das.dtk2504
moe-w8a8-prefill-gemm 0.0.1+das.dtk2504
more-itertools 8.10.0
mpmath 1.3.0
msgpack 1.1.2
msgspec 0.20.0
multidict 6.7.0
multiprocess 0.70.18
nest-asyncio 1.6.0
networkx 3.4.2
ninja 1.11.1
numa 1.4.6
numba 0.61.2
numpy 1.25.0
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
oauthlib 3.2.0
onnxruntime 1.19.2+das.opt1.dtk25042
openai 1.90.0
opencensus 0.11.4
opencensus-context 0.1.3
opencv-python 4.12.0.88
opencv-python-headless 4.12.0.88
opentelemetry-api 1.39.1
opentelemetry-exporter-prometheus 0.60b1
opentelemetry-proto 1.39.1
opentelemetry-sdk 1.39.1
opentelemetry-semantic-conventions 0.60b1
outlines 0.1.11
outlines_core 0.1.26
packaging 25.0
pandas 2.3.3
partial-json-parser 0.2.1.1.post7
peft 0.18.0
pillow 12.0.0
pip 25.3
platformdirs 4.5.1
pluggy 1.6.0
priority 2.0.0
prometheus_client 0.23.1
prometheus-fastapi-instrumentator 7.1.0
propcache 0.4.1
proto-plus 1.26.1
protobuf 6.33.2
psutil 7.1.3
py-cpuinfo 9.0.0
py-spy 0.4.1
pyarrow 22.0.0
pyasn1 0.6.1
pyasn1_modules 0.4.2
pybase64 1.4.3
pycountry 24.6.1
pydantic 2.12.5
pydantic_core 2.41.5
pydantic-extra-types 2.10.6
Pygments 2.19.2
PyGObject 3.42.1
PyHive 0.7.0
PyJWT 2.3.0
PyMySQL 1.1.2
pyparsing 3.2.5
pytest 9.0.2
pytest-asyncio 1.3.0
python-apt 2.4.0+ubuntu4
python-dateutil 2.9.0.post0
python-dotenv 1.2.1
python-json-logger 4.0.0
python-multipart 0.0.20
PyTrie 0.4.0
pytz 2025.2
PyYAML 6.0.3
pyzmq 27.1.0
Quart 0.20.0
ray 2.48.0
redis 7.1.0
referencing 0.37.0
regex 2025.11.3
requests 2.32.5
rich 14.2.0
rich-toolkit 0.17.0
rignore 0.7.6
rotary_emb 2.6.1+das.opt1.dtk2504
rpds-py 0.30.0
rsa 4.9.1
runai-model-streamer 0.11.0
runai-model-streamer-s3 0.11.0
s3transfer 0.16.0
safetensors 0.7.0
scipy 1.15.3
SecretStorage 3.3.1
sentencepiece 0.2.1
sentry-sdk 2.47.0
setuptools 80.8.0
setuptools-scm 9.2.2
shellingham 1.5.4
six 1.16.0
smart_open 7.5.0
sniffio 1.3.1
sortedcontainers 2.4.0
SQLAlchemy 2.0.45
starlette 0.50.0
sympy 1.13.1
taskgroup 0.2.2
tensorboardX 2.6.4
tensorizer 2.12.0
termcolor 3.2.0
threadpoolctl 3.6.0
tiktoken 0.12.0
tokenizers 0.22.1
tomli 2.3.0
torch 2.5.1+das.opt1.dtk25042
torchaudio 2.5.1+das.opt1.dtk25042
torchdata 0.8.0
torchvision 0.20.1+das.opt1.dtk25042
tqdm 4.67.1
transformer_engine 2.5.0+das.opt1.dtk25042
transformers 4.57.3
triton 3.1+das.opt1.dtk25042
typer 0.20.0
typer-slim 0.20.0
typing_extensions 4.15.0
typing-inspection 0.4.2
tzdata 2025.3
urllib3 2.6.2
uvicorn 0.38.0
uvloop 0.22.1
virtualenv 20.35.4
vllm 0.9.2+das.opt2.ffcc47b.dtk25042
wadllib 1.3.6
watchfiles 1.1.1
websockets 15.0.1
Werkzeug 3.1.4
wheel 0.37.1
wrapt 2.0.1
wsproto 1.3.2
xentropy_cuda_lib 2.6.1+das.opt1.dtk2504
xgrammar 0.1.19
xxhash 3.6.0
yapf 0.43.0
yarl 1.22.0
zipp 3.23.0
[project]
name = "compactor-vllm"
description = "Fast KV Cache Compression for LLMs"
version = "0.0.1"
dependencies = [
# "triton>=3.5.0",
"transformers",
# "torch>=2.9.0",
"safetensors",
"tqdm",
"flash-attn",
"pytest"
]
requires-python = ">= 3.8"
authors = [
{name = "Vivek Chari", email = "viveknchari@gmail.com"},
]
[project.optional-dependencies]
evaluate = ["rouge", "pandas", "fuzzywuzzy"]
[tool.ruff]
exclude = [
"triton_kernels"
]
from compactor_vllm.compression import CompressionMethod
from compactor_vllm.config.engine_config import AttentionBackend, LLMConfig
from compactor_vllm.config.sampling_params import SamplingParams
from compactor_vllm.core.llm_engine import LLMEngine as _LLMEngine
class LLM(_LLMEngine):
pass
__all__ = [
"LLMConfig",
"LLM",
"SamplingParams",
"AttentionBackend",
"CompressionMethod",
]
import argparse
import logging
import math
import torch
from compactor_vllm.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
logger = logging.getLogger(__name__)
def build_mock_paged_cache_from_lengths(
L_cache_per_b: torch.Tensor,
HKV: int,
D: int,
PAGE_SIZE: int,
N_LOGICAL_PAGES_MAX: int,
device,
dtype,
):
B = len(L_cache_per_b)
max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
assert (L_cache_per_b <= max_len).all()
seq_lens_bh = torch.empty((B, HKV), dtype=torch.int32, device=device)
for b in range(B):
seq_lens_bh[b, :].fill_(L_cache_per_b[b])
num_phys_pages = B * HKV * N_LOGICAL_PAGES_MAX
CACHE_SIZE = num_phys_pages * PAGE_SIZE
K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
page_table = torch.empty(
(B, HKV, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
)
# assign unique physical pages per (b, h, lp)
phys_page = 0
for b in range(B):
for h in range(HKV):
for lp in range(N_LOGICAL_PAGES_MAX):
page_table[b, h, lp] = phys_page
phys_page += 1
for b in range(B):
Lc = int(L_cache_per_b[b].item())
for h in range(HKV):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, h, lp].item())
idx = phys * PAGE_SIZE + off
K_cache[idx] = torch.randn(D, device=device, dtype=dtype)
V_cache[idx] = torch.randn(D, device=device, dtype=dtype)
return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
def autotune_causal_sparse_varlen_with_cache(
*,
max_length: int = 16384,
HKV: int = 8,
HQ: int = 32,
D: int = 128,
PAGE_SIZE: int = 128,
device: str = "cuda",
dtype=torch.float16,
):
"""
Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
"""
import itertools
import tqdm
N_LOGICAL_PAGES_MAX = ((max_length + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
B = 4
# D must be a power of two (kernel requirement).
assert (D & (D - 1)) == 0
lengths_to_sweep = [0, 256]
i = 9
while (v := (1 << i)) < max_length:
lengths_to_sweep.append(v)
i += 1
combos = list(itertools.product(lengths_to_sweep, repeat=2))
logger.info(
"tuning kernels. this may take a few minutes, "
"but only needs to be run once per LLMConfig"
)
for cache_l, append_l in tqdm.tqdm(combos):
if cache_l + append_l == 0:
continue
L_cache_per_b = torch.tensor(
[cache_l] * B,
device=device,
dtype=torch.int32,
)
assert (L_cache_per_b <= PAGE_SIZE * N_LOGICAL_PAGES_MAX).all()
K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
build_mock_paged_cache_from_lengths(
L_cache_per_b=L_cache_per_b,
HKV=HKV,
D=D,
PAGE_SIZE=PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
device=device,
dtype=dtype,
)
)
L_app_list = [append_l] * B
cu = [0]
for L in L_app_list:
cu.append(cu[-1] + L)
cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
N = int(cu_seqlens_qk[-1].item())
max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
max_seqlen_k = seq_lens_bh.max().item()
q_raw = torch.randn(N, HQ, D, device=device, dtype=dtype)
k_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
v_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
# Identity batch mapping (local batch index == global)
batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
sm_scale = 1.0 / math.sqrt(D)
causal_sparse_varlen_with_cache(
q=q_raw,
k_cache=K_cache,
v_cache=V_cache,
k=k_append_raw,
v=v_append_raw,
seq_lens_bh=seq_lens_bh,
global_page_table=page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=cu_seqlens_qk,
HKV=HKV,
PAGE_SIZE=PAGE_SIZE,
sm_scale=sm_scale,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_cache=max_seqlen_k,
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Autotune Triton kernels. "
"Results are cached, so this should only need to be run once per configuration."
"This script doesn't need to be run, as the kernels will be autotuned at runtime"
"if no cached autotuning data exists. Running this before hand will prevent run-time"
"autotuning, which will accelerate compactor-vllm at inference time."
)
parser.add_argument(
"--max-length",
type=int,
default=16384,
help="Maximum total sequence length to consider.",
)
parser.add_argument(
"--HKV",
type=int,
default=8,
help="Number of KV heads.",
)
parser.add_argument(
"--HQ",
type=int,
default=32,
help="Number of query heads.",
)
parser.add_argument(
"--D",
type=int,
default=128,
help="Per-head hidden dimension (must be power of 2).",
)
parser.add_argument(
"--page-size",
type=int,
default=128,
help="Page size (tokens per physical page).",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu').",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
help="Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}.",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
help="Logging level.",
)
return parser.parse_args()
def _resolve_dtype(dtype_str: str):
s = dtype_str.lower()
if s in ("float16", "fp16", "half"):
return torch.float16
if s in ("bfloat16", "bf16"):
return torch.bfloat16
if s in ("float32", "fp32"):
return torch.float32
raise ValueError(f"Unsupported dtype: {dtype_str}")
def main():
args = _parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
dtype = _resolve_dtype(args.dtype)
logger.info(
"Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
"device=%s, dtype=%s",
args.max_length,
args.HKV,
args.HQ,
args.D,
args.page_size,
args.device,
dtype,
)
autotune_causal_sparse_varlen_with_cache(
max_length=args.max_length,
HKV=args.HKV,
HQ=args.HQ,
D=args.D,
PAGE_SIZE=args.page_size,
device=args.device,
dtype=dtype,
)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
)
main()
import functools
import math
import torch
import triton
import triton.language as tl
from compactor_vllm.utils.triton_compat import (
autotune as triton_autotune,
maybe_set_allocator,
)
def head_sparse_decode_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_lens_bh: torch.Tensor,
global_page_table: torch.Tensor,
batch_mapping: torch.Tensor,
HKV: int,
PAGE_SIZE: int,
sm_scale: float = None,
key_split: int = None,
):
"""
Decode-time head-sparse attention over a paged KV cache.
This is a wrapper around the Triton decode kernel used during incremental
generation. For each batch, we read the cached keys
and values from a global paged KV buffer, apply causal attention with one
new query token, and return the attention output.
The KV cache is stored in a single global K/V tensor of shape
``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
(batch, kv_head, token_idx) is mapped to a physical row in the cache by:
1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
Grouped-query attention (GQA / MQA) is supported by passing more query
heads than KV heads (``HQ`` must be a multiple of ``HKV``).
Args:
:param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]``
containing the new decode tokens for each sequence in the launch batch.
:param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
backing buffer for all (batch, head) KV pages.
:param v: Global value cache of shape ``[CACHE_SIZE, D]``.
:param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
local batch index and KV head, the number of valid cached tokens
in the paged KV cache.
:param global_page_table: Tensor of shape
``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
``(true_batch_idx, kv_head, logical_page)`` to a physical page id
in the global cache.
:param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
index used by this call to the true batch row used to index
``global_page_table``.
:param HKV: Number of KV heads.
:param PAGE_SIZE: Number of tokens stored per physical KV page.
:param sm_scale: Optional scaling factor applied to the attention logits
before softmax. If ``None``, ``1 / sqrt(D)`` is used.
:param key_split: Optional number of splits along the key sequence length.
If > 1, the kernel will process the KV sequence in ``key_split``
chunks to reduce on-chip memory usage. If ``None`` or 0, a
heuristic is used.
Returns:
:return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
device and dtype as ``q``.
"""
with torch.cuda.device(q.device):
if q.ndim != 3:
assert q.ndim == 4
B, HQ, S, D = q.shape
assert S == 1, "head_sparse_decode_attention only supports q_len=1"
q = q.squeeze(-2)
elif q.ndim == 3:
B, HQ, D = q.shape
CACHE_SIZE = k.shape[0]
assert PAGE_SIZE % 32 == 0, "PAGE_SIZE must be divisible by 128"
GROUP_M = HQ // HKV
assert GROUP_M * HKV == HQ, "HQ must be divisible by H_kv"
FP8 = hasattr(torch, "float8_e5m2") and q.dtype == torch.float8_e5m2
seq_lens_bh = seq_lens_bh.to(torch.int32)
assert B <= 32767, "too many batches"
assert global_page_table.shape[1] == HKV
assert q.is_contiguous()
assert (D & (D - 1)) == 0, "D must be a power of 2"
N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
sm_scale = 1 / math.sqrt(D) if sm_scale is None else sm_scale
if key_split is None:
# round max_seq_len to the next power of two to maximize cache hits
key_split = num_splits_heuristic(
B * HKV,
max_seq_len=1 << int(seq_lens_bh.max()).bit_length(),
num_sms=torch.cuda.get_device_properties(
q.device
).multi_processor_count,
max_splits=12,
)
maybe_set_allocator(
lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
)
# stage 1 scratch
mid_o = torch.empty((B, key_split, HQ, D), device=q.device, dtype=q.dtype)
mid_lse = torch.empty((B, key_split, HQ), device=q.device, dtype=torch.float32)
# processes all queries for a KV head together
# pointers are lowercase, CONSTANTS are upper
grid1 = (B, HKV, key_split)
_varkv_stage1_groupM[grid1](
q=q,
k=k,
v=v,
mid_o=mid_o,
mid_lse=mid_lse,
page_table_bhl=global_page_table,
batch_mapping=batch_mapping,
seq_lens_bh=seq_lens_bh.contiguous(),
SM_SCALE=sm_scale,
B=B,
HKV=HKV,
HQ=HQ,
CACHE_SIZE=CACHE_SIZE,
STRIDE_LBS=mid_lse.stride(0),
STRIDE_LS=mid_lse.stride(1),
STRIDE_LH=mid_lse.stride(2),
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
D=D,
KEY_SPLIT=key_split,
GROUP_M=GROUP_M,
DTYPE=tl.float8e5
if FP8
else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
PAGE_SIZE=PAGE_SIZE,
)
if key_split == 1:
return mid_o.squeeze(1).contiguous()
# reduce partial results across splits
output = torch.empty_like(q)
grid2 = (B, HQ)
_varkv_stage2_reduce[grid2](
mid_o=mid_o,
mid_lse=mid_lse,
output=output,
STRIDE_LBS=mid_lse.stride(0),
STRIDE_LS=mid_lse.stride(1),
STRIDE_LH=mid_lse.stride(2),
STRIDE_OBS=output.stride(0),
STRIDE_OH=output.stride(1),
B=B,
HQ=HQ,
D=D, # type: ignore
KEY_SPLIT=key_split, # type: ignore
DTYPE=tl.float8e5
if FP8
else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
)
return output
# similar to flash attention split heuristic
@functools.lru_cache(maxsize=128)
def num_splits_heuristic(
total_mblocks: int,
max_seq_len: int,
num_sms: int,
max_splits: int,
) -> int:
# If we nearly fill SMs already, prefer 1 split
if total_mblocks >= 0.8 * num_sms or max_seq_len <= 1024:
return 1
eff = []
max_eff = 0.0
for s in range(1, min(max_splits, num_sms) + 1):
if (max_seq_len / s) <= 512:
break
n_waves = float(total_mblocks * s) / float(num_sms)
e = n_waves / math.ceil(n_waves) if n_waves > 0 else 0.0
eff.append(e)
max_eff = max(max_eff, e)
threshold = 0.75 * max_eff # if not split_min_hit else 0.9 * max_eff
for i, e in enumerate(eff, start=1):
if e >= threshold:
return i
return 1
def prune_invalid_configs(configs, _, **kwargs):
PAGE_SIZE = kwargs["PAGE_SIZE"]
return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) <= PAGE_SIZE]
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_N": BLOCK_N, "MIN_BLOCK_KV": MIN_BLOCK_KV, "WARPSPEC": ws},
num_warps=w,
num_stages=s,
)
for BLOCK_N in [32, 64, 128]
for MIN_BLOCK_KV in [8]
for s in [2, 3, 4]
for w in [4, 8]
for ws in [True, False]
],
key=[
"HKV",
"GROUP_M",
"D",
"PAGE_SIZE", # "B"
],
cache_results=True,
prune_configs_by={"early_config_prune": prune_invalid_configs},
)
@triton.jit
def _varkv_stage1_groupM(
q, # [B, HQ, D] contiguous
k, # GLOBAL cache: [CACHE_SIZE, D], contiguous
v, # GLOBAL cache: [CACHE_SIZE, D], contiguous
mid_o,
mid_lse,
page_table_bhl, # int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
batch_mapping, # int32 [B] maps local pid_b -> true batch index
seq_lens_bh, # int32 [B*H_kv] valid tokens per (b,h)
SM_SCALE,
B,
HKV,
HQ,
CACHE_SIZE, # CACHE_SIZE = N_PAGES * PAGE_SIZE
STRIDE_LBS,
STRIDE_LS,
STRIDE_LH,
# constexprs
N_LOGICAL_PAGES_MAX: tl.constexpr, # page table width per (b,h)
D: tl.constexpr,
KEY_SPLIT: tl.constexpr,
GROUP_M: tl.constexpr,
DTYPE: tl.constexpr,
BLOCK_N: tl.constexpr,
MIN_BLOCK_KV: tl.constexpr,
WARPSPEC: tl.constexpr,
PAGE_SIZE: tl.constexpr,
):
pid_b = tl.program_id(0) # batch
pid_kvh = tl.program_id(1) # kv head
pid_s = tl.program_id(2) # split
# valid length L for this (b,h)
bh_stride = HKV
L = tl.load(seq_lens_bh + pid_b * bh_stride + pid_kvh)
if L == 0:
return
tl.assume(L > 0)
# split sizing on logical token axis [0..L)
base = tl.cdiv(L, KEY_SPLIT)
per_split_len = tl.cdiv(base, MIN_BLOCK_KV) * MIN_BLOCK_KV
split_start = pid_s * per_split_len
split_end = tl.minimum(split_start + per_split_len, L)
# query heads mapped to this kv head
base_qh = pid_kvh * GROUP_M
GROUP_M_PAD: tl.constexpr = 16 if GROUP_M < 16 else GROUP_M
offs_m = tl.arange(0, GROUP_M_PAD)
mask_m = offs_m < GROUP_M
offs_d = tl.arange(0, D)
# load Q tile [M, D]
q_ptrs = q + (pid_b * HQ + base_qh + offs_m)[:, None] * D + offs_d[None, :]
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0).to(DTYPE) # [M, D]
# streaming softmax state per query
e_max = tl.zeros([GROUP_M_PAD], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([GROUP_M_PAD], dtype=tl.float32)
acc = tl.zeros([GROUP_M_PAD, D], dtype=tl.float32)
if split_end > split_start:
# logical pages covering [split_start, split_end)
lp0 = split_start // PAGE_SIZE
lp1 = tl.cdiv(split_end, PAGE_SIZE) # exclusive
mapped_b = tl.load(batch_mapping + pid_b)
tl.assume(mapped_b >= 0)
# page table base for this (b,h)
pt_stride = N_LOGICAL_PAGES_MAX
pt_base = (mapped_b * HKV + pid_kvh) * pt_stride
for lp in tl.range(lp0, lp1):
phys = tl.load(
page_table_bhl + pt_base + lp, cache_modifier=".cg"
) # physical page id
# bounds within the logical page
local_start = tl.where(lp == lp0, split_start - lp * PAGE_SIZE, 0)
local_end = tl.where(lp == (lp1 - 1), split_end - lp * PAGE_SIZE, PAGE_SIZE)
page_base = phys * PAGE_SIZE
page_base = tl.multiple_of(page_base, BLOCK_N)
for s in tl.range(local_start, local_end, BLOCK_N):
s = tl.multiple_of(s, MIN_BLOCK_KV)
offs_bn = tl.arange(0, BLOCK_N)
key_idx = page_base + s + offs_bn
k_ptrs = k + key_idx[:, None] * D + offs_d[None, :]
k_blk = tl.load(k_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
qk = tl.dot(q, k_blk.T) * SM_SCALE # [M, BN]
offs_n = s + tl.arange(0, BLOCK_N)
mask_n = offs_n < local_end
qk = tl.where(mask_n[None, :], qk, -float("inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max) # [M]
re_scale = tl.exp(e_max - n_e_max) # [M]
acc = acc * re_scale[:, None] # [M, D]
v_ptrs = v + key_idx[:, None] * D + offs_d[None, :]
v_blk = tl.load(v_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
p = tl.exp(qk - n_e_max[:, None]) # [M, BN]
acc = tl.dot(p.to(DTYPE), v_blk, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# write mid outputs [M, D] for this split
tmp = (acc / e_sum[:, None]).to(DTYPE)
row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
tl.store(mid_ptrs, tmp, mask=mask_m[:, None])
ml_ptrs = (
mid_lse
+ pid_b * STRIDE_LBS
+ pid_s * STRIDE_LS
+ (base_qh + offs_m) * STRIDE_LH
)
safe_sum = tl.where(mask_m, e_sum, 1.0)
tl.store(ml_ptrs, e_max + tl.log(safe_sum), mask=mask_m)
else:
# empty split
zero_md = tl.zeros([GROUP_M_PAD, D], dtype=DTYPE)
row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
tl.store(mid_ptrs, zero_md, mask=mask_m[:, None])
ml_ptrs = (
mid_lse
+ pid_b * STRIDE_LBS
+ pid_s * STRIDE_LS
+ (base_qh + offs_m) * STRIDE_LH
)
tl.store(ml_ptrs, -float("inf"), mask=mask_m)
@triton.jit
def _varkv_stage2_reduce(
mid_o,
mid_lse,
output,
STRIDE_LBS,
STRIDE_LS,
STRIDE_LH,
STRIDE_OBS,
STRIDE_OH,
B,
HQ,
D: tl.constexpr,
KEY_SPLIT: tl.constexpr,
DTYPE: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
offs_d = tl.arange(0, D)
# across split LSE combine
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([D], dtype=tl.float32)
for s in tl.range(KEY_SPLIT):
row_mid = pid_b * (KEY_SPLIT * HQ) + s * HQ + pid_h
tv = tl.load(mid_o + row_mid * D + offs_d).to(DTYPE)
tl_ptr = mid_lse + pid_b * STRIDE_LBS + s * STRIDE_LS + pid_h * STRIDE_LH
tlogic = tl.load(tl_ptr)
n_e_max = tl.maximum(e_max, tlogic)
old_scale = tl.exp(e_max - n_e_max)
acc = acc * old_scale + tl.exp(tlogic - n_e_max) * tv.to(tl.float32)
e_sum = e_sum * old_scale + tl.exp(tlogic - n_e_max)
e_max = n_e_max
o = (acc / e_sum).to(DTYPE)
o_ptr = output + pid_b * STRIDE_OBS + pid_h * STRIDE_OH + offs_d
tl.store(o_ptr, o)
from compactor_vllm.compression.common import (
BaseCompressionMethod,
NoCompression,
)
from compactor_vllm.compression.criticalkv import CriticalAdaKVCompression
from compactor_vllm.compression.compactor import CompactorCompression
from compactor_vllm.compression.compression_config import (
BatchCompressionParams,
CompressionMethod,
SequenceCompressionParams,
)
from compactor_vllm.compression.snapkv import SnapKVCompression
COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = {
CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression,
CompressionMethod.COMPACTOR: CompactorCompression,
CompressionMethod.SNAPKV: SnapKVCompression,
CompressionMethod.NONE: NoCompression,
}
def apply_prerope_compression(q, k, v, context):
method = context.compression_context.compression_method
return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context)
def apply_postrope_compression(q, k, v, prerope_scores, context):
method = context.compression_context.compression_method
return COMPRESSION_REGISTRY[method].post_rope_scoring(
q, k, v, prerope_scores, context=context
)
__all__ = [
"apply_prerope_compression",
"apply_postrope_compression",
"CompressionMethod",
"BatchCompressionParams",
"SequenceCompressionParams",
"COMPRESSION_REGISTRY"
]
This diff is collapsed.
import logging
from dataclasses import dataclass
from enum import Enum, auto
logger = logging.getLogger(__name__)
class CompressionMethod(Enum):
CRITICALADAKV = auto()
COMPACTOR = auto()
SNAPKV = auto()
NONE = auto()
# class CachingPolicy(Enum):
# CACHE_PROMPT = auto()
# DONT_CACHE = auto()
# class CompressionType(Enum):
# QUERY_AWARE = auto()
# QUERY_AGNOSTIC = auto()
@dataclass
class SequenceCompressionParams:
compression_ratio: float = 1.0
protected_first_tokens: int = 16
protected_last_tokens: int = 64
@dataclass
class BatchCompressionParams:
# compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
compression_method: CompressionMethod = CompressionMethod.COMPACTOR
do_chunked_compression: bool = True
chunk_size: int = 512
def __post_init__(self):
if self.compression_method == CompressionMethod.SNAPKV:
self.do_chunked_compression = False
logger.warning(
"CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment