Commit c3270a92 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.15.1-dev' into v0.15.1-dev

parents feced2f1 0b7cc6cf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant: Near-optimal KV-cache quantization for vLLM.
PolarQuant compression: random rotation + per-coordinate Lloyd-Max
scalar quantization for keys, uniform quantization for values.
Reference: "TurboQuant: Online Vector Quantization with Near-optimal
Distortion Rate" (ICLR 2026), Zandieh et al.
"""
from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig
__all__ = ["TurboQuantConfig"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Lloyd-Max optimal scalar quantizer for TurboQuant.
After rotating a d-dimensional unit vector by a random orthogonal matrix,
each coordinate approximately follows N(0, 1/d) for d >= 64.
We solve the Lloyd-Max conditions to find optimal centroids.
Based on: turboquant-pytorch/lloyd_max.py (Zandieh et al.)
"""
import math
from functools import lru_cache
import torch
def _gaussian_pdf(x: float, sigma2: float) -> float:
return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
def _trapz(f, a: float, b: float, n: int = 200) -> float:
"""Trapezoidal numerical integration (replaces scipy.integrate.quad)."""
h = (b - a) / n
result = 0.5 * (f(a) + f(b))
for i in range(1, n):
result += f(a + i * h)
return result * h
def solve_lloyd_max(
d: int,
bits: int,
max_iter: int = 200,
tol: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Solve Lloyd-Max optimal quantizer for N(0, 1/d) distribution.
Args:
d: Vector dimension (determines variance = 1/d).
bits: Number of quantization bits.
max_iter: Maximum Lloyd-Max iterations.
tol: Convergence tolerance.
Returns:
centroids: Sorted tensor of 2^bits optimal centroids.
boundaries: Sorted tensor of 2^bits - 1 decision boundaries.
"""
n_levels = 2**bits
sigma2 = 1.0 / d
sigma = math.sqrt(sigma2)
def pdf(x):
return _gaussian_pdf(x, sigma2)
lo, hi = -3.5 * sigma, 3.5 * sigma
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
for _ in range(max_iter):
boundaries = [
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
]
edges = [lo * 3] + boundaries + [hi * 3]
new_centroids = []
for i in range(n_levels):
a, b = edges[i], edges[i + 1]
num = _trapz(lambda x: x * pdf(x), a, b)
den = _trapz(pdf, a, b)
new_centroids.append(num / den if den > 1e-15 else centroids[i])
if max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) < tol:
break
centroids = new_centroids
boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
return (
torch.tensor(centroids, dtype=torch.float32),
torch.tensor(boundaries, dtype=torch.float32),
)
@lru_cache(maxsize=32)
def get_centroids(d: int, bits: int) -> torch.Tensor:
"""Get precomputed Lloyd-Max centroids (cached)."""
centroids, _ = solve_lloyd_max(d, bits)
return centroids
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant configuration."""
import math
from dataclasses import dataclass
# Named TQ presets: each maps to frozen config parameters.
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
# value_quant_bits: 3-4 = uniform quantized values.
TQ_PRESETS: dict[str, dict] = {
"turboquant_k8v4": {
"key_quant_bits": 8,
"value_quant_bits": 4,
"norm_correction": False,
},
"turboquant_4bit_nc": {
"key_quant_bits": 4,
"value_quant_bits": 4,
"norm_correction": True,
},
"turboquant_k3v4_nc": {
"key_quant_bits": 3,
"value_quant_bits": 4,
"norm_correction": True,
},
"turboquant_3bit_nc": {
"key_quant_bits": 3,
"value_quant_bits": 3,
"norm_correction": True,
},
}
@dataclass
class TurboQuantConfig:
"""Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
and uniform quantization for values. QJL is intentionally omitted —
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax.
Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
Args:
head_dim: Attention head dimension (e.g. 64, 96, 128).
key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit.
"""
head_dim: int = 128
key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits: int = 4 # 3-4 = uniform quantized values
seed: int = 42
norm_correction: bool = False
@property
def key_fp8(self) -> bool:
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
return self.key_quant_bits == 8
@property
def mse_bits(self) -> int:
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
For MSE key modes, equals key_quant_bits.
For FP8 key mode, falls back to value_quant_bits (centroids are still
needed for continuation-prefill dequant and decode kernel params).
"""
if self.key_fp8:
return self.value_quant_bits
return self.key_quant_bits
@property
def key_mse_bits(self) -> int:
"""MSE bits actually used for key quantization (0 if FP8 keys)."""
if self.key_fp8:
return 0
return self.key_quant_bits
@property
def centroid_bits(self) -> int:
"""Bits for centroid generation — always non-zero."""
return self.mse_bits
@property
def n_centroids(self) -> int:
return 2**self.mse_bits
@property
def key_packed_size(self) -> int:
"""Packed bytes for a single KEY vector.
FP8 mode (key_quant_bits=8):
head_dim bytes (1 byte per element, no overhead).
TQ mode:
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)
- res_norm: 2 bytes (float16)
"""
if self.key_fp8:
return self.head_dim # 1 byte per element
mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8)
norm_bytes = 4 # 2x float16
return mse_bytes + norm_bytes
@property
def effective_value_quant_bits(self) -> int:
"""Actual bits used for value storage."""
return self.value_quant_bits
@property
def value_packed_size(self) -> int:
"""Packed bytes for a single VALUE vector.
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
"""
data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8)
return data_bytes + 4 # +2 scale(fp16) +2 zero(fp16)
@property
def slot_size(self) -> int:
"""Total packed bytes per head per position (key + value combined).
Layout: [key_packed | value_packed]
"""
return self.key_packed_size + self.value_packed_size
@property
def slot_size_aligned(self) -> int:
"""Slot size rounded up to next even number.
Even-number is required so effective_head_size = slot_size_aligned // 2
is integral.
"""
s = self.slot_size
return s + (s % 2) # round up to even
@staticmethod
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
"""Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
"""
if n <= 0 or num_layers <= 0:
return []
n = min(n, num_layers // 2) # don't skip more than half
first = list(range(n))
last = list(range(num_layers - n, num_layers))
# Deduplicate (if num_layers <= 2*n)
indices = sorted(set(first + last))
return [str(i) for i in indices]
@staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
"""
if cache_dtype not in TQ_PRESETS:
valid = ", ".join(TQ_PRESETS.keys())
raise ValueError(
f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
f"Valid presets: {valid}"
)
preset = TQ_PRESETS[cache_dtype]
return TurboQuantConfig(
head_dim=head_dim,
key_quant_bits=preset["key_quant_bits"],
value_quant_bits=preset["value_quant_bits"],
norm_correction=preset["norm_correction"],
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant quantizer utilities.
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
generate_rotation_matrix() is retained for standalone benchmarks only.
Triton kernels handle all quantization, packing, and dequantization on GPU.
"""
import torch
def generate_rotation_matrix(
d: int, seed: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""Generate Haar-distributed random orthogonal matrix via QR decomposition."""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
G = torch.randn(d, d, generator=gen, device="cpu", dtype=torch.float32)
Q, R = torch.linalg.qr(G)
# Fix sign ambiguity for determinism
diag_sign = torch.sign(torch.diag(R))
diag_sign[diag_sign == 0] = 1.0
Q = Q * diag_sign.unsqueeze(0)
return Q.to(device)
def generate_wht_signs(
d: int, seed: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""Generate deterministic random ±1 signs for WHT rotation.
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
"""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
bits = torch.randint(0, 2, (d,), generator=gen, device="cpu")
signs = bits.float() * 2 - 1
return signs.to(device)
......@@ -280,6 +280,11 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = []
invalid_reasons = {}
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
return [(AttentionBackendEnum.TURBOQUANT, 0)], {}
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
......
......@@ -264,6 +264,12 @@ class RocmPlatform(Platform):
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
logger.info_once("Using TurboQuant attention backend.")
return AttentionBackendEnum.TURBOQUANT.get_path()
if attn_selector_config.use_sparse:
# if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
# raise ValueError(
......
......@@ -52,6 +52,12 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
logger.info_once("Using TurboQuant attention backend.")
return AttentionBackendEnum.TURBOQUANT.get_path()
dtype = attn_selector_config.dtype
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
......
......@@ -42,6 +42,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8,
"turboquant_k8v4": torch.uint8,
"turboquant_4bit_nc": torch.uint8,
"turboquant_k3v4_nc": torch.uint8,
"turboquant_3bit_nc": torch.uint8,
}
TORCH_DTYPE_TO_NUMPY_DTYPE = {
......
......@@ -202,7 +202,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
AttentionCGSupport.UNIFORM_BATCH
)
reorder_batch_threshold: int = 1
......
......@@ -78,6 +78,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
TURBOQUANT = "vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM = None
......
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused Triton kernels for TurboQuant KV store.
Two kernels:
1. _tq_fused_store_fp8: FP8 key scatter + value uniform quantization.
2. _tq_fused_store_mse: Fused bucketize + centroid gather + residual norm
+ MSE index packing + value quantization (eliminates 4 PyTorch kernel
launches vs the old pack-only approach).
The launcher `triton_turboquant_store` selects the appropriate kernel.
"""
import math
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.attention.ops.triton_turboquant_decode import _use_fp8_e4b15
# ═══════════════════════════════════════════════════════════════════════
# Shared: value uniform quantization + pack + scale/zero store
# ═══════════════════════════════════════════════════════════════════════
@triton.jit
def _store_quantized_value(
Value_ptr,
KV_cache_ptr,
base, # pid * D offset into Value_ptr
slot_base, # byte offset into KV_cache_ptr for this slot+head
d_offs, # tl.arange(0, BLOCK_D)
d_mask, # d_offs < D
D: tl.constexpr,
KPS: tl.constexpr,
VQB: tl.constexpr,
VAL_DATA_BYTES: tl.constexpr,
BLOCK_VAL: tl.constexpr,
BLOCK_GRP: tl.constexpr,
):
"""Uniform quantization of values to VQB bits, pack, and store with scale/zero."""
val_cache_offset = KPS
if VQB == 3:
val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to(
tl.float32
)
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
v_scale = (val_max - val_min) / 7.0
v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8)
q_vals = tl.minimum(
tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 7
)
grp_offs = tl.arange(0, BLOCK_GRP)
grp_mask = grp_offs < (D // 8)
q_grp = tl.reshape(q_vals, [BLOCK_GRP, 8])
shifts_3bit = tl.arange(0, 8) * 3
packed_24 = tl.sum(q_grp << shifts_3bit[None, :], axis=1)
b0 = (packed_24 & 0xFF).to(tl.uint8)
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3,
b0,
mask=grp_mask,
)
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 1,
b1,
mask=grp_mask,
)
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 2,
b2,
mask=grp_mask,
)
sc_offset = val_cache_offset + VAL_DATA_BYTES
sc_f16 = v_scale.to(tl.float16)
sc_u16 = sc_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 1,
((sc_u16 >> 8) & 0xFF).to(tl.uint8),
)
zr_f16 = val_min.to(tl.float16)
zr_u16 = zr_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 3,
((zr_u16 >> 8) & 0xFF).to(tl.uint8),
)
else: # VQB == 4
val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to(
tl.float32
)
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
v_scale = (val_max - val_min) / 15.0
v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8)
val_offs = tl.arange(0, BLOCK_VAL)
val_mask = val_offs < VAL_DATA_BYTES
v0 = tl.load(
Value_ptr + base + val_offs * 2,
mask=val_mask & (val_offs * 2 < D),
other=val_min,
)
v1 = tl.load(
Value_ptr + base + val_offs * 2 + 1,
mask=val_mask & (val_offs * 2 + 1 < D),
other=val_min,
)
q0 = tl.minimum(
tl.maximum(((v0 - val_min) / v_scale + 0.5).to(tl.int32), 0), 15
)
q1 = tl.minimum(
tl.maximum(((v1 - val_min) / v_scale + 0.5).to(tl.int32), 0), 15
)
packed_val = (q0 | (q1 << 4)).to(tl.uint8)
tl.store(
KV_cache_ptr + slot_base + val_cache_offset + val_offs,
packed_val,
mask=val_mask,
)
sc_offset = val_cache_offset + VAL_DATA_BYTES
sc_f16 = v_scale.to(tl.float16)
sc_u16 = sc_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 1,
((sc_u16 >> 8) & 0xFF).to(tl.uint8),
)
zr_f16 = val_min.to(tl.float16)
zr_u16 = zr_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + sc_offset + 3,
((zr_u16 >> 8) & 0xFF).to(tl.uint8),
)
# ═══════════════════════════════════════════════════════════════════════
# FP8 key store + value uniform quantization
# ═══════════════════════════════════════════════════════════════════════
@triton.jit
def _tq_fused_store_fp8(
Key_ptr, # [NH, D] float16/bfloat16 — raw keys
Value_ptr, # [NH, D] float16/bfloat16 — raw values
KV_cache_ptr, # [total_bytes] uint8 (flattened view)
Slot_mapping_ptr, # [N] int32 — per-token slot indices
# Cache strides (for computing byte offsets)
stride_cache_block: tl.constexpr,
stride_cache_pos: tl.constexpr,
stride_cache_head: tl.constexpr,
# Dimensions
D: tl.constexpr,
H: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_D: tl.constexpr,
# TQ layout
KPS: tl.constexpr,
# Value quantization
VQB: tl.constexpr,
VAL_DATA_BYTES: tl.constexpr,
# Packing block sizes
BLOCK_VAL: tl.constexpr,
BLOCK_GRP: tl.constexpr = 16,
FP8_E4B15: tl.constexpr = 0, # 1 = e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
"""FP8 key cast+scatter + value uniform quantization: one program per (token, head)."""
pid = tl.program_id(0)
token_idx = pid // H
head_idx = pid % H
slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0:
return
blk = slot // BLOCK_SIZE
off = slot % BLOCK_SIZE
slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
)
base = pid * D
# ── FP8 KEY: cast to FP8 in-kernel and store ─────────────────
d_offs = tl.arange(0, BLOCK_D)
d_mask = d_offs < D
k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0)
if FP8_E4B15:
k_fp8 = k_vals.to(tl.float8e4b15)
else:
x_f32 = k_vals.to(tl.float32)
k_fp8 = x_f32.to(tl.float8e4nv)
k_bytes = k_fp8.to(tl.uint8, bitcast=True)
tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask)
# ── VALUE QUANTIZE + PACK ───────────────────────────────────────
_store_quantized_value(
Value_ptr,
KV_cache_ptr,
base,
slot_base,
d_offs,
d_mask,
D=D,
KPS=KPS,
VQB=VQB,
VAL_DATA_BYTES=VAL_DATA_BYTES,
BLOCK_VAL=BLOCK_VAL,
BLOCK_GRP=BLOCK_GRP,
)
# ═══════════════════════════════════════════════════════════════════════
# Fused MSE store: bucketize + centroid gather + residual norm + pack
# (eliminates 4 PyTorch kernel launches per layer vs pack-only kernel)
# ═══════════════════════════════════════════════════════════════════════
@triton.jit
def _tq_fused_store_mse(
# Post-rotation inputs
Y_ptr, # [NH, D] float32 — rotated normalized keys (x_hat @ PiT)
Norms_ptr, # [NH] float32 — key vector norms (||k||)
Value_ptr, # [NH, D] float32 — raw values
# Quantization tables
Centroids_ptr, # [n_centroids] float32
Midpoints_ptr, # [n_centroids-1] float32
# Cache and indexing
KV_cache_ptr, # [total_bytes] uint8 (flattened view)
Slot_mapping_ptr, # [N] int32 — per-token slot indices
# Cache strides
stride_cache_block: tl.constexpr,
stride_cache_pos: tl.constexpr,
stride_cache_head: tl.constexpr,
# Dimensions
D: tl.constexpr,
H: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_D: tl.constexpr,
# TQ layout
MSE_BYTES: tl.constexpr,
KPS: tl.constexpr,
# Value quantization
VQB: tl.constexpr,
VAL_DATA_BYTES: tl.constexpr,
# Packing block sizes
BLOCK_VAL: tl.constexpr,
# MSE params
MSE_BITS: tl.constexpr,
N_CENTROIDS: tl.constexpr,
BLOCK_GRP: tl.constexpr = 16,
):
"""Fused MSE quantize + pack + store.
Performs bucketize, centroid gather, residual norm, MSE index packing,
and value quantization in one kernel — eliminates 4 PyTorch kernel
launches (bucketize, gather, subtract, norm) per layer vs pack-only.
"""
pid = tl.program_id(0)
token_idx = pid // H
head_idx = pid % H
slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0:
return
blk = slot // BLOCK_SIZE
off = slot % BLOCK_SIZE
slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
)
base = pid * D
d_offs = tl.arange(0, BLOCK_D)
d_mask = d_offs < D
# ── 1. INLINE BUCKETIZE ──────────────────────────────────────────
y_vec = tl.load(Y_ptr + base + d_offs, mask=d_mask, other=0.0)
idx = tl.zeros([BLOCK_D], dtype=tl.int32)
for i in range(N_CENTROIDS - 1):
mid_val = tl.load(Midpoints_ptr + i)
idx += tl.where(y_vec >= mid_val, 1, 0)
# ── 2. CENTROID GATHER + RESIDUAL NORM ────────────────────────────
centroid_vals = tl.load(Centroids_ptr + idx, mask=d_mask, other=0.0)
residual = y_vec - centroid_vals
gamma = tl.sqrt(tl.sum(tl.where(d_mask, residual * residual, 0.0), axis=0))
# ── 3. PACK MSE INDICES from register idx ─────────────────────────
if MSE_BITS == 4:
idx_pairs = tl.reshape(idx, [BLOCK_D // 2, 2])
shifts_4 = tl.arange(0, 2) * 4
packed = tl.sum((idx_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8)
mse_offs = tl.arange(0, BLOCK_D // 2)
mse_mask = mse_offs < MSE_BYTES
tl.store(KV_cache_ptr + slot_base + mse_offs, packed, mask=mse_mask)
elif MSE_BITS == 3:
grp_offs = tl.arange(0, BLOCK_GRP)
grp_mask = grp_offs < (D // 8)
idx_grp = tl.reshape(idx, [BLOCK_GRP, 8])
shifts_3 = tl.arange(0, 8) * 3
packed_24 = tl.sum((idx_grp & 0x7) << shifts_3[None, :], axis=1)
b0 = (packed_24 & 0xFF).to(tl.uint8)
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
tl.store(KV_cache_ptr + slot_base + grp_offs * 3, b0, mask=grp_mask)
tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 1, b1, mask=grp_mask)
tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 2, b2, mask=grp_mask)
# ── 4. STORE NORMS (vec_norm + gamma as fp16) ─────────────────────
norm_offset = MSE_BYTES
vn_f16 = tl.load(Norms_ptr + pid).to(tl.float16)
vn_u16 = vn_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + norm_offset, (vn_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + norm_offset + 1, ((vn_u16 >> 8) & 0xFF).to(tl.uint8)
)
gm_f16 = gamma.to(tl.float16)
gm_u16 = gm_f16.to(tl.uint16, bitcast=True)
tl.store(KV_cache_ptr + slot_base + norm_offset + 2, (gm_u16 & 0xFF).to(tl.uint8))
tl.store(
KV_cache_ptr + slot_base + norm_offset + 3, ((gm_u16 >> 8) & 0xFF).to(tl.uint8)
)
# ── 5. VALUE QUANTIZE + PACK ──────────────────────────────────────
_store_quantized_value(
Value_ptr,
KV_cache_ptr,
base,
slot_base,
d_offs,
d_mask,
D=D,
KPS=KPS,
VQB=VQB,
VAL_DATA_BYTES=VAL_DATA_BYTES,
BLOCK_VAL=BLOCK_VAL,
BLOCK_GRP=BLOCK_GRP,
)
# ═══════════════════════════════════════════════════════════════════════
# Launcher
# ═══════════════════════════════════════════════════════════════════════
def triton_turboquant_store(
key: torch.Tensor, # [N, H, D] — raw keys (post-RoPE)
value: torch.Tensor, # [N, H, D] — raw values
kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8
slot_mapping: torch.Tensor, # [N] int32
PiT: torch.Tensor, # [D, D] float32
centroids: torch.Tensor, # [n_centroids] float32
midpoints: torch.Tensor, # [n_centroids-1] float32
mse_bits: int,
key_packed_size: int,
value_quant_bits: int,
key_fp8: bool = False,
):
"""Launch TQ store kernel — FP8 uses _tq_fused_store_fp8, MSE uses _tq_fused_store_mse."""
N, H, D = key.shape
NH = N * H
block_size = kv_cache.shape[1]
num_kv_heads = kv_cache.shape[2]
padded_slot = kv_cache.shape[3]
BLOCK_D = triton.next_power_of_2(D)
mse_bytes = math.ceil(D * mse_bits / 8)
n_centroids = 2**mse_bits
val_data_bytes = math.ceil(D * value_quant_bits / 8)
BLOCK_VAL = triton.next_power_of_2(val_data_bytes)
# Cache strides
stride_block = block_size * num_kv_heads * padded_slot
stride_pos = num_kv_heads * padded_slot
stride_head = padded_slot
block_grp = triton.next_power_of_2(D // 8) if D >= 8 else 1
# ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ──
if key_fp8:
k_flat = key.reshape(NH, D).contiguous()
v_flat = value.reshape(NH, D).contiguous()
fp8_e4b15 = _use_fp8_e4b15(key.device.index or 0)
grid = (NH,)
_tq_fused_store_fp8[grid](
k_flat,
v_flat,
kv_cache.view(-1),
slot_mapping,
stride_cache_block=stride_block,
stride_cache_pos=stride_pos,
stride_cache_head=stride_head,
D=D,
H=H,
BLOCK_SIZE=block_size,
BLOCK_D=BLOCK_D,
KPS=key_packed_size,
VQB=value_quant_bits,
VAL_DATA_BYTES=val_data_bytes,
BLOCK_VAL=BLOCK_VAL,
BLOCK_GRP=block_grp,
FP8_E4B15=fp8_e4b15,
num_warps=4,
num_stages=1,
)
return
# ── MSE PATH: external GEMM + fused bucketize/pack kernel ──
# Normalize + rotation GEMM externally (cuBLAS is faster than in-kernel)
k_flat = key.float().reshape(NH, D)
norms = k_flat.norm(dim=1, keepdim=True)
x_hat = k_flat / (norms + 1e-8)
y = (x_hat @ PiT).contiguous()
v_flat = value.float().reshape(NH, D)
# Fused kernel: bucketize + centroid gather + residual norm + pack
grid = (NH,)
_tq_fused_store_mse[grid](
y,
norms.squeeze(1),
v_flat,
centroids,
midpoints,
kv_cache.view(-1),
slot_mapping,
stride_cache_block=stride_block,
stride_cache_pos=stride_pos,
stride_cache_head=stride_head,
D=D,
H=H,
BLOCK_SIZE=block_size,
BLOCK_D=BLOCK_D,
MSE_BYTES=mse_bytes,
KPS=key_packed_size,
VQB=value_quant_bits,
VAL_DATA_BYTES=val_data_bytes,
BLOCK_VAL=BLOCK_VAL,
MSE_BITS=mse_bits,
N_CENTROIDS=n_centroids,
BLOCK_GRP=block_grp,
num_warps=4,
num_stages=1,
)
......@@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (
MLAAttentionSpec,
SinkFullAttentionSpec,
SlidingWindowSpec,
TQFullAttentionSpec,
)
from vllm.v1.request import Request
......@@ -51,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
self.enable_caching = enable_caching
# self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
......@@ -204,6 +206,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
)
req_blocks.extend(allocated_blocks)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_tokens_main_model: int
......@@ -230,6 +234,8 @@ class SingleTypeKVCacheManager(ABC):
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in new_blocks)
return new_blocks
def cache_blocks(self, request: Request, num_tokens: int) -> None:
......@@ -1048,6 +1054,7 @@ class SinkFullAttentionManager(FullAttentionManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
TQFullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
......
......@@ -187,6 +187,32 @@ class FullAttentionSpec(AttentionSpec):
)
@dataclass(frozen=True, kw_only=True)
class TQFullAttentionSpec(FullAttentionSpec):
"""FullAttentionSpec with TQ-aware page size.
Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
real_page_size_bytes to use TQ slot bytes instead of the raw
head_size * dtype formula.
"""
tq_slot_size: int = 0
@property
def real_page_size_bytes(self) -> int:
if self.tq_slot_size > 0:
return self.block_size * self.num_kv_heads * self.tq_slot_size
return super().real_page_size_bytes
@classmethod
def merge(cls, specs: list[Self]) -> Self:
merged = super().merge(specs)
assert all(s.tq_slot_size == specs[0].tq_slot_size for s in specs), (
"All TQ layers in the same KV cache group must use the same tq_slot_size."
)
return replace(merged, tq_slot_size=specs[0].tq_slot_size)
@dataclass(frozen=True, kw_only=True)
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this
......
......@@ -6,6 +6,7 @@ import numpy as np
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
......@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
"""
if parallel_config.data_parallel_size == 1:
if parallel_config.data_parallel_size == 1 or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency":
# Early exit.
return False, None, cudagraph_mode
......
......@@ -189,6 +189,7 @@ from .utils import (
sanity_check_mm_encoder_outputs,
)
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.utils.torch_utils import async_tensor_h2d
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
......@@ -5117,9 +5118,6 @@ class GPUModelRunner(
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = self._init_model_kwargs()
else:
self.input_ids.gpu[:num_tokens_padded] = torch.randint(0, self.model_config.get_vocab_size(),
(num_tokens_padded,),
dtype=torch.int32)
input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None
......@@ -5234,9 +5232,15 @@ class GPUModelRunner(
self.eplb_step(is_dummy=True, is_profile=is_profile)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
logit_indices_device = torch.from_numpy(logit_indices).to(
self.device, non_blocking=True
)
# logit_indices_device = torch.from_numpy(logit_indices).to(
# self.device, non_blocking=True
# )
logit_indices = logit_indices.tolist()
logit_indices_device = async_tensor_h2d(
logit_indices,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
return hidden_states, hidden_states[logit_indices_device]
@torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment