Unverified Commit f4b42df0 authored by Vibhav Agarwal's avatar Vibhav Agarwal Committed by GitHub
Browse files

[Attention Backend] TurboQuant: 2-bit KV cache compression with 4x capacity (#38479)


Signed-off-by: default avatarvibhavagarwal5 <vibhavagarwal5@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 3bfe55a0
......@@ -91,6 +91,16 @@ steps:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
- label: LM Eval TurboQuant KV Cache
timeout_in_minutes: 75
source_file_dependencies:
- vllm/model_executor/layers/quantization/turboquant/
- vllm/v1/attention/backends/turboquant_attn.py
- vllm/v1/attention/ops/triton_turboquant_decode.py
- vllm/v1/attention/ops/triton_turboquant_store.py
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/models-turboquant.txt
- label: GPQA Eval (GPT-OSS) (H100)
timeout_in_minutes: 120
device: h100
......
......@@ -178,6 +178,7 @@ Priority is **1 = highest** (tried first).
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
| `TURBOQUANT` | | fp16, bf16 | `turboquant_k8v4`, `turboquant_4bit_nc`, `turboquant_k3v4_nc`, `turboquant_3bit_nc` | 16, 32, 64, 128 | Any | ❌ | ❌ | ❌ | Decoder | Any |
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
>
......
......@@ -170,6 +170,9 @@ eles = "eles"
datas = "datas"
ser = "ser"
ure = "ure"
# Walsh-Hadamard Transform
wht = "wht"
WHT = "WHT"
[tool.uv]
no-build-isolation-package = ["torch"]
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.78
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096"
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096"
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.75
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096"
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096"
Qwen3-4B-TQ-k8v4.yaml
Qwen3-4B-TQ-t4nc.yaml
Qwen3-4B-TQ-k3v4nc.yaml
Qwen3-4B-TQ-t3nc.yaml
This diff is collapsed.
......@@ -27,6 +27,11 @@ class AttentionConfig:
flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""
tq_max_kv_splits_for_cuda_graph: int = 32
"""TurboQuant max NUM_KV_SPLITS for cuda graph decode.
Fixes the split count so grid dimensions are constant across captures,
and buffers can be pre-allocated to avoid inflating the memory estimate."""
use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""
......
......@@ -24,6 +24,10 @@ CacheDType = Literal[
"fp8_e5m2",
"fp8_inc",
"fp8_ds_mla",
"turboquant_k8v4",
"turboquant_4bit_nc",
"turboquant_k3v4_nc",
"turboquant_3bit_nc",
"int8_per_token_head",
"fp8_per_token_head",
]
......
......@@ -1642,6 +1642,31 @@ class EngineArgs:
kv_offloading_backend=self.kv_offloading_backend,
)
# TurboQuant: auto-skip first/last 2 layers (boundary protection).
# These layers are most sensitive to quantization error.
# Users can add extra layers via --kv-cache-dtype-skip-layers.
if resolved_cache_dtype.startswith("turboquant_"):
if model_config.is_hybrid:
raise NotImplementedError(
"TurboQuant KV cache is not supported for hybrid "
"(attention + Mamba) models. Boundary layer protection "
"requires uniform attention layers."
)
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
num_layers = model_config.hf_text_config.num_hidden_layers
boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers)
existing = set(cache_config.kv_cache_dtype_skip_layers)
merged = sorted(existing | set(boundary), key=lambda x: int(x))
cache_config.kv_cache_dtype_skip_layers = merged
logger.info(
"TQ: skipping layers %s for boundary protection (num_layers=%d)",
merged,
num_layers,
)
ray_runtime_env = None
if is_ray_initialized():
# Ray Serve LLM calls `create_engine_config` in the context
......@@ -1948,6 +1973,19 @@ class EngineArgs:
self.attention_backend
)
# TurboQuant requires FlashAttention 2 — FA3 boundary layers assert
# FlashAttentionImpl which fails with TurboQuantAttentionImpl.
if resolved_cache_dtype.startswith("turboquant_") and (
attention_config.flash_attn_version is None
or attention_config.flash_attn_version >= 3
):
logger.warning(
"TurboQuant is not yet compatible with FlashAttention >= 3. "
"Overriding flash_attn_version to 2. To silence this "
"warning, pass --attention-config.flash_attn_version=2"
)
attention_config.flash_attn_version = 2
# Mamba config overrides
mamba_config = copy.deepcopy(self.mamba_config)
# Convert string to enum if needed (CLI parsing returns a string)
......
......@@ -379,6 +379,10 @@ class Attention(nn.Module, AttentionLayerBase):
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
# Initialize TurboQuant buffers (Pi, S, centroids) if tq cache dtype
if kv_cache_dtype.startswith("turboquant_"):
self._init_turboquant_buffers(kv_cache_dtype, head_size, prefix)
# for attn backends supporting query quantization
self.query_quant = None
if (
......@@ -397,6 +401,67 @@ class Attention(nn.Module, AttentionLayerBase):
else GroupShape.PER_TENSOR,
)
def _init_turboquant_buffers(
self, cache_dtype: str, head_size: int, prefix: str
) -> None:
"""Initialize TurboQuant rotation/projection matrices and centroids."""
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
)
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_wht_signs,
)
tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size)
# Each layer needs a unique rotation matrix so quantization errors
# don't correlate across layers. Stride must exceed max head_dim to
# ensure non-overlapping RNG streams between adjacent layers.
_TQ_LAYER_SEED_STRIDE = 1337
from vllm.model_executor.models.utils import extract_layer_index
layer_idx = extract_layer_index(prefix)
seed = tq_config.seed + layer_idx * _TQ_LAYER_SEED_STRIDE
self.register_buffer(
"_tq_signs",
generate_wht_signs(head_size, seed=seed),
)
self.register_buffer(
"_tq_centroids",
get_centroids(head_size, tq_config.centroid_bits),
)
self._tq_config = tq_config
# Pre-allocate decode intermediate buffers so model.to(device) moves
# them to GPU *before* the memory profiler runs. Without this the
# profiler gives all free memory to KV cache blocks and the first
# decode OOMs when these buffers are lazily allocated.
_vllm_cfg = get_current_vllm_config()
B = _vllm_cfg.scheduler_config.max_num_seqs
Hq = self.num_heads
S = _vllm_cfg.attention_config.tq_max_kv_splits_for_cuda_graph
D = head_size
self.register_buffer(
"_tq_mid_o_buf",
torch.empty(B, Hq, S, D + 1, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"_tq_output_buf",
torch.empty(B, Hq, D, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"_tq_lse_buf",
torch.empty(B, Hq, dtype=torch.float32),
persistent=False,
)
def forward(
self,
query: torch.Tensor,
......@@ -544,6 +609,23 @@ class Attention(nn.Module, AttentionLayerBase):
kv_quant_mode=quant_mode,
sliding_window=self.sliding_window,
)
elif self.kv_cache_dtype.startswith("turboquant_"):
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
from vllm.v1.kv_cache_interface import TQFullAttentionSpec
tq_config = TurboQuantConfig.from_cache_dtype(
self.kv_cache_dtype, self.head_size
)
return TQFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size,
dtype=self.kv_cache_torch_dtype,
tq_slot_size=tq_config.slot_size_aligned,
)
else:
return FullAttentionSpec(
block_size=block_size,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant: Near-optimal KV-cache quantization for vLLM.
PolarQuant compression: random rotation + per-coordinate Lloyd-Max
scalar quantization for keys, uniform quantization for values.
Reference: "TurboQuant: Online Vector Quantization with Near-optimal
Distortion Rate" (ICLR 2026), Zandieh et al.
"""
from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig
__all__ = ["TurboQuantConfig"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Lloyd-Max optimal scalar quantizer for TurboQuant.
After rotating a d-dimensional unit vector by a random orthogonal matrix,
each coordinate approximately follows N(0, 1/d) for d >= 64.
We solve the Lloyd-Max conditions to find optimal centroids.
Based on: turboquant-pytorch/lloyd_max.py (Zandieh et al.)
"""
import math
from functools import lru_cache
import torch
def _gaussian_pdf(x: float, sigma2: float) -> float:
return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
def _trapz(f, a: float, b: float, n: int = 200) -> float:
"""Trapezoidal numerical integration (replaces scipy.integrate.quad)."""
h = (b - a) / n
result = 0.5 * (f(a) + f(b))
for i in range(1, n):
result += f(a + i * h)
return result * h
def solve_lloyd_max(
d: int,
bits: int,
max_iter: int = 200,
tol: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Solve Lloyd-Max optimal quantizer for N(0, 1/d) distribution.
Args:
d: Vector dimension (determines variance = 1/d).
bits: Number of quantization bits.
max_iter: Maximum Lloyd-Max iterations.
tol: Convergence tolerance.
Returns:
centroids: Sorted tensor of 2^bits optimal centroids.
boundaries: Sorted tensor of 2^bits - 1 decision boundaries.
"""
n_levels = 2**bits
sigma2 = 1.0 / d
sigma = math.sqrt(sigma2)
def pdf(x):
return _gaussian_pdf(x, sigma2)
lo, hi = -3.5 * sigma, 3.5 * sigma
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
for _ in range(max_iter):
boundaries = [
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
]
edges = [lo * 3] + boundaries + [hi * 3]
new_centroids = []
for i in range(n_levels):
a, b = edges[i], edges[i + 1]
num = _trapz(lambda x: x * pdf(x), a, b)
den = _trapz(pdf, a, b)
new_centroids.append(num / den if den > 1e-15 else centroids[i])
if max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) < tol:
break
centroids = new_centroids
boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
return (
torch.tensor(centroids, dtype=torch.float32),
torch.tensor(boundaries, dtype=torch.float32),
)
@lru_cache(maxsize=32)
def get_centroids(d: int, bits: int) -> torch.Tensor:
"""Get precomputed Lloyd-Max centroids (cached)."""
centroids, _ = solve_lloyd_max(d, bits)
return centroids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant configuration."""
import math
from dataclasses import dataclass
# Named TQ presets: each maps to frozen config parameters.
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
# value_quant_bits: 3-4 = uniform quantized values.
TQ_PRESETS: dict[str, dict] = {
"turboquant_k8v4": {
"key_quant_bits": 8,
"value_quant_bits": 4,
"norm_correction": False,
},
"turboquant_4bit_nc": {
"key_quant_bits": 4,
"value_quant_bits": 4,
"norm_correction": True,
},
"turboquant_k3v4_nc": {
"key_quant_bits": 3,
"value_quant_bits": 4,
"norm_correction": True,
},
"turboquant_3bit_nc": {
"key_quant_bits": 3,
"value_quant_bits": 3,
"norm_correction": True,
},
}
@dataclass
class TurboQuantConfig:
"""Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
and uniform quantization for values. QJL is intentionally omitted —
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax.
Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
Args:
head_dim: Attention head dimension (e.g. 64, 96, 128).
key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit.
"""
head_dim: int = 128
key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits: int = 4 # 3-4 = uniform quantized values
seed: int = 42
norm_correction: bool = False
@property
def key_fp8(self) -> bool:
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
return self.key_quant_bits == 8
@property
def mse_bits(self) -> int:
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
For MSE key modes, equals key_quant_bits.
For FP8 key mode, falls back to value_quant_bits (centroids are still
needed for continuation-prefill dequant and decode kernel params).
"""
if self.key_fp8:
return self.value_quant_bits
return self.key_quant_bits
@property
def key_mse_bits(self) -> int:
"""MSE bits actually used for key quantization (0 if FP8 keys)."""
if self.key_fp8:
return 0
return self.key_quant_bits
@property
def centroid_bits(self) -> int:
"""Bits for centroid generation — always non-zero."""
return self.mse_bits
@property
def n_centroids(self) -> int:
return 2**self.mse_bits
@property
def key_packed_size(self) -> int:
"""Packed bytes for a single KEY vector.
FP8 mode (key_quant_bits=8):
head_dim bytes (1 byte per element, no overhead).
TQ mode:
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)
"""
if self.key_fp8:
return self.head_dim # 1 byte per element
mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8)
norm_bytes = 2 # vec_norm fp16
return mse_bytes + norm_bytes
@property
def effective_value_quant_bits(self) -> int:
"""Actual bits used for value storage."""
return self.value_quant_bits
@property
def value_packed_size(self) -> int:
"""Packed bytes for a single VALUE vector.
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
"""
data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8)
return data_bytes + 4 # +2 scale(fp16) +2 zero(fp16)
@property
def slot_size(self) -> int:
"""Total packed bytes per head per position (key + value combined).
Layout: [key_packed | value_packed]
"""
return self.key_packed_size + self.value_packed_size
@property
def slot_size_aligned(self) -> int:
"""Slot size rounded up to next even number.
Even-number is required so effective_head_size = slot_size_aligned // 2
is integral.
"""
s = self.slot_size
return s + (s % 2) # round up to even
@staticmethod
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
"""Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
"""
if n <= 0 or num_layers <= 0:
return []
n = min(n, num_layers // 2) # don't skip more than half
first = list(range(n))
last = list(range(num_layers - n, num_layers))
# Deduplicate (if num_layers <= 2*n)
indices = sorted(set(first + last))
return [str(i) for i in indices]
@staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
"""
if cache_dtype not in TQ_PRESETS:
valid = ", ".join(TQ_PRESETS.keys())
raise ValueError(
f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
f"Valid presets: {valid}"
)
preset = TQ_PRESETS[cache_dtype]
return TurboQuantConfig(
head_dim=head_dim,
key_quant_bits=preset["key_quant_bits"],
value_quant_bits=preset["value_quant_bits"],
norm_correction=preset["norm_correction"],
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant quantizer utilities.
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
Triton kernels handle all quantization, packing, and dequantization on GPU.
"""
import torch
_CPU = torch.device("cpu")
def generate_wht_signs(d: int, seed: int, device: torch.device = _CPU) -> torch.Tensor:
"""Generate deterministic random ±1 signs for WHT rotation.
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
"""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
bits = torch.randint(0, 2, (d,), generator=gen, device="cpu")
signs = bits.float() * 2 - 1
return signs.to(device)
......@@ -255,6 +255,11 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = []
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
return [(AttentionBackendEnum.TURBOQUANT, 0)], {}
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla,
device_capability,
......
......@@ -61,6 +61,12 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
logger.info_once("Using TurboQuant attention backend.")
return AttentionBackendEnum.TURBOQUANT.get_path()
dtype = attn_selector_config.dtype
if attn_selector_config.use_sparse:
logger.info_once("Using XPU MLA Sparse backend.")
......
......@@ -42,6 +42,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_per_token_head": torch.uint8,
"fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8,
"turboquant_k8v4": torch.uint8,
"turboquant_4bit_nc": torch.uint8,
"turboquant_k3v4_nc": torch.uint8,
"turboquant_3bit_nc": torch.uint8,
}
TORCH_DTYPE_TO_NUMPY_DTYPE = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment