Commit 0885aa25 authored by wanglong3's avatar wanglong3 Committed by zhangzbb
Browse files

[feature][Attention Backend] TurboQuant: 2-bit KV cache compression with 4x capacity #38479

parent 4fca01b8
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.78
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096"
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096"
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.75
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096"
model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096"
Qwen3-4B-TQ-k8v4.yaml
Qwen3-4B-TQ-t4nc.yaml
Qwen3-4B-TQ-k3v4nc.yaml
Qwen3-4B-TQ-t3nc.yaml
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for TurboQuant KV-cache quantization.
Run: .venv/bin/python -m pytest tests/quantization/test_turboquant.py -v
"""
import math
import pytest
import torch
from vllm.model_executor.layers.quantization.turboquant.config import (
TQ_PRESETS,
TurboQuantConfig,
)
from vllm.utils.math_utils import next_power_of_2
# ============================================================================
# Helpers
# ============================================================================
ALL_PRESETS = list(TQ_PRESETS.keys())
def _assert_strictly_sorted(seq, name="sequence"):
for i in range(len(seq) - 1):
assert seq[i] < seq[i + 1], f"{name} not sorted at index {i}"
def _is_power_of_2(n: int) -> bool:
return n > 0 and next_power_of_2(n) == n
# Expected concrete values for each preset at head_dim=128.
# fmt: off
PRESET_EXPECTED = {
"turboquant_k8v4": dict(
key_fp8=True, key_quant_bits=8,
key_mse_bits=0, value_quant_bits=4,
mse_bits=4, n_centroids=16, centroid_bits=4,
norm_correction=False,
key_packed_size=128, value_packed_size=68,
slot_size=196, slot_size_aligned=196,
),
"turboquant_4bit_nc": dict(
key_fp8=False, key_quant_bits=4,
key_mse_bits=4, value_quant_bits=4,
mse_bits=4, n_centroids=16, centroid_bits=4,
norm_correction=True,
key_packed_size=68, value_packed_size=68,
slot_size=136, slot_size_aligned=136,
),
"turboquant_k3v4_nc": dict(
key_fp8=False, key_quant_bits=3,
key_mse_bits=3, value_quant_bits=4,
mse_bits=3, n_centroids=8, centroid_bits=3,
norm_correction=True,
key_packed_size=52, value_packed_size=68,
slot_size=120, slot_size_aligned=120,
),
"turboquant_3bit_nc": dict(
key_fp8=False, key_quant_bits=3,
key_mse_bits=3, value_quant_bits=3,
mse_bits=3, n_centroids=8, centroid_bits=3,
norm_correction=True,
key_packed_size=52, value_packed_size=52,
slot_size=104, slot_size_aligned=104,
),
}
# fmt: on
# ============================================================================
# Config tests (CPU-only, no dependencies beyond config.py)
# ============================================================================
class TestTurboQuantConfig:
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_preset_parses(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
assert isinstance(cfg, TurboQuantConfig)
def test_invalid_preset_raises(self):
with pytest.raises(ValueError, match="Unknown TurboQuant"):
TurboQuantConfig.from_cache_dtype("turboquant_invalid", head_dim=128)
# ---- Per-preset concrete value checks (table-driven) ----
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_key_mode(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
exp = PRESET_EXPECTED[preset]
assert cfg.key_fp8 is exp["key_fp8"]
assert cfg.key_quant_bits == exp["key_quant_bits"]
assert cfg.key_mse_bits == exp["key_mse_bits"]
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_value_mode(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
exp = PRESET_EXPECTED[preset]
assert cfg.value_quant_bits == exp["value_quant_bits"]
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_bits_and_centroids(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
exp = PRESET_EXPECTED[preset]
assert cfg.mse_bits == exp["mse_bits"]
assert cfg.n_centroids == exp["n_centroids"]
assert cfg.centroid_bits == exp["centroid_bits"]
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_norm_correction(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
assert cfg.norm_correction is PRESET_EXPECTED[preset]["norm_correction"]
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_packed_sizes(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
exp = PRESET_EXPECTED[preset]
assert cfg.key_packed_size == exp["key_packed_size"]
assert cfg.value_packed_size == exp["value_packed_size"]
assert cfg.slot_size == exp["slot_size"]
assert cfg.slot_size_aligned == exp["slot_size_aligned"]
# ---- Cross-preset structural invariants ----
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_slot_equals_key_plus_value(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
assert cfg.slot_size == cfg.key_packed_size + cfg.value_packed_size
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_padded_slot_is_even(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
assert cfg.slot_size_aligned >= cfg.slot_size
assert cfg.slot_size_aligned % 2 == 0, (
f"slot_size_aligned={cfg.slot_size_aligned} is not even"
)
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_key_value_packed_sizes_positive(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
assert cfg.key_packed_size > 0
assert cfg.value_packed_size > 0
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_n_centroids_is_2_to_mse_bits(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
assert cfg.n_centroids == 2**cfg.mse_bits
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_centroid_bits_always_positive(self, preset):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
assert cfg.centroid_bits > 0
@pytest.mark.parametrize("preset", ALL_PRESETS)
def test_mse_key_or_fp8_exclusive(self, preset):
"""Each preset is either FP8 keys or MSE keys, never both."""
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
if cfg.key_fp8:
assert cfg.key_mse_bits == 0
assert cfg.key_quant_bits == 8
else:
assert cfg.key_mse_bits > 0
assert cfg.key_quant_bits in (3, 4)
@pytest.mark.parametrize("preset", ALL_PRESETS)
@pytest.mark.parametrize("head_dim", [64, 96, 128, 256])
def test_all_presets_all_head_dims(self, preset, head_dim):
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=head_dim)
assert cfg.head_dim == head_dim
assert cfg.slot_size == cfg.key_packed_size + cfg.value_packed_size
assert cfg.slot_size_aligned >= cfg.slot_size
assert cfg.slot_size_aligned % 2 == 0
# ---- Boundary skip layers ----
def test_boundary_skip_layers_basic(self):
layers = TurboQuantConfig.get_boundary_skip_layers(32)
assert layers == ["0", "1", "30", "31"]
def test_boundary_skip_layers_zero(self):
assert TurboQuantConfig.get_boundary_skip_layers(32, 0) == []
def test_boundary_skip_layers_small_model(self):
layers = TurboQuantConfig.get_boundary_skip_layers(4)
assert layers == ["0", "1", "2", "3"]
def test_boundary_skip_layers_cap_at_half(self):
layers = TurboQuantConfig.get_boundary_skip_layers(8, 10)
assert len(layers) == 8
# ============================================================================
# Centroids tests (CPU-only)
# ============================================================================
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
solve_lloyd_max,
)
class TestCentroids:
@pytest.mark.parametrize("bits,expected_n", [(2, 4), (3, 8), (4, 16)])
def test_centroids_shape(self, bits, expected_n):
c = get_centroids(128, bits)
assert c.shape == (expected_n,)
@pytest.mark.parametrize("bits", [2, 3, 4])
def test_centroids_sorted(self, bits):
_assert_strictly_sorted(get_centroids(128, bits), "centroids")
def test_centroids_cached(self):
c1 = get_centroids(128, 3)
c2 = get_centroids(128, 3)
assert c1 is c2, "get_centroids should return cached object"
def test_centroids_different_dims_not_identical(self):
c64 = get_centroids(64, 3)
c128 = get_centroids(128, 3)
assert not torch.equal(c64, c128)
@pytest.mark.parametrize("bits", [2, 3, 4])
def test_centroids_symmetric_around_zero(self, bits):
"""N(0, 1/d) is symmetric, so centroids should be ~symmetric."""
c = get_centroids(128, bits)
assert abs(c.mean().item()) < 0.01, "Centroids not centered near 0"
assert abs(c[0].item() + c[-1].item()) < 0.01
@pytest.mark.parametrize("bits", [2, 3, 4])
def test_centroids_within_4sigma(self, bits):
"""All centroids should be within ~4 sigma of N(0, 1/d)."""
sigma = math.sqrt(1.0 / 128)
c = get_centroids(128, bits)
for i, val in enumerate(c):
assert abs(val.item()) < 4 * sigma, (
f"Centroid {i}={val:.6f} outside 4*sigma={4 * sigma:.6f}"
)
class TestLloydMax:
@pytest.mark.parametrize("bits,expected_n", [(2, 4), (3, 8), (4, 16)])
def test_solve_shapes(self, bits, expected_n):
centroids, boundaries = solve_lloyd_max(128, bits)
assert centroids.shape == (expected_n,)
assert boundaries.shape == (expected_n - 1,)
@pytest.mark.parametrize("bits", [2, 3, 4])
def test_centroids_sorted(self, bits):
centroids, _ = solve_lloyd_max(128, bits)
_assert_strictly_sorted(centroids, "centroids")
@pytest.mark.parametrize("bits", [2, 3, 4])
def test_boundaries_sorted(self, bits):
_, boundaries = solve_lloyd_max(128, bits)
_assert_strictly_sorted(boundaries, "boundaries")
@pytest.mark.parametrize("bits", [2, 3, 4])
def test_boundaries_between_centroids(self, bits):
"""Each boundary must lie between its adjacent centroids."""
centroids, boundaries = solve_lloyd_max(128, bits)
for i in range(len(boundaries)):
assert centroids[i] < boundaries[i] < centroids[i + 1], (
f"Boundary {i}={boundaries[i]:.6f} not between "
f"c[{i}]={centroids[i]:.6f} and c[{i + 1}]={centroids[i + 1]:.6f}"
)
@pytest.mark.parametrize("bits", [2, 3, 4])
def test_boundaries_are_midpoints(self, bits):
"""Lloyd-Max boundaries are midpoints of adjacent centroids."""
centroids, boundaries = solve_lloyd_max(128, bits)
for i in range(len(boundaries)):
expected = (centroids[i] + centroids[i + 1]) / 2.0
assert abs(boundaries[i].item() - expected.item()) < 1e-6
def test_solve_deterministic(self):
c1, b1 = solve_lloyd_max(128, 3)
c2, b2 = solve_lloyd_max(128, 3)
assert torch.equal(c1, c2)
assert torch.equal(b1, b2)
def test_solve_dtype_float32(self):
centroids, boundaries = solve_lloyd_max(128, 3)
assert centroids.dtype == torch.float32
assert boundaries.dtype == torch.float32
# ============================================================================
# Rotation matrix tests (GPU required)
# ============================================================================
CUDA_AVAILABLE = torch.cuda.is_available()
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_rotation_matrix,
)
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestRotationMatrix:
@pytest.mark.parametrize("dim", [64, 96, 128, 256])
def test_rotation_matrix_shape_and_orthogonal(self, dim):
Pi = generate_rotation_matrix(dim, seed=42, device="cuda")
assert Pi.shape == (dim, dim)
eye = Pi @ Pi.T
assert torch.allclose(eye, torch.eye(dim, device="cuda"), atol=1e-5), (
f"Pi not orthogonal for dim={dim}"
)
def test_rotation_matrix_deterministic(self):
Pi1 = generate_rotation_matrix(128, seed=42)
Pi2 = generate_rotation_matrix(128, seed=42)
assert torch.equal(Pi1, Pi2)
def test_rotation_matrix_different_seeds(self):
Pi1 = generate_rotation_matrix(128, seed=42)
Pi2 = generate_rotation_matrix(128, seed=99)
assert not torch.equal(Pi1, Pi2)
def test_rotation_matrix_det_is_pm1(self):
"""Orthogonal matrix determinant must be +1 or -1."""
Pi = generate_rotation_matrix(128, seed=42, device="cuda")
det = torch.linalg.det(Pi)
assert abs(abs(det.item()) - 1.0) < 1e-4
......@@ -372,3 +372,26 @@ def test_fp8_reloading(
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
"""Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(
"facebook/opt-125m",
kv_cache_dtype="fp8",
kv_cache_dtype_skip_layers=["0", "2"],
enforce_eager=True,
) as llm:
def check_layers(model):
for i, layer in enumerate(model.model.decoder.layers):
expected = "auto" if str(i) in ["0", "2"] else "fp8"
assert layer.self_attn.attn.kv_cache_dtype == expected
llm.apply_model(check_layers)
#!/usr/bin/env bash
# num_stages sweep for TQ Triton decode kernel _tq_decode_stage1
# Tests num_stages=1,2,3 for k8v4 (GPU2) and turboquant_4bit_nc (GPU3)
# Usage: bash tools/num_stages_sweep.sh
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")/.." && pwd)"
KERNEL_FILE="$SCRIPT_DIR/vllm/v1/attention/ops/triton_turboquant_decode.py"
PYTHON="$SCRIPT_DIR/.venv/bin/python"
MODEL="Qwen/Qwen3-4B"
RESULTS_FILE="$SCRIPT_DIR/tools/num_stages_results.txt"
echo "=== TQ Triton num_stages sweep ===" | tee "$RESULTS_FILE"
echo "Date: $(date)" | tee -a "$RESULTS_FILE"
echo "Kernel: $KERNEL_FILE" | tee -a "$RESULTS_FILE"
echo "" | tee -a "$RESULTS_FILE"
patch_num_stages() {
local ns=$1
# Replace num_stages=N in the _tq_decode_stage1 launch (line ~548)
sed -i "s/^\( num_stages=\)[0-9]\+,$/\1${ns},/" "$KERNEL_FILE"
echo " patched num_stages=$ns in $KERNEL_FILE"
}
run_bench() {
local gpu=$1
local port=$2
local preset=$3
local ns=$4
echo ""
echo "--- preset=$preset num_stages=$ns GPU=$gpu ---" | tee -a "$RESULTS_FILE"
# Start server
echo " Starting server on GPU $gpu port $port..."
CUDA_VISIBLE_DEVICES=$gpu $PYTHON -m vllm.entrypoints.openai.api_server \
--model "$MODEL" \
--kv-cache-dtype "$preset" \
--port "$port" \
--max-model-len 32768 \
--disable-log-requests \
> /tmp/vllm_gpu${gpu}_ns${ns}.log 2>&1 &
SERVER_PID=$!
# Wait for server to be ready
echo " Waiting for server (pid=$SERVER_PID)..."
local max_wait=120
local elapsed=0
while ! curl -sf "http://localhost:${port}/health" > /dev/null 2>&1; do
sleep 3
elapsed=$((elapsed + 3))
if [ $elapsed -ge $max_wait ]; then
echo " ERROR: server did not start after ${max_wait}s" | tee -a "$RESULTS_FILE"
kill $SERVER_PID 2>/dev/null || true
return 1
fi
done
echo " Server ready after ${elapsed}s"
# Run benchmark
echo " Running bench..."
BENCH_OUT=$(
$PYTHON -m sglang.bench_serving \
--backend vllm \
--port "$port" \
--model "$MODEL" \
--dataset-name random \
--random-input-len 64 \
--random-output-len 1024 \
--num-prompts 200 \
--request-rate inf 2>&1
)
echo "$BENCH_OUT" >> "$RESULTS_FILE"
# Extract key metrics
THROUGHPUT=$(echo "$BENCH_OUT" | grep -oP 'Output token throughput.*?:\s*\K[\d.]+' | head -1 || echo "N/A")
MEDIAN_TTFT=$(echo "$BENCH_OUT" | grep -oP 'Median TTFT.*?:\s*\K[\d.]+' | head -1 || echo "N/A")
echo " output_tok/s=$THROUGHPUT median_ttft_ms=$MEDIAN_TTFT" | tee -a "$RESULTS_FILE"
# Kill server
kill $SERVER_PID 2>/dev/null || true
wait $SERVER_PID 2>/dev/null || true
sleep 2
}
# ===== Sweep k8v4 on GPU 2 =====
PRESET="turboquant_k8v4"
GPU=2
PORT=8502
echo "### PRESET: $PRESET GPU: $GPU ###" | tee -a "$RESULTS_FILE"
for NS in 1 2 3; do
patch_num_stages $NS
run_bench $GPU $PORT "$PRESET" $NS
done
# Restore to 2 (default)
patch_num_stages 2
# ===== Sweep turboquant_4bit_nc on GPU 3 =====
PRESET="turboquant_4bit_nc"
GPU=3
PORT=8503
echo "" | tee -a "$RESULTS_FILE"
echo "### PRESET: $PRESET GPU: $GPU ###" | tee -a "$RESULTS_FILE"
for NS in 1 2 3; do
patch_num_stages $NS
run_bench $GPU $PORT "$PRESET" $NS
done
# Restore to 2
patch_num_stages 2
echo "" | tee -a "$RESULTS_FILE"
echo "=== Sweep complete. Results in $RESULTS_FILE ===" | tee -a "$RESULTS_FILE"
#!/usr/bin/env bash
# Run a single benchmark: start server, bench, kill server, print results
# Usage: bash tools/run_single_bench.sh <GPU> <PORT> <PRESET> <NS_LABEL>
# NS_LABEL is just for logging (the kernel file must already be patched)
set -euo pipefail
GPU=$1
PORT=$2
PRESET=$3
NS_LABEL=$4
PYTHON=/home/vibhav.agarwal/vllm-tq/.venv/bin/python
MODEL=Qwen/Qwen3-4B
LOG_DIR=/home/vibhav.agarwal/vllm-tq/tools/sweep_logs
echo ">>> START preset=$PRESET num_stages=$NS_LABEL gpu=$GPU port=$PORT"
# Start server
CUDA_VISIBLE_DEVICES=$GPU $PYTHON -m vllm.entrypoints.openai.api_server \
--model "$MODEL" \
--kv-cache-dtype "$PRESET" \
--port "$PORT" \
--max-model-len 32768 \
--disable-log-stats \
> "$LOG_DIR/gpu${GPU}_${PRESET}_ns${NS_LABEL}.log" 2>&1 &
SERVER_PID=$!
# Wait for server ready (max 150s)
for i in $(seq 1 50); do
if curl -sf "http://localhost:${PORT}/health" > /dev/null 2>&1; then
echo " server ready at t=${i}*3s"
break
fi
sleep 3
if [ $i -eq 50 ]; then
echo "ERROR: server did not start"
kill -9 $SERVER_PID 2>/dev/null || true
exit 1
fi
done
# Run benchmark
BENCH_LOG="$LOG_DIR/bench_${PRESET}_ns${NS_LABEL}.log"
$PYTHON -m sglang.bench_serving \
--backend vllm \
--port "$PORT" \
--model "$MODEL" \
--dataset-name random \
--random-input-len 64 \
--random-output-len 1024 \
--num-prompts 200 \
--request-rate inf \
> "$BENCH_LOG" 2>&1
# Extract metrics
OUT_TPS=$(grep -oP 'Output token throughput.*?:\s*\K[\d.]+' "$BENCH_LOG" | head -1 || echo "N/A")
MEDIAN_ITL=$(grep -oP 'Median ITL.*?:\s*\K[\d.]+' "$BENCH_LOG" | head -1 || echo "N/A")
MEDIAN_TPOT=$(grep -oP 'Median TPOT.*?:\s*\K[\d.]+' "$BENCH_LOG" | head -1 || echo "N/A")
echo " RESULT: output_tok/s=$OUT_TPS median_itl_ms=$MEDIAN_ITL median_tpot_ms=$MEDIAN_TPOT"
# Kill server and release GPU memory
kill -9 $SERVER_PID 2>/dev/null || true
# Also kill EngineCore child processes
pkill -9 -f "VLLM::EngineCore" 2>/dev/null || true
sleep 5
......@@ -204,6 +204,31 @@ class Attention(nn.Module, AttentionLayerBase):
if cache_config is not None:
cache_config.cache_dtype = "fp8"
cache_config.calculate_kv_scales = False
# Skip quantization for specified layers
if cache_config is not None and cache_config.kv_cache_dtype_skip_layers:
from vllm.model_executor.models.utils import extract_layer_index
skip = False
# Check attention type
if (
sliding_window is not None
and "sliding_window" in cache_config.kv_cache_dtype_skip_layers
):
skip = True
# Check layer index
layer_idx = extract_layer_index(prefix)
if str(layer_idx) in cache_config.kv_cache_dtype_skip_layers:
skip = True
if skip:
kv_cache_dtype = "auto"
calculate_kv_scales = False
logger.info(
"Layer %s: kv_cache_dtype=%s, sliding_window=%s",
prefix,
kv_cache_dtype,
sliding_window,
)
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
......@@ -326,6 +351,10 @@ class Attention(nn.Module, AttentionLayerBase):
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
# Initialize TurboQuant buffers (Pi, S, centroids) if tq cache dtype
if kv_cache_dtype.startswith("turboquant_"):
self._init_turboquant_buffers(kv_cache_dtype, head_size, prefix)
# for attn backends supporting query quantization
self.query_quant = None
# @TODO
......@@ -344,6 +373,42 @@ class Attention(nn.Module, AttentionLayerBase):
else GroupShape.PER_TENSOR,
)
def _init_turboquant_buffers(
self, cache_dtype: str, head_size: int, prefix: str
) -> None:
"""Initialize TurboQuant rotation/projection matrices and centroids."""
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
)
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_wht_signs,
)
tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size)
# Each layer needs a unique rotation matrix so quantization errors
# don't correlate across layers. Stride must exceed max head_dim to
# ensure non-overlapping RNG streams between adjacent layers.
_TQ_LAYER_SEED_STRIDE = 1337
from vllm.model_executor.models.utils import extract_layer_index
layer_idx = extract_layer_index(prefix)
seed = tq_config.seed + layer_idx * _TQ_LAYER_SEED_STRIDE
self.register_buffer(
"_tq_signs",
generate_wht_signs(head_size, seed=seed),
)
self.register_buffer(
"_tq_centroids",
get_centroids(head_size, tq_config.centroid_bits),
)
self._tq_config = tq_config
def forward(
self,
query: torch.Tensor,
......@@ -499,6 +564,23 @@ class Attention(nn.Module, AttentionLayerBase):
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
elif self.kv_cache_dtype.startswith("turboquant_"):
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
from vllm.v1.kv_cache_interface import TQFullAttentionSpec
tq_config = TurboQuantConfig.from_cache_dtype(
self.kv_cache_dtype, self.head_size
)
return TQFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size,
dtype=self.kv_cache_torch_dtype,
tq_slot_size=tq_config.slot_size_aligned,
)
else:
return FullAttentionSpec(
block_size=block_size,
......
......@@ -29,6 +29,11 @@ class AttentionConfig:
flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""
tq_max_kv_splits_for_cuda_graph: int = 32
"""TurboQuant max NUM_KV_SPLITS for cuda graph decode.
Fixes the split count so grid dimensions are constant across captures,
and buffers can be pre-allocated to avoid inflating the memory estimate."""
use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""
......
......@@ -31,6 +31,10 @@ CacheDType = Literal[
"fp8_e5m2",
"fp8_inc",
"fp8_ds_mla",
"turboquant_k8v4",
"turboquant_4bit_nc",
"turboquant_k3v4_nc",
"turboquant_3bit_nc",
"int8",
]
MambaDType = Literal["auto", "float32", "float16"]
......@@ -111,6 +115,9 @@ class CacheConfig:
"""This enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
checkpoint if available. Otherwise, the scales will default to 1.0."""
kv_cache_dtype_skip_layers: list[str] = field(default_factory=list)
"""Layer patterns to skip KV cache quantization. Accepts layer indices
(e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
cpu_kvcache_space_bytes: int | None = None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded: int | None = None
......
......@@ -556,6 +556,9 @@ class EngineArgs:
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
kv_cache_dtype_skip_layers: list[str] = get_field(
CacheConfig, "kv_cache_dtype_skip_layers"
)
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
......@@ -931,6 +934,9 @@ class EngineArgs:
cache_group.add_argument(
"--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
)
cache_group.add_argument(
"--kv-cache-dtype-skip-layers", **cache_kwargs["kv_cache_dtype_skip_layers"]
)
cache_group.add_argument(
"--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
)
......@@ -1420,6 +1426,7 @@ class EngineArgs:
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_cache_dtype_skip_layers=self.kv_cache_dtype_skip_layers,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
......@@ -1429,6 +1436,30 @@ class EngineArgs:
kv_offloading_backend=self.kv_offloading_backend,
)
# TurboQuant: auto-skip first/last 2 layers (boundary protection).
# These layers are most sensitive to quantization error.
# Users can add extra layers via --kv-cache-dtype-skip-layers.
# Disabled for hybrid models (attn+mamba) — mixed page sizes break
# the required page size unification.
if (
resolved_cache_dtype.startswith("turboquant_")
and not model_config.is_hybrid
):
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
num_layers = model_config.hf_text_config.num_hidden_layers
boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers)
existing = set(cache_config.kv_cache_dtype_skip_layers)
merged = sorted(existing | set(boundary), key=lambda x: int(x))
cache_config.kv_cache_dtype_skip_layers = merged
logger.info(
"TQ: skipping layers %s for boundary protection (num_layers=%d)",
merged,
num_layers,
)
ray_runtime_env = None
if is_ray_initialized():
# Ray Serve LLM calls `create_engine_config` in the context
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant: Near-optimal KV-cache quantization for vLLM.
PolarQuant compression: random rotation + per-coordinate Lloyd-Max
scalar quantization for keys, uniform quantization for values.
Reference: "TurboQuant: Online Vector Quantization with Near-optimal
Distortion Rate" (ICLR 2026), Zandieh et al.
"""
from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig
__all__ = ["TurboQuantConfig"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Lloyd-Max optimal scalar quantizer for TurboQuant.
After rotating a d-dimensional unit vector by a random orthogonal matrix,
each coordinate approximately follows N(0, 1/d) for d >= 64.
We solve the Lloyd-Max conditions to find optimal centroids.
Based on: turboquant-pytorch/lloyd_max.py (Zandieh et al.)
"""
import math
from functools import lru_cache
import torch
def _gaussian_pdf(x: float, sigma2: float) -> float:
return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
def _trapz(f, a: float, b: float, n: int = 200) -> float:
"""Trapezoidal numerical integration (replaces scipy.integrate.quad)."""
h = (b - a) / n
result = 0.5 * (f(a) + f(b))
for i in range(1, n):
result += f(a + i * h)
return result * h
def solve_lloyd_max(
d: int,
bits: int,
max_iter: int = 200,
tol: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Solve Lloyd-Max optimal quantizer for N(0, 1/d) distribution.
Args:
d: Vector dimension (determines variance = 1/d).
bits: Number of quantization bits.
max_iter: Maximum Lloyd-Max iterations.
tol: Convergence tolerance.
Returns:
centroids: Sorted tensor of 2^bits optimal centroids.
boundaries: Sorted tensor of 2^bits - 1 decision boundaries.
"""
n_levels = 2**bits
sigma2 = 1.0 / d
sigma = math.sqrt(sigma2)
def pdf(x):
return _gaussian_pdf(x, sigma2)
lo, hi = -3.5 * sigma, 3.5 * sigma
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
for _ in range(max_iter):
boundaries = [
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
]
edges = [lo * 3] + boundaries + [hi * 3]
new_centroids = []
for i in range(n_levels):
a, b = edges[i], edges[i + 1]
num = _trapz(lambda x: x * pdf(x), a, b)
den = _trapz(pdf, a, b)
new_centroids.append(num / den if den > 1e-15 else centroids[i])
if max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) < tol:
break
centroids = new_centroids
boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
return (
torch.tensor(centroids, dtype=torch.float32),
torch.tensor(boundaries, dtype=torch.float32),
)
@lru_cache(maxsize=32)
def get_centroids(d: int, bits: int) -> torch.Tensor:
"""Get precomputed Lloyd-Max centroids (cached)."""
centroids, _ = solve_lloyd_max(d, bits)
return centroids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant configuration."""
import math
from dataclasses import dataclass
# Named TQ presets: each maps to frozen config parameters.
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
# value_quant_bits: 3-4 = uniform quantized values.
TQ_PRESETS: dict[str, dict] = {
"turboquant_k8v4": {
"key_quant_bits": 8,
"value_quant_bits": 4,
"norm_correction": False,
},
"turboquant_4bit_nc": {
"key_quant_bits": 4,
"value_quant_bits": 4,
"norm_correction": True,
},
"turboquant_k3v4_nc": {
"key_quant_bits": 3,
"value_quant_bits": 4,
"norm_correction": True,
},
"turboquant_3bit_nc": {
"key_quant_bits": 3,
"value_quant_bits": 3,
"norm_correction": True,
},
}
@dataclass
class TurboQuantConfig:
"""Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
and uniform quantization for values. QJL is intentionally omitted —
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax.
Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
Args:
head_dim: Attention head dimension (e.g. 64, 96, 128).
key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit.
"""
head_dim: int = 128
key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits: int = 4 # 3-4 = uniform quantized values
seed: int = 42
norm_correction: bool = False
@property
def key_fp8(self) -> bool:
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
return self.key_quant_bits == 8
@property
def mse_bits(self) -> int:
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
For MSE key modes, equals key_quant_bits.
For FP8 key mode, falls back to value_quant_bits (centroids are still
needed for continuation-prefill dequant and decode kernel params).
"""
if self.key_fp8:
return self.value_quant_bits
return self.key_quant_bits
@property
def key_mse_bits(self) -> int:
"""MSE bits actually used for key quantization (0 if FP8 keys)."""
if self.key_fp8:
return 0
return self.key_quant_bits
@property
def centroid_bits(self) -> int:
"""Bits for centroid generation — always non-zero."""
return self.mse_bits
@property
def n_centroids(self) -> int:
return 2**self.mse_bits
@property
def key_packed_size(self) -> int:
"""Packed bytes for a single KEY vector.
FP8 mode (key_quant_bits=8):
head_dim bytes (1 byte per element, no overhead).
TQ mode:
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)
- res_norm: 2 bytes (float16)
"""
if self.key_fp8:
return self.head_dim # 1 byte per element
mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8)
norm_bytes = 4 # 2x float16
return mse_bytes + norm_bytes
@property
def effective_value_quant_bits(self) -> int:
"""Actual bits used for value storage."""
return self.value_quant_bits
@property
def value_packed_size(self) -> int:
"""Packed bytes for a single VALUE vector.
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
"""
data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8)
return data_bytes + 4 # +2 scale(fp16) +2 zero(fp16)
@property
def slot_size(self) -> int:
"""Total packed bytes per head per position (key + value combined).
Layout: [key_packed | value_packed]
"""
return self.key_packed_size + self.value_packed_size
@property
def slot_size_aligned(self) -> int:
"""Slot size rounded up to next even number.
Even-number is required so effective_head_size = slot_size_aligned // 2
is integral.
"""
s = self.slot_size
return s + (s % 2) # round up to even
@staticmethod
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
"""Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
"""
if n <= 0 or num_layers <= 0:
return []
n = min(n, num_layers // 2) # don't skip more than half
first = list(range(n))
last = list(range(num_layers - n, num_layers))
# Deduplicate (if num_layers <= 2*n)
indices = sorted(set(first + last))
return [str(i) for i in indices]
@staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
"""
if cache_dtype not in TQ_PRESETS:
valid = ", ".join(TQ_PRESETS.keys())
raise ValueError(
f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
f"Valid presets: {valid}"
)
preset = TQ_PRESETS[cache_dtype]
return TurboQuantConfig(
head_dim=head_dim,
key_quant_bits=preset["key_quant_bits"],
value_quant_bits=preset["value_quant_bits"],
norm_correction=preset["norm_correction"],
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant quantizer utilities.
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
generate_rotation_matrix() is retained for standalone benchmarks only.
Triton kernels handle all quantization, packing, and dequantization on GPU.
"""
import torch
def generate_rotation_matrix(
d: int, seed: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""Generate Haar-distributed random orthogonal matrix via QR decomposition."""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
G = torch.randn(d, d, generator=gen, device="cpu", dtype=torch.float32)
Q, R = torch.linalg.qr(G)
# Fix sign ambiguity for determinism
diag_sign = torch.sign(torch.diag(R))
diag_sign[diag_sign == 0] = 1.0
Q = Q * diag_sign.unsqueeze(0)
return Q.to(device)
def generate_wht_signs(
d: int, seed: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""Generate deterministic random ±1 signs for WHT rotation.
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
"""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
bits = torch.randint(0, 2, (d,), generator=gen, device="cpu")
signs = bits.float() * 2 - 1
return signs.to(device)
......@@ -280,6 +280,11 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = []
invalid_reasons = {}
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
return [(AttentionBackendEnum.TURBOQUANT, 0)], {}
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
......
......@@ -264,6 +264,12 @@ class RocmPlatform(Platform):
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
logger.info_once("Using TurboQuant attention backend.")
return AttentionBackendEnum.TURBOQUANT.get_path()
if attn_selector_config.use_sparse:
# if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
# raise ValueError(
......
......@@ -52,6 +52,12 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
logger.info_once("Using TurboQuant attention backend.")
return AttentionBackendEnum.TURBOQUANT.get_path()
dtype = attn_selector_config.dtype
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment