Unverified Commit ed0622e3 authored by Dan Alistarh's avatar Dan Alistarh Committed by GitHub
Browse files

[Attention] TurboQuant: remove redundant random signs, add prior art attribution (#40194)


Signed-off-by: default avatarDan Alistarh <d.alistarh@gmail.com>
parent b5f6c5f8
......@@ -18,9 +18,6 @@ from vllm.model_executor.layers.quantization.turboquant.config import (
TQ_PRESETS,
TurboQuantConfig,
)
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_wht_signs,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
......@@ -393,7 +390,7 @@ class TestRotationMatrix:
# ============================================================================
# WHT rotation tests (serving path: generate_wht_signs + _build_hadamard)
# Hadamard rotation tests (serving path: _build_hadamard)
# ============================================================================
......@@ -406,50 +403,26 @@ def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor:
@pytest.mark.skipif(not GPGPU_AVAILABLE, reason="GPGPU not available")
class TestWHTRotation:
"""Tests for the WHT rotation actually used in serving."""
class TestHadamardRotation:
"""Tests for the Hadamard rotation used in serving."""
@pytest.mark.parametrize("dim", [64, 128, 256])
def test_wht_orthonormal(self, dim):
"""signs * H must be orthonormal: (signs*H) @ (signs*H)^T = I."""
signs = generate_wht_signs(dim, seed=42, device=DEVICE_TYPE)
def test_hadamard_orthonormal(self, dim):
"""H must be orthonormal: H @ H^T = I."""
H = _build_hadamard(dim, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous()
eye = PiT @ PiT.T
eye = H @ H.T
assert torch.allclose(eye, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), (
f"WHT rotation not orthonormal for dim={dim}"
f"Hadamard not orthonormal for dim={dim}"
)
@pytest.mark.parametrize("dim", [64, 128, 256])
def test_wht_self_inverse(self, dim):
"""PiT should be self-inverse: PiT @ PiT = I (up to sign flip)."""
signs = generate_wht_signs(dim, seed=42, device=DEVICE_TYPE)
def test_hadamard_symmetric(self, dim):
"""Sylvester Hadamard must be symmetric: H = H^T."""
H = _build_hadamard(dim, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous()
Pi = PiT.T.contiguous()
# Pi @ PiT should be identity (rotation then inverse)
result = Pi @ PiT
assert torch.allclose(result, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), (
f"WHT rotation not self-inverse for dim={dim}"
assert torch.allclose(H, H.T, atol=1e-6), (
f"Hadamard not symmetric for dim={dim}"
)
def test_wht_signs_deterministic(self):
"""Same seed must produce identical signs."""
s1 = generate_wht_signs(128, seed=42)
s2 = generate_wht_signs(128, seed=42)
assert torch.equal(s1, s2)
def test_wht_signs_different_seeds(self):
"""Different seeds must produce different signs."""
s1 = generate_wht_signs(128, seed=42)
s2 = generate_wht_signs(128, seed=99)
assert not torch.equal(s1, s2)
def test_wht_signs_are_pm1(self):
"""All sign values must be exactly +1 or -1."""
signs = generate_wht_signs(128, seed=42)
assert torch.all(signs.abs() == 1.0)
# ============================================================================
# Store → Decode round-trip test (GPU + Triton required)
......@@ -491,11 +464,10 @@ class TestStoreDecodeRoundTrip:
device = torch.device(DEVICE_TYPE)
# Generate rotation
signs = generate_wht_signs(D, seed=42, device=device)
# Pure Hadamard rotation (symmetric: H = H^T, so Pi = PiT = H)
H = _build_hadamard(D, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous().float()
Pi = PiT.T.contiguous()
PiT = H
Pi = H
# Generate centroids
centroids, _ = solve_lloyd_max(D, cfg.centroid_bits)
......
......@@ -406,33 +406,16 @@ class Attention(nn.Module, AttentionLayerBase):
def _init_turboquant_buffers(
self, cache_dtype: str, head_size: int, prefix: str
) -> None:
"""Initialize TurboQuant rotation/projection matrices and centroids."""
"""Initialize TurboQuant centroids for Lloyd-Max quantization."""
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
)
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_wht_signs,
)
tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size)
# Each layer needs a unique rotation matrix so quantization errors
# don't correlate across layers. Stride must exceed max head_dim to
# ensure non-overlapping RNG streams between adjacent layers.
_TQ_LAYER_SEED_STRIDE = 1337
from vllm.model_executor.models.utils import extract_layer_index
layer_idx = extract_layer_index(prefix)
seed = tq_config.seed + layer_idx * _TQ_LAYER_SEED_STRIDE
self.register_buffer(
"_tq_signs",
generate_wht_signs(head_size, seed=seed),
)
self.register_buffer(
"_tq_centroids",
get_centroids(head_size, tq_config.centroid_bits),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant: Near-optimal KV-cache quantization for vLLM.
"""TurboQuant: KV-cache quantization for vLLM.
PolarQuant compression: random rotation + per-coordinate Lloyd-Max
scalar quantization for keys, uniform quantization for values.
Hadamard rotation + per-coordinate Lloyd-Max scalar quantization for
keys, uniform quantization for values.
Reference: "TurboQuant: Online Vector Quantization with Near-optimal
Distortion Rate" (ICLR 2026), Zandieh et al.
The technique implemented here consists of the scalar case of the HIGGS
quantization method (Malinovskii et al., "Pushing the Limits of Large
Language Model Quantization via the Linearity Theorem", NAACL 2025;
preprint arXiv:2411.17525): rotation + optimized grid + optional
re-normalization, applied to KV cache compression. A first application
of this approach to KV-cache compression is in "Cache Me If You Must:
Adaptive Key-Value Quantization for Large Language Models" (Shutova
et al., ICML 2025; preprint arXiv:2501.19392). Both these references
pre-date the TurboQuant paper (Zandieh et al., ICLR 2026).
"""
from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig
......
......@@ -36,10 +36,22 @@ TQ_PRESETS: dict[str, dict] = {
class TurboQuantConfig:
"""Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
and uniform quantization for values. QJL is intentionally omitted —
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax.
Applies Hadamard rotation followed by per-coordinate Lloyd-Max scalar
quantization for keys, and uniform quantization for values.
Historical note: this is the scalar case of the HIGGS quantization
method (Malinovskii et al., "Pushing the Limits of Large Language Model
Quantization via the Linearity Theorem", NAACL 2025; preprint
arXiv:2411.17525): rotation + optimized grid + optional re-normalization,
applied to KV cache compression. A first application of this approach to
KV-cache compression is in "Cache Me If You Must: Adaptive Key-Value
Quantization for Large Language Models" (Shutova et al., ICML 2025;
preprint arXiv:2501.19392). Both these references pre-date the
TurboQuant paper.
QJL is intentionally omitted — community consensus (5+ independent
groups) found it hurts attention quality by amplifying variance through
softmax.
Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
......@@ -53,8 +65,6 @@ class TurboQuantConfig:
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit.
......@@ -63,7 +73,7 @@ class TurboQuantConfig:
head_dim: int = 128
key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits: int = 4 # 3-4 = uniform quantized values
seed: int = 42
seed: int = 42 # kept for backward compatibility; no longer used internally
norm_correction: bool = False
@property
......
......@@ -2,23 +2,5 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant quantizer utilities.
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
Triton kernels handle all quantization, packing, and dequantization on GPU.
"""
import torch
_CPU = torch.device("cpu")
def generate_wht_signs(d: int, seed: int, device: torch.device = _CPU) -> torch.Tensor:
"""Generate deterministic random ±1 signs for WHT rotation.
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
"""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
bits = torch.randint(0, 2, (d,), generator=gen, device="cpu")
signs = bits.float() * 2 - 1
return signs.to(device)
......@@ -279,29 +279,25 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
)
def _ensure_on_device(self, layer, device):
"""One-time derivation of TQ buffers (rotation matrices, midpoints).
"""One-time derivation of TQ buffers (rotation matrix, midpoints).
Registered buffers (_tq_signs, _tq_centroids) are already on the
correct device via register_buffer + model.to(device).
The Hadamard rotation is shared across all layers: random sign
flips do not improve Lloyd-Max quantization quality because the
quantizer is symmetric around zero (sign-flipping a coordinate
maps it to the mirror centroid with identical distortion).
"""
if not hasattr(layer, "_tq_cached"):
D = layer._tq_signs.shape[0]
signs = layer._tq_signs.to(device=device, dtype=torch.float32)
D = self.head_size
# WHT rotation: orthonormal + self-inverse, enabling future
# Pure Hadamard: orthonormal + symmetric (H = H^T), enabling
# in-kernel butterfly fusion and trivial inverse for continuation.
H = _build_hadamard(D, str(device))
layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous()
layer._tq_Pi = layer._tq_PiT.T.contiguous()
layer._tq_PiT = H
layer._tq_Pi = H
c = layer._tq_centroids.to(device=device, dtype=torch.float32)
# Precompute midpoints for threshold-based quantization
c_sorted, _ = c.sort()
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
# Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf)
# are pre-allocated via register_buffer in Attention.__init__
# and moved to GPU by model.to(device) — no allocation needed
# here. The memory profiler sees them before KV cache sizing.
layer._tq_cached = True
def do_kv_cache_update(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment