Unverified Commit ed0622e3 authored by Dan Alistarh's avatar Dan Alistarh Committed by GitHub
Browse files

[Attention] TurboQuant: remove redundant random signs, add prior art attribution (#40194)


Signed-off-by: default avatarDan Alistarh <d.alistarh@gmail.com>
parent b5f6c5f8
...@@ -18,9 +18,6 @@ from vllm.model_executor.layers.quantization.turboquant.config import ( ...@@ -18,9 +18,6 @@ from vllm.model_executor.layers.quantization.turboquant.config import (
TQ_PRESETS, TQ_PRESETS,
TurboQuantConfig, TurboQuantConfig,
) )
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_wht_signs,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
...@@ -393,7 +390,7 @@ class TestRotationMatrix: ...@@ -393,7 +390,7 @@ class TestRotationMatrix:
# ============================================================================ # ============================================================================
# WHT rotation tests (serving path: generate_wht_signs + _build_hadamard) # Hadamard rotation tests (serving path: _build_hadamard)
# ============================================================================ # ============================================================================
...@@ -406,50 +403,26 @@ def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor: ...@@ -406,50 +403,26 @@ def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor:
@pytest.mark.skipif(not GPGPU_AVAILABLE, reason="GPGPU not available") @pytest.mark.skipif(not GPGPU_AVAILABLE, reason="GPGPU not available")
class TestWHTRotation: class TestHadamardRotation:
"""Tests for the WHT rotation actually used in serving.""" """Tests for the Hadamard rotation used in serving."""
@pytest.mark.parametrize("dim", [64, 128, 256]) @pytest.mark.parametrize("dim", [64, 128, 256])
def test_wht_orthonormal(self, dim): def test_hadamard_orthonormal(self, dim):
"""signs * H must be orthonormal: (signs*H) @ (signs*H)^T = I.""" """H must be orthonormal: H @ H^T = I."""
signs = generate_wht_signs(dim, seed=42, device=DEVICE_TYPE)
H = _build_hadamard(dim, DEVICE_TYPE) H = _build_hadamard(dim, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous() eye = H @ H.T
eye = PiT @ PiT.T
assert torch.allclose(eye, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), ( assert torch.allclose(eye, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), (
f"WHT rotation not orthonormal for dim={dim}" f"Hadamard not orthonormal for dim={dim}"
) )
@pytest.mark.parametrize("dim", [64, 128, 256]) @pytest.mark.parametrize("dim", [64, 128, 256])
def test_wht_self_inverse(self, dim): def test_hadamard_symmetric(self, dim):
"""PiT should be self-inverse: PiT @ PiT = I (up to sign flip).""" """Sylvester Hadamard must be symmetric: H = H^T."""
signs = generate_wht_signs(dim, seed=42, device=DEVICE_TYPE)
H = _build_hadamard(dim, DEVICE_TYPE) H = _build_hadamard(dim, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous() assert torch.allclose(H, H.T, atol=1e-6), (
Pi = PiT.T.contiguous() f"Hadamard not symmetric for dim={dim}"
# Pi @ PiT should be identity (rotation then inverse)
result = Pi @ PiT
assert torch.allclose(result, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), (
f"WHT rotation not self-inverse for dim={dim}"
) )
def test_wht_signs_deterministic(self):
"""Same seed must produce identical signs."""
s1 = generate_wht_signs(128, seed=42)
s2 = generate_wht_signs(128, seed=42)
assert torch.equal(s1, s2)
def test_wht_signs_different_seeds(self):
"""Different seeds must produce different signs."""
s1 = generate_wht_signs(128, seed=42)
s2 = generate_wht_signs(128, seed=99)
assert not torch.equal(s1, s2)
def test_wht_signs_are_pm1(self):
"""All sign values must be exactly +1 or -1."""
signs = generate_wht_signs(128, seed=42)
assert torch.all(signs.abs() == 1.0)
# ============================================================================ # ============================================================================
# Store → Decode round-trip test (GPU + Triton required) # Store → Decode round-trip test (GPU + Triton required)
...@@ -491,11 +464,10 @@ class TestStoreDecodeRoundTrip: ...@@ -491,11 +464,10 @@ class TestStoreDecodeRoundTrip:
device = torch.device(DEVICE_TYPE) device = torch.device(DEVICE_TYPE)
# Generate rotation # Pure Hadamard rotation (symmetric: H = H^T, so Pi = PiT = H)
signs = generate_wht_signs(D, seed=42, device=device)
H = _build_hadamard(D, DEVICE_TYPE) H = _build_hadamard(D, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous().float() PiT = H
Pi = PiT.T.contiguous() Pi = H
# Generate centroids # Generate centroids
centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) centroids, _ = solve_lloyd_max(D, cfg.centroid_bits)
......
...@@ -406,33 +406,16 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -406,33 +406,16 @@ class Attention(nn.Module, AttentionLayerBase):
def _init_turboquant_buffers( def _init_turboquant_buffers(
self, cache_dtype: str, head_size: int, prefix: str self, cache_dtype: str, head_size: int, prefix: str
) -> None: ) -> None:
"""Initialize TurboQuant rotation/projection matrices and centroids.""" """Initialize TurboQuant centroids for Lloyd-Max quantization."""
from vllm.model_executor.layers.quantization.turboquant.centroids import ( from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids, get_centroids,
) )
from vllm.model_executor.layers.quantization.turboquant.config import ( from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig, TurboQuantConfig,
) )
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_wht_signs,
)
tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size) tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size)
# Each layer needs a unique rotation matrix so quantization errors
# don't correlate across layers. Stride must exceed max head_dim to
# ensure non-overlapping RNG streams between adjacent layers.
_TQ_LAYER_SEED_STRIDE = 1337
from vllm.model_executor.models.utils import extract_layer_index
layer_idx = extract_layer_index(prefix)
seed = tq_config.seed + layer_idx * _TQ_LAYER_SEED_STRIDE
self.register_buffer(
"_tq_signs",
generate_wht_signs(head_size, seed=seed),
)
self.register_buffer( self.register_buffer(
"_tq_centroids", "_tq_centroids",
get_centroids(head_size, tq_config.centroid_bits), get_centroids(head_size, tq_config.centroid_bits),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant: Near-optimal KV-cache quantization for vLLM. """TurboQuant: KV-cache quantization for vLLM.
PolarQuant compression: random rotation + per-coordinate Lloyd-Max Hadamard rotation + per-coordinate Lloyd-Max scalar quantization for
scalar quantization for keys, uniform quantization for values. keys, uniform quantization for values.
Reference: "TurboQuant: Online Vector Quantization with Near-optimal The technique implemented here consists of the scalar case of the HIGGS
Distortion Rate" (ICLR 2026), Zandieh et al. quantization method (Malinovskii et al., "Pushing the Limits of Large
Language Model Quantization via the Linearity Theorem", NAACL 2025;
preprint arXiv:2411.17525): rotation + optimized grid + optional
re-normalization, applied to KV cache compression. A first application
of this approach to KV-cache compression is in "Cache Me If You Must:
Adaptive Key-Value Quantization for Large Language Models" (Shutova
et al., ICML 2025; preprint arXiv:2501.19392). Both these references
pre-date the TurboQuant paper (Zandieh et al., ICLR 2026).
""" """
from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig
......
...@@ -36,10 +36,22 @@ TQ_PRESETS: dict[str, dict] = { ...@@ -36,10 +36,22 @@ TQ_PRESETS: dict[str, dict] = {
class TurboQuantConfig: class TurboQuantConfig:
"""Configuration for TurboQuant KV-cache quantization. """Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys Applies Hadamard rotation followed by per-coordinate Lloyd-Max scalar
and uniform quantization for values. QJL is intentionally omitted — quantization for keys, and uniform quantization for values.
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax. Historical note: this is the scalar case of the HIGGS quantization
method (Malinovskii et al., "Pushing the Limits of Large Language Model
Quantization via the Linearity Theorem", NAACL 2025; preprint
arXiv:2411.17525): rotation + optimized grid + optional re-normalization,
applied to KV cache compression. A first application of this approach to
KV-cache compression is in "Cache Me If You Must: Adaptive Key-Value
Quantization for Large Language Models" (Shutova et al., ICML 2025;
preprint arXiv:2501.19392). Both these references pre-date the
TurboQuant paper.
QJL is intentionally omitted — community consensus (5+ independent
groups) found it hurts attention quality by amplifying variance through
softmax.
Named presets (use via --kv-cache-dtype): Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
...@@ -53,8 +65,6 @@ class TurboQuantConfig: ...@@ -53,8 +65,6 @@ class TurboQuantConfig:
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys. rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization. value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default). 3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit. distortion, improving PPL by ~0.8% at 4-bit.
...@@ -63,7 +73,7 @@ class TurboQuantConfig: ...@@ -63,7 +73,7 @@ class TurboQuantConfig:
head_dim: int = 128 head_dim: int = 128
key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits: int = 4 # 3-4 = uniform quantized values value_quant_bits: int = 4 # 3-4 = uniform quantized values
seed: int = 42 seed: int = 42 # kept for backward compatibility; no longer used internally
norm_correction: bool = False norm_correction: bool = False
@property @property
......
...@@ -2,23 +2,5 @@ ...@@ -2,23 +2,5 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant quantizer utilities. """TurboQuant quantizer utilities.
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
Triton kernels handle all quantization, packing, and dequantization on GPU. Triton kernels handle all quantization, packing, and dequantization on GPU.
""" """
import torch
_CPU = torch.device("cpu")
def generate_wht_signs(d: int, seed: int, device: torch.device = _CPU) -> torch.Tensor:
"""Generate deterministic random ±1 signs for WHT rotation.
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
"""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
bits = torch.randint(0, 2, (d,), generator=gen, device="cpu")
signs = bits.float() * 2 - 1
return signs.to(device)
...@@ -279,29 +279,25 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -279,29 +279,25 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
) )
def _ensure_on_device(self, layer, device): def _ensure_on_device(self, layer, device):
"""One-time derivation of TQ buffers (rotation matrices, midpoints). """One-time derivation of TQ buffers (rotation matrix, midpoints).
Registered buffers (_tq_signs, _tq_centroids) are already on the The Hadamard rotation is shared across all layers: random sign
correct device via register_buffer + model.to(device). flips do not improve Lloyd-Max quantization quality because the
quantizer is symmetric around zero (sign-flipping a coordinate
maps it to the mirror centroid with identical distortion).
""" """
if not hasattr(layer, "_tq_cached"): if not hasattr(layer, "_tq_cached"):
D = layer._tq_signs.shape[0] D = self.head_size
signs = layer._tq_signs.to(device=device, dtype=torch.float32)
# WHT rotation: orthonormal + self-inverse, enabling future # Pure Hadamard: orthonormal + symmetric (H = H^T), enabling
# in-kernel butterfly fusion and trivial inverse for continuation. # in-kernel butterfly fusion and trivial inverse for continuation.
H = _build_hadamard(D, str(device)) H = _build_hadamard(D, str(device))
layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous() layer._tq_PiT = H
layer._tq_Pi = layer._tq_PiT.T.contiguous() layer._tq_Pi = H
c = layer._tq_centroids.to(device=device, dtype=torch.float32) c = layer._tq_centroids.to(device=device, dtype=torch.float32)
# Precompute midpoints for threshold-based quantization
c_sorted, _ = c.sort() c_sorted, _ = c.sort()
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2 layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
# Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf)
# are pre-allocated via register_buffer in Attention.__init__
# and moved to GPU by model.to(device) — no allocation needed
# here. The memory profiler sees them before KV cache sizing.
layer._tq_cached = True layer._tq_cached = True
def do_kv_cache_update( def do_kv_cache_update(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment