Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3270a92
Commit
c3270a92
authored
Apr 22, 2026
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.15.1-dev' into v0.15.1-dev
parents
feced2f1
0b7cc6cf
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2290 additions
and
8 deletions
+2290
-8
vllm/model_executor/layers/quantization/turboquant/__init__.py
...model_executor/layers/quantization/turboquant/__init__.py
+14
-0
vllm/model_executor/layers/quantization/turboquant/centroids.py
...odel_executor/layers/quantization/turboquant/centroids.py
+86
-0
vllm/model_executor/layers/quantization/turboquant/config.py
vllm/model_executor/layers/quantization/turboquant/config.py
+186
-0
vllm/model_executor/layers/quantization/turboquant/quantizer.py
...odel_executor/layers/quantization/turboquant/quantizer.py
+40
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+5
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+6
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+6
-0
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+4
-0
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+1
-1
vllm/v1/attention/backends/registry.py
vllm/v1/attention/backends/registry.py
+1
-0
vllm/v1/attention/backends/turboquant_attn.py
vllm/v1/attention/backends/turboquant_attn.py
+812
-0
vllm/v1/attention/ops/triton_turboquant_decode.py
vllm/v1/attention/ops/triton_turboquant_decode.py
+624
-0
vllm/v1/attention/ops/triton_turboquant_store.py
vllm/v1/attention/ops/triton_turboquant_store.py
+460
-0
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+7
-0
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+26
-0
vllm/v1/worker/dp_utils.py
vllm/v1/worker/dp_utils.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+10
-6
No files found.
vllm/model_executor/layers/quantization/turboquant/__init__.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant: Near-optimal KV-cache quantization for vLLM.
PolarQuant compression: random rotation + per-coordinate Lloyd-Max
scalar quantization for keys, uniform quantization for values.
Reference: "TurboQuant: Online Vector Quantization with Near-optimal
Distortion Rate" (ICLR 2026), Zandieh et al.
"""
from
vllm.model_executor.layers.quantization.turboquant.config
import
TurboQuantConfig
__all__
=
[
"TurboQuantConfig"
]
vllm/model_executor/layers/quantization/turboquant/centroids.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Lloyd-Max optimal scalar quantizer for TurboQuant.
After rotating a d-dimensional unit vector by a random orthogonal matrix,
each coordinate approximately follows N(0, 1/d) for d >= 64.
We solve the Lloyd-Max conditions to find optimal centroids.
Based on: turboquant-pytorch/lloyd_max.py (Zandieh et al.)
"""
import
math
from
functools
import
lru_cache
import
torch
def
_gaussian_pdf
(
x
:
float
,
sigma2
:
float
)
->
float
:
return
(
1.0
/
math
.
sqrt
(
2
*
math
.
pi
*
sigma2
))
*
math
.
exp
(
-
x
*
x
/
(
2
*
sigma2
))
def
_trapz
(
f
,
a
:
float
,
b
:
float
,
n
:
int
=
200
)
->
float
:
"""Trapezoidal numerical integration (replaces scipy.integrate.quad)."""
h
=
(
b
-
a
)
/
n
result
=
0.5
*
(
f
(
a
)
+
f
(
b
))
for
i
in
range
(
1
,
n
):
result
+=
f
(
a
+
i
*
h
)
return
result
*
h
def
solve_lloyd_max
(
d
:
int
,
bits
:
int
,
max_iter
:
int
=
200
,
tol
:
float
=
1e-10
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Solve Lloyd-Max optimal quantizer for N(0, 1/d) distribution.
Args:
d: Vector dimension (determines variance = 1/d).
bits: Number of quantization bits.
max_iter: Maximum Lloyd-Max iterations.
tol: Convergence tolerance.
Returns:
centroids: Sorted tensor of 2^bits optimal centroids.
boundaries: Sorted tensor of 2^bits - 1 decision boundaries.
"""
n_levels
=
2
**
bits
sigma2
=
1.0
/
d
sigma
=
math
.
sqrt
(
sigma2
)
def
pdf
(
x
):
return
_gaussian_pdf
(
x
,
sigma2
)
lo
,
hi
=
-
3.5
*
sigma
,
3.5
*
sigma
centroids
=
[
lo
+
(
hi
-
lo
)
*
(
i
+
0.5
)
/
n_levels
for
i
in
range
(
n_levels
)]
for
_
in
range
(
max_iter
):
boundaries
=
[
(
centroids
[
i
]
+
centroids
[
i
+
1
])
/
2.0
for
i
in
range
(
n_levels
-
1
)
]
edges
=
[
lo
*
3
]
+
boundaries
+
[
hi
*
3
]
new_centroids
=
[]
for
i
in
range
(
n_levels
):
a
,
b
=
edges
[
i
],
edges
[
i
+
1
]
num
=
_trapz
(
lambda
x
:
x
*
pdf
(
x
),
a
,
b
)
den
=
_trapz
(
pdf
,
a
,
b
)
new_centroids
.
append
(
num
/
den
if
den
>
1e-15
else
centroids
[
i
])
if
max
(
abs
(
new_centroids
[
i
]
-
centroids
[
i
])
for
i
in
range
(
n_levels
))
<
tol
:
break
centroids
=
new_centroids
boundaries
=
[(
centroids
[
i
]
+
centroids
[
i
+
1
])
/
2.0
for
i
in
range
(
n_levels
-
1
)]
return
(
torch
.
tensor
(
centroids
,
dtype
=
torch
.
float32
),
torch
.
tensor
(
boundaries
,
dtype
=
torch
.
float32
),
)
@
lru_cache
(
maxsize
=
32
)
def
get_centroids
(
d
:
int
,
bits
:
int
)
->
torch
.
Tensor
:
"""Get precomputed Lloyd-Max centroids (cached)."""
centroids
,
_
=
solve_lloyd_max
(
d
,
bits
)
return
centroids
vllm/model_executor/layers/quantization/turboquant/config.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant configuration."""
import
math
from
dataclasses
import
dataclass
# Named TQ presets: each maps to frozen config parameters.
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
# value_quant_bits: 3-4 = uniform quantized values.
TQ_PRESETS
:
dict
[
str
,
dict
]
=
{
"turboquant_k8v4"
:
{
"key_quant_bits"
:
8
,
"value_quant_bits"
:
4
,
"norm_correction"
:
False
,
},
"turboquant_4bit_nc"
:
{
"key_quant_bits"
:
4
,
"value_quant_bits"
:
4
,
"norm_correction"
:
True
,
},
"turboquant_k3v4_nc"
:
{
"key_quant_bits"
:
3
,
"value_quant_bits"
:
4
,
"norm_correction"
:
True
,
},
"turboquant_3bit_nc"
:
{
"key_quant_bits"
:
3
,
"value_quant_bits"
:
3
,
"norm_correction"
:
True
,
},
}
@
dataclass
class
TurboQuantConfig
:
"""Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
and uniform quantization for values. QJL is intentionally omitted —
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax.
Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
Args:
head_dim: Attention head dimension (e.g. 64, 96, 128).
key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit.
"""
head_dim
:
int
=
128
key_quant_bits
:
int
=
3
# 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits
:
int
=
4
# 3-4 = uniform quantized values
seed
:
int
=
42
norm_correction
:
bool
=
False
@
property
def
key_fp8
(
self
)
->
bool
:
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
return
self
.
key_quant_bits
==
8
@
property
def
mse_bits
(
self
)
->
int
:
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
For MSE key modes, equals key_quant_bits.
For FP8 key mode, falls back to value_quant_bits (centroids are still
needed for continuation-prefill dequant and decode kernel params).
"""
if
self
.
key_fp8
:
return
self
.
value_quant_bits
return
self
.
key_quant_bits
@
property
def
key_mse_bits
(
self
)
->
int
:
"""MSE bits actually used for key quantization (0 if FP8 keys)."""
if
self
.
key_fp8
:
return
0
return
self
.
key_quant_bits
@
property
def
centroid_bits
(
self
)
->
int
:
"""Bits for centroid generation — always non-zero."""
return
self
.
mse_bits
@
property
def
n_centroids
(
self
)
->
int
:
return
2
**
self
.
mse_bits
@
property
def
key_packed_size
(
self
)
->
int
:
"""Packed bytes for a single KEY vector.
FP8 mode (key_quant_bits=8):
head_dim bytes (1 byte per element, no overhead).
TQ mode:
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)
- res_norm: 2 bytes (float16)
"""
if
self
.
key_fp8
:
return
self
.
head_dim
# 1 byte per element
mse_bytes
=
math
.
ceil
(
self
.
head_dim
*
self
.
key_mse_bits
/
8
)
norm_bytes
=
4
# 2x float16
return
mse_bytes
+
norm_bytes
@
property
def
effective_value_quant_bits
(
self
)
->
int
:
"""Actual bits used for value storage."""
return
self
.
value_quant_bits
@
property
def
value_packed_size
(
self
)
->
int
:
"""Packed bytes for a single VALUE vector.
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
"""
data_bytes
=
math
.
ceil
(
self
.
head_dim
*
self
.
value_quant_bits
/
8
)
return
data_bytes
+
4
# +2 scale(fp16) +2 zero(fp16)
@
property
def
slot_size
(
self
)
->
int
:
"""Total packed bytes per head per position (key + value combined).
Layout: [key_packed | value_packed]
"""
return
self
.
key_packed_size
+
self
.
value_packed_size
@
property
def
slot_size_aligned
(
self
)
->
int
:
"""Slot size rounded up to next even number.
Even-number is required so effective_head_size = slot_size_aligned // 2
is integral.
"""
s
=
self
.
slot_size
return
s
+
(
s
%
2
)
# round up to even
@
staticmethod
def
get_boundary_skip_layers
(
num_layers
:
int
,
n
:
int
=
2
)
->
list
[
str
]:
"""Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
"""
if
n
<=
0
or
num_layers
<=
0
:
return
[]
n
=
min
(
n
,
num_layers
//
2
)
# don't skip more than half
first
=
list
(
range
(
n
))
last
=
list
(
range
(
num_layers
-
n
,
num_layers
))
# Deduplicate (if num_layers <= 2*n)
indices
=
sorted
(
set
(
first
+
last
))
return
[
str
(
i
)
for
i
in
indices
]
@
staticmethod
def
from_cache_dtype
(
cache_dtype
:
str
,
head_dim
:
int
)
->
"TurboQuantConfig"
:
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
"""
if
cache_dtype
not
in
TQ_PRESETS
:
valid
=
", "
.
join
(
TQ_PRESETS
.
keys
())
raise
ValueError
(
f
"Unknown TurboQuant cache dtype:
{
cache_dtype
!
r
}
. "
f
"Valid presets:
{
valid
}
"
)
preset
=
TQ_PRESETS
[
cache_dtype
]
return
TurboQuantConfig
(
head_dim
=
head_dim
,
key_quant_bits
=
preset
[
"key_quant_bits"
],
value_quant_bits
=
preset
[
"value_quant_bits"
],
norm_correction
=
preset
[
"norm_correction"
],
)
vllm/model_executor/layers/quantization/turboquant/quantizer.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant quantizer utilities.
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
generate_rotation_matrix() is retained for standalone benchmarks only.
Triton kernels handle all quantization, packing, and dequantization on GPU.
"""
import
torch
def
generate_rotation_matrix
(
d
:
int
,
seed
:
int
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
)
->
torch
.
Tensor
:
"""Generate Haar-distributed random orthogonal matrix via QR decomposition."""
gen
=
torch
.
Generator
(
device
=
"cpu"
)
gen
.
manual_seed
(
seed
)
G
=
torch
.
randn
(
d
,
d
,
generator
=
gen
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
Q
,
R
=
torch
.
linalg
.
qr
(
G
)
# Fix sign ambiguity for determinism
diag_sign
=
torch
.
sign
(
torch
.
diag
(
R
))
diag_sign
[
diag_sign
==
0
]
=
1.0
Q
=
Q
*
diag_sign
.
unsqueeze
(
0
)
return
Q
.
to
(
device
)
def
generate_wht_signs
(
d
:
int
,
seed
:
int
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
)
->
torch
.
Tensor
:
"""Generate deterministic random ±1 signs for WHT rotation.
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
"""
gen
=
torch
.
Generator
(
device
=
"cpu"
)
gen
.
manual_seed
(
seed
)
bits
=
torch
.
randint
(
0
,
2
,
(
d
,),
generator
=
gen
,
device
=
"cpu"
)
signs
=
bits
.
float
()
*
2
-
1
return
signs
.
to
(
device
)
vllm/platforms/cuda.py
View file @
c3270a92
...
...
@@ -280,6 +280,11 @@ class CudaPlatformBase(Platform):
valid_backends_priorities
=
[]
invalid_reasons
=
{}
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"turboquant_"
):
return
[(
AttentionBackendEnum
.
TURBOQUANT
,
0
)],
{}
backend_priorities
=
_get_backend_priorities
(
attn_selector_config
.
use_mla
,
device_capability
)
...
...
vllm/platforms/rocm.py
View file @
c3270a92
...
...
@@ -264,6 +264,12 @@ class RocmPlatform(Platform):
block_size
=
attn_selector_config
.
block_size
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"turboquant_"
):
logger
.
info_once
(
"Using TurboQuant attention backend."
)
return
AttentionBackendEnum
.
TURBOQUANT
.
get_path
()
if
attn_selector_config
.
use_sparse
:
# if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
# raise ValueError(
...
...
vllm/platforms/xpu.py
View file @
c3270a92
...
...
@@ -52,6 +52,12 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"turboquant_"
):
logger
.
info_once
(
"Using TurboQuant attention backend."
)
return
AttentionBackendEnum
.
TURBOQUANT
.
get_path
()
dtype
=
attn_selector_config
.
dtype
if
attn_selector_config
.
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on XPU."
)
...
...
vllm/utils/torch_utils.py
View file @
c3270a92
...
...
@@ -42,6 +42,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"int8"
:
torch
.
int8
,
"fp8_inc"
:
torch
.
float8_e4m3fn
,
"fp8_ds_mla"
:
torch
.
uint8
,
"turboquant_k8v4"
:
torch
.
uint8
,
"turboquant_4bit_nc"
:
torch
.
uint8
,
"turboquant_k3v4_nc"
:
torch
.
uint8
,
"turboquant_3bit_nc"
:
torch
.
uint8
,
}
TORCH_DTYPE_TO_NUMPY_DTYPE
=
{
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
c3270a92
...
...
@@ -202,7 +202,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class
DeepseekV32IndexerMetadataBuilder
(
AttentionMetadataBuilder
):
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
AttentionCGSupport
.
UNIFORM_
SINGLE_TOKEN_DECODE
AttentionCGSupport
.
UNIFORM_
BATCH
)
reorder_batch_threshold
:
int
=
1
...
...
vllm/v1/attention/backends/registry.py
View file @
c3270a92
...
...
@@ -78,6 +78,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN
=
"vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
TURBOQUANT
=
"vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM
=
None
...
...
vllm/v1/attention/backends/turboquant_attn.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant attention backend for vLLM.
Prefill: Standard scaled dot-product attention on uncompressed K/V,
then quantize K and store K+V into combined cache slot.
Decode: Compute TQ attention scores from compressed cache,
unpack FP16 values, softmax + weighted sum.
Cache layout (no leading 2 dimension):
(num_blocks, block_size, num_kv_heads, slot_size)
where slot_size = key_packed_size + value_fp16_size
Per-head per-position slot layout:
[key_packed (kps bytes) | value_fp16 (D*2 bytes)]
For turboquant_k3v4_nc head_dim=256: [100 bytes key | 512 bytes value] = 612
"""
import
functools
import
math
import
os
from
dataclasses
import
dataclass
from
typing
import
ClassVar
import
torch
import
torch.nn.functional
as
F
from
vllm.config
import
get_current_vllm_config
from
vllm.triton_utils
import
triton
from
vllm.utils.torch_utils
import
aux_stream
from
vllm.v1.attention.ops.triton_turboquant_decode
import
(
_tq_full_dequant_kv
,
_use_fp8_e4b15
,
triton_turboquant_decode_attention
,
)
from
vllm.v1.attention.ops.triton_turboquant_store
import
triton_turboquant_store
# CUDA stream overlap: disabled by default — degrades TTFT under concurrent
# load (489ms vs 338ms). Enable via TQ_STREAM_OVERLAP=1 for experimentation.
_USE_STREAM_OVERLAP
=
os
.
environ
.
get
(
"TQ_STREAM_OVERLAP"
,
"0"
)
==
"1"
# Continuation prefill: for small continuation chunks (q_len ≤ threshold),
# use the TQ decode kernel directly instead of full-dequant + flash_attn.
# do_kv_cache_update already stored all tokens to TQ cache, so the decode
# kernel can read them efficiently. This avoids O(cached_len) dequant work
# per continuation, eliminating the O(N²/chunk_size) collapse at long context.
_CONTINUATION_DECODE_THRESHOLD
=
128
from
vllm.config.cache
import
CacheDType
from
vllm.v1.attention.backends.fa_utils
import
(
is_flash_attn_varlen_func_available
,
)
_HAS_FLASH_ATTN
=
is_flash_attn_varlen_func_available
()
if
_HAS_FLASH_ATTN
:
from
vllm.v1.attention.backends.fa_utils
import
flash_attn_varlen_func
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
,
CommonAttentionMetadata
,
MultipleOf
,
)
from
vllm.v1.attention.backends.utils
import
split_decodes_and_prefills
@
functools
.
cache
def
_build_hadamard
(
d
:
int
,
device_str
:
str
)
->
torch
.
Tensor
:
"""Orthonormal Hadamard matrix (Sylvester construction), built on CPU.
Precomputed D×D matrix enables matmul-based WHT — single cuBLAS GEMM
instead of log2(D) butterfly kernel launches. 64KB for D=128.
"""
H
=
torch
.
tensor
([[
1.0
]])
while
H
.
shape
[
0
]
<
d
:
H
=
torch
.
cat
([
torch
.
cat
([
H
,
H
],
1
),
torch
.
cat
([
H
,
-
H
],
1
)],
0
)
return
(
H
/
math
.
sqrt
(
d
)).
to
(
torch
.
device
(
device_str
))
class
TurboQuantAttentionBackend
(
AttentionBackend
):
"""Attention backend using TurboQuant KV-cache compression."""
accept_output_buffer
:
bool
=
True
forward_includes_kv_cache_update
:
bool
=
False
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
,
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"turboquant_k8v4"
,
"turboquant_4bit_nc"
,
"turboquant_k3v4_nc"
,
"turboquant_3bit_nc"
,
]
@
staticmethod
def
get_name
()
->
str
:
return
"TURBOQUANT"
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
16
,
32
,
64
,
128
]
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
return
attn_type
==
AttentionType
.
DECODER
@
classmethod
def
supports_per_head_quant_scales
(
cls
)
->
bool
:
return
False
@
staticmethod
def
get_impl_cls
()
->
type
[
"TurboQuantAttentionImpl"
]:
return
TurboQuantAttentionImpl
@
staticmethod
def
get_builder_cls
()
->
type
[
"TurboQuantMetadataBuilder"
]:
return
TurboQuantMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"turboquant_4bit_nc"
,
)
->
tuple
[
int
,
...]:
"""Combined K+V cache shape — no leading 2 dimension.
Standard attention backends use (2, num_blocks, block_size, num_kv_heads,
head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V
into a single interleaved slot per head per position, so the cache is:
(num_blocks, block_size, num_kv_heads, slot_size_aligned)
Each slot = [key_packed | value_packed | padding].
This is safe because TQ has its own get_kv_cache_shape override and
never shares cache tensors with other backends. Layers that fall back
to native dtype via kv_cache_dtype_skip_layers get their own
standard-shaped cache allocation.
head_size is the model's real head_dim. slot_size_aligned is computed
from the TQ config to ensure correct cache allocation for all head dims.
"""
from
vllm.model_executor.layers.quantization.turboquant.config
import
(
TurboQuantConfig
,
)
tq_config
=
TurboQuantConfig
.
from_cache_dtype
(
cache_dtype_str
,
head_size
)
return
(
num_blocks
,
block_size
,
num_kv_heads
,
tq_config
.
slot_size_aligned
)
@
classmethod
def
supports_kv_cache_dtype
(
cls
,
kv_cache_dtype
:
CacheDType
|
None
)
->
bool
:
if
kv_cache_dtype
is
None
:
return
False
return
kv_cache_dtype
.
startswith
(
"turboquant_"
)
@
classmethod
def
supports_head_size
(
cls
,
head_size
:
int
)
->
bool
:
# head_size from spec is effective_head_size (padded_slot//2),
# not the model's actual head_dim. Accept any positive value.
return
head_size
>
0
@
dataclass
class
TurboQuantMetadata
(
AttentionMetadata
):
"""Metadata for TurboQuant attention."""
seq_lens
:
torch
.
Tensor
# (num_reqs,) — total context length per request
slot_mapping
:
torch
.
Tensor
# (num_tokens,) — cache slot for each token
block_table
:
torch
.
Tensor
# (num_reqs, max_num_blocks)
query_start_loc
:
torch
.
Tensor
# (num_reqs + 1,) — cu_seqlens for queries
num_actual_tokens
:
int
=
0
# actual tokens (excluding padding)
max_query_len
:
int
=
0
# longest query in batch
max_seq_len
:
int
=
0
# longest context in batch
is_prefill
:
bool
=
False
num_decodes
:
int
=
0
# number of decode requests (first in batch)
num_decode_tokens
:
int
=
0
# tokens from decode requests
class
TurboQuantMetadataBuilder
(
AttentionMetadataBuilder
[
TurboQuantMetadata
]):
"""Builds TurboQuantMetadata from scheduler output."""
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
def
__init__
(
self
,
kv_cache_spec
,
layer_names
,
vllm_config
,
device
):
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
self
.
_init_reorder_batch_threshold
(
1
,
supports_spec_as_decode
=
False
)
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
TurboQuantMetadata
:
attn_metadata
=
self
.
build
(
0
,
common_attn_metadata
)
# Set seq_lens to 1 so CUDA graph capture is fast
# (real seq_lens are filled at replay time).
attn_metadata
.
seq_lens
.
fill_
(
1
)
return
attn_metadata
def
build
(
self
,
common_prefix_len
,
common_attn_metadata
,
fast_build
=
False
):
"""Build TurboQuantMetadata from common attention metadata."""
cam
=
common_attn_metadata
# With reorder_batch_threshold=1, the model runner guarantees
# decodes come first in the batch. split_decodes_and_prefills
# finds the boundary (operates on CPU tensors — no GPU sync).
num_decodes
,
num_prefills
,
num_decode_tokens
,
_
=
split_decodes_and_prefills
(
cam
,
decode_threshold
=
self
.
reorder_batch_threshold
)
return
TurboQuantMetadata
(
seq_lens
=
cam
.
seq_lens
,
slot_mapping
=
cam
.
slot_mapping
,
block_table
=
cam
.
block_table_tensor
,
query_start_loc
=
cam
.
query_start_loc
,
num_actual_tokens
=
cam
.
num_actual_tokens
,
max_query_len
=
cam
.
max_query_len
,
max_seq_len
=
cam
.
max_seq_len
,
is_prefill
=
(
cam
.
max_query_len
>
1
),
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
)
class
TurboQuantAttentionImpl
(
AttentionImpl
[
"TurboQuantMetadata"
]):
"""TurboQuant attention implementation.
Vectorized PyTorch: batch quantize/store, vectorized bit-unpack
decode with einsum scores and value gather.
"""
supports_quant_query_input
:
bool
=
False
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
|
None
=
None
,
alibi_slopes
:
list
[
float
]
|
None
=
None
,
sliding_window
:
int
|
None
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
logits_soft_cap
:
float
|
None
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
str
|
None
=
None
,
**
kwargs
,
):
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
scale
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_groups
=
num_heads
//
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
from
vllm.model_executor.layers.quantization.turboquant.config
import
(
TurboQuantConfig
,
)
self
.
tq_config
=
TurboQuantConfig
.
from_cache_dtype
(
kv_cache_dtype
,
head_size
)
# Pre-compute kernel constants from config (avoid repeated arithmetic)
cfg
=
self
.
tq_config
self
.
_mse_bytes
=
(
math
.
ceil
(
head_size
*
cfg
.
key_mse_bits
/
8
)
if
not
cfg
.
key_fp8
else
head_size
)
self
.
_val_data_bytes
=
math
.
ceil
(
head_size
*
cfg
.
effective_value_quant_bits
/
8
)
self
.
_n_centroids
=
cfg
.
n_centroids
if
not
cfg
.
key_fp8
else
1
# Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
# and benchmarks show no regression vs dynamic in eager mode).
vllm_config
=
get_current_vllm_config
()
self
.
max_num_kv_splits
=
(
vllm_config
.
attention_config
.
tq_max_kv_splits_for_cuda_graph
)
def
_ensure_on_device
(
self
,
layer
,
device
):
"""One-time migration of TQ buffers to the correct device."""
if
layer
.
_tq_signs
.
device
!=
device
:
layer
.
_tq_signs
=
layer
.
_tq_signs
.
to
(
device
)
layer
.
_tq_centroids
=
layer
.
_tq_centroids
.
to
(
device
)
if
not
hasattr
(
layer
,
"_tq_cached"
):
D
=
layer
.
_tq_signs
.
shape
[
0
]
signs
=
layer
.
_tq_signs
.
float
()
# WHT rotation: orthonormal + self-inverse, enabling future
# in-kernel butterfly fusion and trivial inverse for continuation.
H
=
_build_hadamard
(
D
,
str
(
device
))
layer
.
_tq_PiT
=
(
signs
.
unsqueeze
(
1
)
*
H
).
contiguous
()
layer
.
_tq_Pi
=
layer
.
_tq_PiT
.
T
.
contiguous
()
c
=
layer
.
_tq_centroids
.
float
()
# Precompute midpoints for threshold-based quantization
c_sorted
,
_
=
c
.
sort
()
layer
.
_tq_midpoints
=
(
c_sorted
[:
-
1
]
+
c_sorted
[
1
:])
/
2
# Decode buffers are lazily allocated on first decode call.
# With fixed NUM_KV_SPLITS (cudagraph mode), the first warmup
# allocates them and subsequent captures reuse via buf_holder.
layer
.
_tq_mid_o_buf
=
None
layer
.
_tq_output_buf
=
None
layer
.
_tq_lse_buf
=
None
layer
.
_tq_cached
=
True
def
do_kv_cache_update
(
self
,
layer
:
torch
.
nn
.
Module
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
"""Store compressed K/V into the combined TQ cache.
Called as a separate custom op (unified_kv_cache_update) BEFORE
the attention forward, matching FlashAttention's split pattern.
slot_mapping is already sliced to num_actual_tokens by the caller.
With stream overlap enabled, the store runs on a secondary CUDA
stream so it can overlap with the next layer's forward pass.
"""
N
=
slot_mapping
.
shape
[
0
]
if
N
<=
0
:
return
device
=
key
.
device
self
.
_ensure_on_device
(
layer
,
device
)
k
=
key
[:
N
].
view
(
N
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
value
[:
N
].
view
(
N
,
self
.
num_kv_heads
,
self
.
head_size
)
# Use stream overlap only when not capturing CUDA graphs
stream
=
aux_stream
()
if
_USE_STREAM_OVERLAP
else
None
use_overlap
=
(
stream
is
not
None
and
not
torch
.
cuda
.
is_current_stream_capturing
()
)
if
use_overlap
:
# Wait for any previous store to finish before starting new one
torch
.
cuda
.
current_stream
(
device
).
wait_stream
(
stream
)
# Launch store on secondary stream
with
torch
.
cuda
.
stream
(
stream
):
self
.
_store_kv
(
k
,
v
,
kv_cache
,
slot_mapping
,
layer
.
_tq_centroids
,
layer
)
else
:
self
.
_store_kv
(
k
,
v
,
kv_cache
,
slot_mapping
,
layer
.
_tq_centroids
,
layer
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"TurboQuantMetadata"
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
num_tokens
=
query
.
shape
[
0
]
if
output
is
None
:
output
=
torch
.
zeros
(
num_tokens
,
self
.
num_heads
*
self
.
head_size
,
dtype
=
query
.
dtype
,
device
=
query
.
device
,
)
if
attn_metadata
is
None
:
return
output
.
fill_
(
0
)
# Slice to actual tokens
N
=
attn_metadata
.
num_actual_tokens
if
N
<=
0
:
return
output
.
fill_
(
0
)
q
=
query
[:
N
].
view
(
N
,
self
.
num_heads
,
self
.
head_size
)
# Get TQ buffers, ensure on device (one-time migration)
device
=
q
.
device
self
.
_ensure_on_device
(
layer
,
device
)
Pi
=
layer
.
_tq_Pi
PiT
=
layer
.
_tq_PiT
centroids
=
layer
.
_tq_centroids
# Ensure any async store has completed before decode reads cache
if
(
_USE_STREAM_OVERLAP
and
not
attn_metadata
.
is_prefill
and
not
torch
.
cuda
.
is_current_stream_capturing
()
):
stream
=
aux_stream
()
if
stream
is
not
None
:
torch
.
cuda
.
current_stream
(
device
).
wait_stream
(
stream
)
# Compute attention (KV cache was already updated by do_kv_cache_update)
# With reorder_batch_threshold=1, decodes come first in the batch.
# num_decodes/num_decode_tokens from metadata give the split point.
num_decodes
=
attn_metadata
.
num_decodes
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
if
not
attn_metadata
.
is_prefill
:
# Pure decode batch — fast path
attn_out
=
self
.
_decode_attention
(
q
,
kv_cache
,
attn_metadata
,
Pi
,
centroids
,
PiT
,
layer
)
elif
num_decodes
==
0
:
# Pure prefill batch
k
=
key
[:
N
].
view
(
N
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
value
[:
N
].
view
(
N
,
self
.
num_kv_heads
,
self
.
head_size
)
attn_out
=
self
.
_prefill_attention
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
Pi
,
centroids
,
PiT
)
else
:
# Mixed batch: decodes first (guaranteed by reorder_batch).
attn_out
=
torch
.
zeros
(
N
,
self
.
num_heads
,
self
.
head_size
,
device
=
device
,
dtype
=
q
.
dtype
)
# --- Decode portion (first num_decodes requests) ---
# Use full-batch max_seq_len as safe upper bound (no GPU sync).
decode_meta
=
TurboQuantMetadata
(
seq_lens
=
attn_metadata
.
seq_lens
[:
num_decodes
],
slot_mapping
=
attn_metadata
.
slot_mapping
[:
num_decode_tokens
],
block_table
=
attn_metadata
.
block_table
[:
num_decodes
],
query_start_loc
=
attn_metadata
.
query_start_loc
[:
num_decodes
+
1
],
num_actual_tokens
=
num_decode_tokens
,
max_query_len
=
1
,
max_seq_len
=
attn_metadata
.
max_seq_len
,
is_prefill
=
False
,
)
attn_out
[:
num_decode_tokens
]
=
self
.
_decode_attention
(
q
[:
num_decode_tokens
],
kv_cache
,
decode_meta
,
Pi
,
centroids
,
PiT
,
layer
)
# --- Prefill portion (remaining requests) ---
# CRITICAL: use prefill-specific max_seq_len so flash_attn's
# fast path (max_query_len == max_seq_len) triggers for
# first-chunk prefills. Using full-batch max_seq_len breaks
# this because decode requests inflate max_seq_len.
prefill_seq_lens
=
attn_metadata
.
seq_lens
[
num_decodes
:]
prefill_max_seq
=
prefill_seq_lens
.
max
().
item
()
prefill_qsl
=
(
attn_metadata
.
query_start_loc
[
num_decodes
:]
-
num_decode_tokens
)
prefill_meta
=
TurboQuantMetadata
(
seq_lens
=
prefill_seq_lens
,
slot_mapping
=
attn_metadata
.
slot_mapping
[
num_decode_tokens
:
N
],
block_table
=
attn_metadata
.
block_table
[
num_decodes
:],
query_start_loc
=
prefill_qsl
,
num_actual_tokens
=
N
-
num_decode_tokens
,
max_query_len
=
attn_metadata
.
max_query_len
,
max_seq_len
=
prefill_max_seq
,
is_prefill
=
True
,
)
k
=
key
[:
N
].
view
(
N
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
value
[:
N
].
view
(
N
,
self
.
num_kv_heads
,
self
.
head_size
)
attn_out
[
num_decode_tokens
:]
=
self
.
_prefill_attention
(
q
[
num_decode_tokens
:],
k
[
num_decode_tokens
:],
v
[
num_decode_tokens
:],
kv_cache
,
prefill_meta
,
Pi
,
centroids
,
PiT
,
)
# Write into output buffer: attn_out is (N, Hq, D)
# output may be 2D (N, Hq*D) or 3D (N, Hq, D)
if
output
.
ndim
==
3
:
output
[:
N
]
=
attn_out
.
to
(
output
.
dtype
)
else
:
output
[:
N
]
=
attn_out
.
reshape
(
N
,
-
1
).
to
(
output
.
dtype
)
return
output
# ------------------------------------------------------------------ #
# Store K/V into combined cache (vectorized) #
# ------------------------------------------------------------------ #
def
_store_kv
(
self
,
key
:
torch
.
Tensor
,
# (N, Hk, D)
value
:
torch
.
Tensor
,
# (N, Hk, D)
kv_cache
:
torch
.
Tensor
,
# (num_blocks, block_size, Hk, slot_size)
slot_mapping
:
torch
.
Tensor
,
centroids
:
torch
.
Tensor
,
layer
:
"AttentionLayer"
,
):
"""Quantize + store via fused Triton kernel."""
triton_turboquant_store
(
key
,
value
,
kv_cache
,
slot_mapping
,
layer
.
_tq_PiT
,
centroids
,
layer
.
_tq_midpoints
,
mse_bits
=
self
.
tq_config
.
key_mse_bits
,
key_packed_size
=
self
.
tq_config
.
key_packed_size
,
value_quant_bits
=
self
.
tq_config
.
effective_value_quant_bits
,
key_fp8
=
self
.
tq_config
.
key_fp8
,
)
# ------------------------------------------------------------------ #
# Prefill: SDPA on raw Q/K/V with causal mask #
# ------------------------------------------------------------------ #
def
_prefill_attention
(
self
,
query
:
torch
.
Tensor
,
# (N, Hq, D)
key
:
torch
.
Tensor
,
# (N, Hk, D)
value
:
torch
.
Tensor
,
# (N, Hk, D)
kv_cache
:
torch
.
Tensor
,
# (num_blocks, block_size, Hk, slot_size)
attn_metadata
:
TurboQuantMetadata
,
Pi
:
torch
.
Tensor
,
centroids
:
torch
.
Tensor
,
PiT
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
N
,
Hq
,
D
=
query
.
shape
# Fast path: use flash_attn for first-chunk prefills (all K/V in batch).
# max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync.
if
_HAS_FLASH_ATTN
and
attn_metadata
.
max_query_len
==
attn_metadata
.
max_seq_len
:
output
=
torch
.
empty
(
N
,
Hq
,
D
,
device
=
query
.
device
,
dtype
=
query
.
dtype
)
return
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
# Continuation or no flash_attn: per-request attention.
# For continuation chunks (seq_len > q_len), we must attend to
# previously cached K/V from the TQ cache, not just the current
# chunk's raw K/V.
Hk
=
key
.
shape
[
1
]
use_gqa
=
Hk
<
Hq
query_start_loc
=
attn_metadata
.
query_start_loc
num_reqs
=
query_start_loc
.
shape
[
0
]
-
1
output
=
torch
.
zeros
(
N
,
Hq
,
D
,
device
=
query
.
device
,
dtype
=
query
.
dtype
)
# Convert to Python lists once (single CPU-GPU sync) instead of
# per-request .item() calls that each force a sync.
qsl
=
query_start_loc
.
tolist
()
seq_lens_list
=
attn_metadata
.
seq_lens
.
tolist
()
for
i
in
range
(
num_reqs
):
q_start
=
qsl
[
i
]
q_end
=
qsl
[
i
+
1
]
q_len
=
q_end
-
q_start
if
q_len
<=
0
:
continue
seq_len
=
seq_lens_list
[
i
]
q_seq
=
query
[
q_start
:
q_end
]
# (q_len, Hq, D)
k_seq
=
key
[
q_start
:
q_end
]
# (q_len, Hk, D)
v_seq
=
value
[
q_start
:
q_end
]
# (q_len, Hk, D)
if
q_len
==
seq_len
:
# First-chunk prefill: all K/V are in the current batch.
if
_HAS_FLASH_ATTN
:
out
=
torch
.
empty_like
(
q_seq
)
cu
=
torch
.
tensor
(
[
0
,
q_len
],
device
=
query
.
device
,
dtype
=
torch
.
int32
)
out
=
flash_attn_varlen_func
(
q
=
q_seq
,
k
=
k_seq
,
v
=
v_seq
,
cu_seqlens_q
=
cu
,
cu_seqlens_k
=
cu
,
max_seqlen_q
=
q_len
,
max_seqlen_k
=
q_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
q_t
=
q_seq
.
transpose
(
0
,
1
).
contiguous
()
k_t
=
k_seq
.
transpose
(
0
,
1
).
contiguous
()
v_t
=
v_seq
.
transpose
(
0
,
1
).
contiguous
()
out
=
F
.
scaled_dot_product_attention
(
q_t
,
k_t
,
v_t
,
is_causal
=
True
,
scale
=
self
.
scale
,
enable_gqa
=
use_gqa
,
).
transpose
(
0
,
1
)
output
[
q_start
:
q_end
]
=
out
.
to
(
query
.
dtype
)
else
:
# Continuation chunk: tokens already stored to TQ cache
# by do_kv_cache_update. Use decode kernel directly to
# avoid O(cached_len) full-dequant per continuation.
# For large continuations, fall back to _continuation_prefill.
cached_len
=
seq_len
-
q_len
if
q_len
<=
_CONTINUATION_DECODE_THRESHOLD
:
# Fast path: treat each query as a decode request
# with incremental seq_lens for causal masking.
synth_seq_lens
=
torch
.
arange
(
cached_len
+
1
,
seq_len
+
1
,
device
=
query
.
device
,
dtype
=
attn_metadata
.
seq_lens
.
dtype
,
)
synth_bt
=
attn_metadata
.
block_table
[
i
:
i
+
1
].
expand
(
q_len
,
-
1
)
out
=
triton_turboquant_decode_attention
(
query
=
q_seq
,
kv_cache
=
kv_cache
,
block_table
=
synth_bt
,
seq_lens
=
synth_seq_lens
,
Pi
=
Pi
,
centroids
=
centroids
,
scale
=
self
.
scale
,
mse_bits
=
self
.
tq_config
.
key_mse_bits
,
key_packed_size
=
self
.
tq_config
.
key_packed_size
,
value_quant_bits
=
(
self
.
tq_config
.
effective_value_quant_bits
),
key_fp8
=
self
.
tq_config
.
key_fp8
,
norm_correction
=
self
.
tq_config
.
norm_correction
,
PiT
=
PiT
,
)
else
:
# Large continuation: dequant cached K/V and use
# flash_attn for better throughput.
out
=
self
.
_continuation_prefill
(
q_seq
,
k_seq
,
v_seq
,
kv_cache
,
attn_metadata
.
block_table
[
i
:
i
+
1
],
cached_len
,
seq_len
,
Pi
,
centroids
,
)
output
[
q_start
:
q_end
]
=
out
.
to
(
query
.
dtype
)
return
output
def
_continuation_prefill
(
self
,
query
:
torch
.
Tensor
,
# (q_len, Hq, D)
key_chunk
:
torch
.
Tensor
,
# (q_len, Hk, D)
val_chunk
:
torch
.
Tensor
,
# (q_len, Hk, D)
kv_cache
:
torch
.
Tensor
,
# (num_blocks, block_size, Hk, slot_size)
block_table
:
torch
.
Tensor
,
# (1, max_num_blocks)
cached_len
:
int
,
seq_len
:
int
,
Pi
:
torch
.
Tensor
,
centroids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Handle continuation chunk by dequanting cached K/V from TQ cache.
Dequants previously cached K/V, concatenates with the current
chunk's raw K/V, then runs flash_attn with causal masking.
"""
q_len
,
Hq
,
D
=
query
.
shape
Hk
=
key_chunk
.
shape
[
1
]
device
=
query
.
device
block_size
=
kv_cache
.
shape
[
1
]
BLOCK_D
=
triton
.
next_power_of_2
(
D
)
mse_bytes
=
self
.
_mse_bytes
val_data_bytes
=
self
.
_val_data_bytes
n_centroids
=
self
.
_n_centroids
# Dequant cached K/V from TQ cache
# Allocate slightly over to align to block_size for the grid
alloc_len
=
math
.
ceil
(
cached_len
/
block_size
)
*
block_size
k_cached
=
torch
.
zeros
(
1
,
Hk
,
alloc_len
,
D
,
dtype
=
torch
.
float16
,
device
=
device
)
v_cached
=
torch
.
zeros
(
1
,
Hk
,
alloc_len
,
D
,
dtype
=
torch
.
float16
,
device
=
device
)
grid
=
(
alloc_len
,
1
*
Hk
)
_tq_full_dequant_kv
[
grid
](
kv_cache
,
block_table
,
centroids
.
float
(),
k_cached
,
v_cached
,
k_cached
.
stride
(
0
),
k_cached
.
stride
(
1
),
k_cached
.
stride
(
2
),
v_cached
.
stride
(
0
),
v_cached
.
stride
(
1
),
v_cached
.
stride
(
2
),
kv_cache
.
stride
(
0
),
kv_cache
.
stride
(
1
),
kv_cache
.
stride
(
2
),
block_table
.
stride
(
0
),
HEAD_DIM
=
D
,
BLOCK_SIZE
=
block_size
,
NUM_KV_HEADS
=
Hk
,
MSE_BYTES
=
mse_bytes
,
KPS
=
self
.
tq_config
.
key_packed_size
,
VQB
=
self
.
tq_config
.
effective_value_quant_bits
,
VAL_DATA_BYTES
=
val_data_bytes
,
MSE_BITS
=
self
.
tq_config
.
key_mse_bits
,
N_CENTROIDS
=
n_centroids
,
KEY_FP8
=
1
if
self
.
tq_config
.
key_fp8
else
0
,
BLOCK_D
=
BLOCK_D
,
NORM_CORRECTION
=
1
if
self
.
tq_config
.
norm_correction
else
0
,
FP8_E4B15
=
_use_fp8_e4b15
(
device
.
index
or
0
),
num_warps
=
4
,
)
# Inverse-rotate MSE keys back to original space
if
not
self
.
tq_config
.
key_fp8
:
k_flat
=
k_cached
[
0
,
:,
:
cached_len
,
:].
reshape
(
-
1
,
D
).
float
()
k_flat
=
k_flat
@
Pi
.
float
()
k_cached_trim
=
(
k_flat
.
to
(
torch
.
float16
).
reshape
(
Hk
,
cached_len
,
D
).
transpose
(
0
,
1
)
)
# (cached_len, Hk, D)
else
:
k_cached_trim
=
(
k_cached
[
0
,
:,
:
cached_len
,
:].
transpose
(
0
,
1
).
contiguous
()
)
# (cached_len, Hk, D)
v_cached_trim
=
(
v_cached
[
0
,
:,
:
cached_len
,
:].
transpose
(
0
,
1
).
contiguous
()
)
# (cached_len, Hk, D)
# Concatenate cached + current chunk K/V (match query dtype)
qdtype
=
query
.
dtype
k_full
=
torch
.
cat
([
k_cached_trim
.
to
(
qdtype
),
key_chunk
],
dim
=
0
)
v_full
=
torch
.
cat
([
v_cached_trim
.
to
(
qdtype
),
val_chunk
],
dim
=
0
)
# Attention: q_len queries attending to seq_len K/V with causal mask
if
_HAS_FLASH_ATTN
:
output
=
torch
.
empty
(
q_len
,
Hq
,
D
,
device
=
device
,
dtype
=
query
.
dtype
)
cu_seqlens_q
=
torch
.
tensor
([
0
,
q_len
],
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
tensor
([
0
,
seq_len
],
device
=
device
,
dtype
=
torch
.
int32
)
return
flash_attn_varlen_func
(
q
=
query
,
k
=
k_full
,
v
=
v_full
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
q_len
,
max_seqlen_k
=
seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
# SDPA fallback: expand KV for GQA, build causal mask
q_t
=
query
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
# (1, Hq, q_len, D)
k_t
=
k_full
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
# (1, Hk, seq_len, D)
v_t
=
v_full
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
# (1, Hk, seq_len, D)
# Build causal mask: query position p can attend to K position j
# where j <= cached_len + p (p is 0-indexed within chunk)
q_pos
=
torch
.
arange
(
q_len
,
device
=
device
).
unsqueeze
(
1
)
+
cached_len
k_pos
=
torch
.
arange
(
seq_len
,
device
=
device
).
unsqueeze
(
0
)
mask
=
k_pos
<=
q_pos
# (q_len, seq_len)
out
=
F
.
scaled_dot_product_attention
(
q_t
,
k_t
,
v_t
,
attn_mask
=
mask
,
scale
=
self
.
scale
,
enable_gqa
=
(
Hk
<
Hq
),
)
# (1, Hq, q_len, D)
return
out
[
0
].
transpose
(
0
,
1
)
# (q_len, Hq, D)
# ------------------------------------------------------------------ #
# Decode: Triton TQ decode attention #
# ------------------------------------------------------------------ #
def
_decode_attention
(
self
,
query
:
torch
.
Tensor
,
# (B, Hq, D)
kv_cache
:
torch
.
Tensor
,
# (num_blocks, block_size, Hk, slot_size)
attn_metadata
:
TurboQuantMetadata
,
Pi
:
torch
.
Tensor
,
centroids
:
torch
.
Tensor
,
PiT
:
torch
.
Tensor
|
None
=
None
,
layer
:
torch
.
nn
.
Module
|
None
=
None
,
)
->
torch
.
Tensor
:
# Grab cached decode buffers from the layer (lazily allocated).
mid_o_buf
=
output_buf
=
lse_buf
=
None
if
layer
is
not
None
:
mid_o_buf
=
getattr
(
layer
,
"_tq_mid_o_buf"
,
None
)
output_buf
=
getattr
(
layer
,
"_tq_output_buf"
,
None
)
lse_buf
=
getattr
(
layer
,
"_tq_lse_buf"
,
None
)
result
=
triton_turboquant_decode_attention
(
query
=
query
,
kv_cache
=
kv_cache
,
block_table
=
attn_metadata
.
block_table
,
seq_lens
=
attn_metadata
.
seq_lens
,
Pi
=
Pi
,
centroids
=
centroids
,
scale
=
self
.
scale
,
mse_bits
=
self
.
tq_config
.
key_mse_bits
,
key_packed_size
=
self
.
tq_config
.
key_packed_size
,
value_quant_bits
=
self
.
tq_config
.
effective_value_quant_bits
,
key_fp8
=
self
.
tq_config
.
key_fp8
,
norm_correction
=
self
.
tq_config
.
norm_correction
,
PiT
=
PiT
,
mid_o_buf
=
mid_o_buf
,
output_buf
=
output_buf
,
lse_buf
=
lse_buf
,
buf_holder
=
layer
,
max_num_kv_splits
=
self
.
max_num_kv_splits
,
)
return
result
vllm/v1/attention/ops/triton_turboquant_decode.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Triton fused TurboQuant decode attention.
Decode path: Triton stage1 (split-KV tiled attention scoring + value
accumulation) + stage2 (log-sum-exp reduction across splits).
Supports FP8 (E4M3) keys, 3-bit and 4-bit uniform quantized values.
"""
import
math
import
torch
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
_FP8_E4B15
:
int
|
None
=
None
def
_use_fp8_e4b15
(
device
:
int
=
0
)
->
int
:
"""Return 1 if device needs fp8e4b15 (Ampere/Ada, SM < 8.9), else 0.
On non-CUDA platforms (e.g. XPU), always returns 0 (use e4nv format).
"""
global
_FP8_E4B15
if
_FP8_E4B15
is
None
:
if
current_platform
.
is_cuda_alike
():
cap
=
torch
.
cuda
.
get_device_capability
(
device
)
_FP8_E4B15
=
1
if
cap
<
(
8
,
9
)
else
0
else
:
_FP8_E4B15
=
0
return
_FP8_E4B15
# ---------------------------------------------------------------------------
# Stage 1: Fused TQ score + value accumulation (BLOCK_KV tiled)
# ---------------------------------------------------------------------------
@
triton
.
jit
def
_tq_decode_stage1
(
# Precomputed query projection
Q_rot_ptr
,
# [B, Hq, D] float32
# Compressed KV cache (combined K+V)
KV_cache_ptr
,
# [num_blocks, block_size, Hk, padded_slot] uint8
# Block table and sequence info
Block_table_ptr
,
# [B, max_num_blocks] int32
Seq_lens_ptr
,
# [B] int32
# TQ parameters
Centroids_ptr
,
# [n_centroids] float32
# Output (intermediate for stage2)
Mid_o_ptr
,
# [B, Hq, NUM_KV_SPLITS, D+1] float32
# Strides
stride_qb
,
stride_qh
,
# Q strides: [B, Hq, D]
stride_cache_block
,
stride_cache_pos
,
stride_cache_head
,
# KV cache
stride_bt_b
,
# block_table stride per batch
stride_mid_b
,
stride_mid_h
,
stride_mid_s
,
# mid_o strides
# Constexpr dims
NUM_KV_HEADS
:
tl
.
constexpr
,
HEAD_DIM
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
# KV cache block_size (pages)
NUM_KV_SPLITS
:
tl
.
constexpr
,
KV_GROUP_SIZE
:
tl
.
constexpr
,
# Hq // Hk
# TQ layout constants
MSE_BITS
:
tl
.
constexpr
,
# 3 or 4
MSE_BYTES
:
tl
.
constexpr
,
# ceil(D * mse_bits / 8)
KPS
:
tl
.
constexpr
,
# key_packed_size
VQB
:
tl
.
constexpr
,
# value_quant_bits (4 or 8=FP8)
VAL_DATA_BYTES
:
tl
.
constexpr
,
# ceil(D * vqb / 8) or D for FP8
# Score constants
ATTN_SCALE
:
tl
.
constexpr
,
# 1/sqrt(D)
# Block tile sizes
BLOCK_D
:
tl
.
constexpr
,
# next_power_of_2(HEAD_DIM)
BLOCK_KV
:
tl
.
constexpr
,
# tokens per tile (16)
KEY_FP8
:
tl
.
constexpr
,
# 1 if K is stored as FP8
NORM_CORRECTION
:
tl
.
constexpr
=
0
,
# 1 = re-normalize centroids
FP8_E4B15
:
tl
.
constexpr
=
0
,
# 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
bid
=
tl
.
program_id
(
0
)
# batch index
hid
=
tl
.
program_id
(
1
)
# q_head index
sid
=
tl
.
program_id
(
2
)
# kv_split index
kv_head
=
hid
//
KV_GROUP_SIZE
# Sequence length for this batch
seq_len
=
tl
.
load
(
Seq_lens_ptr
+
bid
)
# KV split range
split_len
=
tl
.
cdiv
(
seq_len
,
NUM_KV_SPLITS
)
split_start
=
split_len
*
sid
split_end
=
tl
.
minimum
(
split_start
+
split_len
,
seq_len
)
if
split_start
>=
split_end
:
return
# Dimension offsets
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
HEAD_DIM
kv_range
=
tl
.
arange
(
0
,
BLOCK_KV
)
# Load query vector: q_rot — [BLOCK_D] float32
q_base
=
bid
*
stride_qb
+
hid
*
stride_qh
q_rot
=
tl
.
load
(
Q_rot_ptr
+
q_base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Precompute byte/bit index vectors for MSE gather loads
if
not
KEY_FP8
:
mse_bit_off
=
d_offs
*
MSE_BITS
mse_byte_idx
=
mse_bit_off
//
8
mse_bit_shift
=
mse_bit_off
%
8
mse_mask
=
(
1
<<
MSE_BITS
)
-
1
# Precompute value bit/byte index vectors (loop-invariant)
if
VQB
==
3
:
val_bit_off
=
d_offs
*
3
val_byte_idx
=
val_bit_off
//
8
val_bit_shift
=
val_bit_off
%
8
# Online softmax accumulators
m_prev
=
-
float
(
"inf"
)
l_prev
=
0.0
acc
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
bt_base
=
bid
*
stride_bt_b
# ================================================================
# TILED LOOP: process BLOCK_KV tokens per iteration
# ================================================================
for
start_n
in
range
(
split_start
,
split_end
,
BLOCK_KV
):
kv_offs
=
start_n
+
kv_range
kv_mask
=
kv_offs
<
split_end
page_idx
=
kv_offs
//
BLOCK_SIZE
page_off
=
kv_offs
%
BLOCK_SIZE
block_nums
=
tl
.
load
(
Block_table_ptr
+
bt_base
+
page_idx
,
mask
=
kv_mask
,
other
=
0
,
)
slot_bases
=
(
block_nums
*
stride_cache_block
+
page_off
*
stride_cache_pos
+
kv_head
*
stride_cache_head
)
# ============================================================
# COMPUTE ATTENTION SCORES: [BLOCK_KV]
# ============================================================
if
KEY_FP8
:
k_addrs
=
slot_bases
[:,
None
]
+
d_offs
[
None
,
:]
k_raw
=
tl
.
load
(
KV_cache_ptr
+
k_addrs
,
mask
=
kv_mask
[:,
None
]
&
d_mask
[
None
,
:],
other
=
0
,
)
if
FP8_E4B15
:
k_float
=
k_raw
.
to
(
tl
.
float8e4b15
,
bitcast
=
True
).
to
(
tl
.
float32
)
else
:
k_float
=
k_raw
.
to
(
tl
.
float8e4nv
,
bitcast
=
True
).
to
(
tl
.
float32
)
scores
=
(
tl
.
sum
(
tl
.
where
(
d_mask
[
None
,
:],
q_rot
[
None
,
:]
*
k_float
,
0.0
),
axis
=
1
,
)
*
ATTN_SCALE
)
scores
=
tl
.
where
(
kv_mask
,
scores
,
-
float
(
"inf"
))
else
:
# MSE unpack + norms
mse_addrs0
=
slot_bases
[:,
None
]
+
mse_byte_idx
[
None
,
:]
mse_raw0
=
tl
.
load
(
KV_cache_ptr
+
mse_addrs0
,
mask
=
kv_mask
[:,
None
]
&
d_mask
[
None
,
:],
other
=
0
,
).
to
(
tl
.
int32
)
mse_raw1
=
tl
.
load
(
KV_cache_ptr
+
mse_addrs0
+
1
,
mask
=
kv_mask
[:,
None
]
&
d_mask
[
None
,
:],
other
=
0
,
).
to
(
tl
.
int32
)
raw16
=
mse_raw0
|
(
mse_raw1
<<
8
)
mse_idx
=
(
raw16
>>
mse_bit_shift
[
None
,
:])
&
mse_mask
# Centroid gather + dot product
c_vals
=
tl
.
load
(
Centroids_ptr
+
mse_idx
,
mask
=
kv_mask
[:,
None
]
&
d_mask
[
None
,
:],
other
=
0.0
,
)
# Norm correction: re-normalize centroid vector to unit norm
if
NORM_CORRECTION
:
c_norm_sq
=
tl
.
sum
(
tl
.
where
(
d_mask
[
None
,
:],
c_vals
*
c_vals
,
0.0
),
axis
=
1
,
)
c_inv_norm
=
1.0
/
tl
.
sqrt
(
c_norm_sq
+
1e-16
)
c_vals
=
c_vals
*
c_inv_norm
[:,
None
]
term1
=
tl
.
sum
(
tl
.
where
(
d_mask
[
None
,
:],
q_rot
[
None
,
:]
*
c_vals
,
0.0
),
axis
=
1
,
)
# Load norms (fp16 -> fp32): norms are at MSE_BYTES offset
norm_bases
=
slot_bases
+
MSE_BYTES
n_lo
=
tl
.
load
(
KV_cache_ptr
+
norm_bases
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
n_hi
=
tl
.
load
(
KV_cache_ptr
+
norm_bases
+
1
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
vec_norms
=
(
n_lo
|
(
n_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
scores
=
vec_norms
*
term1
*
ATTN_SCALE
scores
=
tl
.
where
(
kv_mask
,
scores
,
-
float
(
"inf"
))
# ============================================================
# ONLINE SOFTMAX UPDATE (block-level)
# ============================================================
n_e_max
=
tl
.
maximum
(
tl
.
max
(
scores
,
0
),
m_prev
)
re_scale
=
tl
.
exp
(
m_prev
-
n_e_max
)
p
=
tl
.
exp
(
scores
-
n_e_max
)
# ============================================================
# VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D]
# ============================================================
val_bases
=
slot_bases
+
KPS
if
VQB
==
3
:
val_addrs0
=
val_bases
[:,
None
]
+
val_byte_idx
[
None
,
:]
val_raw0
=
tl
.
load
(
KV_cache_ptr
+
val_addrs0
,
mask
=
kv_mask
[:,
None
]
&
d_mask
[
None
,
:],
other
=
0
,
).
to
(
tl
.
int32
)
val_raw1
=
tl
.
load
(
KV_cache_ptr
+
val_addrs0
+
1
,
mask
=
kv_mask
[:,
None
]
&
d_mask
[
None
,
:],
other
=
0
,
).
to
(
tl
.
int32
)
raw16
=
val_raw0
|
(
val_raw1
<<
8
)
v_idx
=
((
raw16
>>
val_bit_shift
[
None
,
:])
&
0x7
).
to
(
tl
.
float32
)
sc_bases
=
val_bases
+
VAL_DATA_BYTES
sc_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
sc_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
+
1
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
v_scales
=
(
(
sc_lo
|
(
sc_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
)
zr_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
+
2
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
zr_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
+
3
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
v_zeros
=
(
zr_lo
|
(
zr_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
values
=
v_idx
*
v_scales
[:,
None
]
+
v_zeros
[:,
None
]
else
:
# VQB == 4
vb_idx
=
d_offs
//
2
vb_shift
=
(
d_offs
%
2
)
*
4
val_addrs
=
val_bases
[:,
None
]
+
vb_idx
[
None
,
:]
val_raw
=
tl
.
load
(
KV_cache_ptr
+
val_addrs
,
mask
=
kv_mask
[:,
None
]
&
d_mask
[
None
,
:],
other
=
0
,
).
to
(
tl
.
int32
)
v_idx
=
((
val_raw
>>
vb_shift
[
None
,
:])
&
0xF
).
to
(
tl
.
float32
)
sc_bases
=
val_bases
+
VAL_DATA_BYTES
sc_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
sc_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
+
1
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
v_scales
=
(
(
sc_lo
|
(
sc_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
)
zr_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
+
2
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
zr_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_bases
+
3
,
mask
=
kv_mask
,
other
=
0
).
to
(
tl
.
uint16
)
v_zeros
=
(
zr_lo
|
(
zr_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
values
=
v_idx
*
v_scales
[:,
None
]
+
v_zeros
[:,
None
]
# ============================================================
# WEIGHTED VALUE ACCUMULATION
# ============================================================
acc
=
acc
*
re_scale
+
tl
.
sum
(
p
[:,
None
]
*
values
,
0
)
l_prev
=
l_prev
*
re_scale
+
tl
.
sum
(
p
,
0
)
m_prev
=
n_e_max
# Store partial result
out_base
=
bid
*
stride_mid_b
+
hid
*
stride_mid_h
+
sid
*
stride_mid_s
safe_l
=
tl
.
where
(
l_prev
>
0.0
,
l_prev
,
1.0
)
tl
.
store
(
Mid_o_ptr
+
out_base
+
d_offs
,
acc
/
safe_l
,
mask
=
d_mask
)
lse
=
m_prev
+
tl
.
log
(
safe_l
)
tl
.
store
(
Mid_o_ptr
+
out_base
+
HEAD_DIM
,
lse
)
# ---------------------------------------------------------------------------
# Pre-dequant kernel: Bulk dequant K (MSE+norms) and V to fp16
# ---------------------------------------------------------------------------
@
triton
.
jit
def
_tq_full_dequant_kv
(
KV_cache_ptr
,
Block_table_ptr
,
Centroids_ptr
,
K_out_ptr
,
# [B, Hk, max_seq, D] float16
V_out_ptr
,
# [B, Hk, max_seq, D] float16
stride_ko_b
,
stride_ko_h
,
stride_ko_s
,
stride_vo_b
,
stride_vo_h
,
stride_vo_s
,
stride_cache_block
,
stride_cache_pos
,
stride_cache_head
,
stride_bt_b
,
HEAD_DIM
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
NUM_KV_HEADS
:
tl
.
constexpr
,
MSE_BYTES
:
tl
.
constexpr
,
KPS
:
tl
.
constexpr
,
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
MSE_BITS
:
tl
.
constexpr
,
N_CENTROIDS
:
tl
.
constexpr
,
KEY_FP8
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
NORM_CORRECTION
:
tl
.
constexpr
=
0
,
FP8_E4B15
:
tl
.
constexpr
=
0
,
# 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
"""Full dequant: reconstruct K (MSE centroids * norm or FP8) and V to fp16."""
pos
=
tl
.
program_id
(
0
)
bh
=
tl
.
program_id
(
1
)
bid
=
bh
//
NUM_KV_HEADS
hid
=
bh
%
NUM_KV_HEADS
page_idx
=
pos
//
BLOCK_SIZE
page_off
=
pos
%
BLOCK_SIZE
block_num
=
tl
.
load
(
Block_table_ptr
+
bid
*
stride_bt_b
+
page_idx
)
slot_base
=
(
block_num
*
stride_cache_block
+
page_off
*
stride_cache_pos
+
hid
*
stride_cache_head
)
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
HEAD_DIM
# === K dequant ===
ko_base
=
bid
*
stride_ko_b
+
hid
*
stride_ko_h
+
pos
*
stride_ko_s
if
KEY_FP8
:
k_raw
=
tl
.
load
(
KV_cache_ptr
+
slot_base
+
d_offs
,
mask
=
d_mask
,
other
=
0
)
if
FP8_E4B15
:
k_recon
=
k_raw
.
to
(
tl
.
float8e4b15
,
bitcast
=
True
).
to
(
tl
.
float32
)
else
:
k_recon
=
k_raw
.
to
(
tl
.
float8e4nv
,
bitcast
=
True
).
to
(
tl
.
float32
)
tl
.
store
(
K_out_ptr
+
ko_base
+
d_offs
,
k_recon
.
to
(
tl
.
float16
),
mask
=
d_mask
)
else
:
# MSE unpack (3-bit or 4-bit) + norms
mse_bit_off
=
d_offs
*
MSE_BITS
mse_byte_idx
=
mse_bit_off
//
8
mse_bit_shift
=
mse_bit_off
%
8
mse_umask
=
(
1
<<
MSE_BITS
)
-
1
mse_raw0
=
tl
.
load
(
KV_cache_ptr
+
slot_base
+
mse_byte_idx
,
mask
=
d_mask
,
other
=
0
).
to
(
tl
.
int32
)
mse_raw1
=
tl
.
load
(
KV_cache_ptr
+
slot_base
+
mse_byte_idx
+
1
,
mask
=
d_mask
,
other
=
0
).
to
(
tl
.
int32
)
raw16
=
mse_raw0
|
(
mse_raw1
<<
8
)
mse_idx
=
(
raw16
>>
mse_bit_shift
)
&
mse_umask
k_mse
=
tl
.
load
(
Centroids_ptr
+
mse_idx
,
mask
=
d_mask
,
other
=
0.0
)
# Norm correction: re-normalize centroid vector to unit norm
if
NORM_CORRECTION
:
c_norm_sq
=
tl
.
sum
(
tl
.
where
(
d_mask
,
k_mse
*
k_mse
,
0.0
),
axis
=
0
)
c_inv_norm
=
1.0
/
tl
.
sqrt
(
c_norm_sq
+
1e-16
)
k_mse
=
k_mse
*
c_inv_norm
# Norms at MSE_BYTES offset (no QJL bytes)
norm_base
=
slot_base
+
MSE_BYTES
n_lo
=
tl
.
load
(
KV_cache_ptr
+
norm_base
).
to
(
tl
.
uint16
)
n_hi
=
tl
.
load
(
KV_cache_ptr
+
norm_base
+
1
).
to
(
tl
.
uint16
)
vec_norm
=
(
n_lo
|
(
n_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
k_recon
=
vec_norm
*
k_mse
tl
.
store
(
K_out_ptr
+
ko_base
+
d_offs
,
k_recon
.
to
(
tl
.
float16
),
mask
=
d_mask
)
# === V dequant ===
val_base
=
slot_base
+
KPS
if
VQB
==
4
:
vb_idx
=
d_offs
//
2
vb_shift
=
(
d_offs
%
2
)
*
4
val_raw
=
tl
.
load
(
KV_cache_ptr
+
val_base
+
vb_idx
,
mask
=
d_mask
,
other
=
0
).
to
(
tl
.
int32
)
v_idx
=
((
val_raw
>>
vb_shift
)
&
0xF
).
to
(
tl
.
float32
)
sc_base
=
val_base
+
VAL_DATA_BYTES
sc_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_base
).
to
(
tl
.
uint16
)
sc_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_base
+
1
).
to
(
tl
.
uint16
)
v_scale
=
(
sc_lo
|
(
sc_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
zr_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_base
+
2
).
to
(
tl
.
uint16
)
zr_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_base
+
3
).
to
(
tl
.
uint16
)
v_zero
=
(
zr_lo
|
(
zr_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
v_vals
=
v_idx
*
v_scale
+
v_zero
elif
VQB
==
3
:
# 3-bit value unpack: 8 values per 3 bytes
val_bit_off
=
d_offs
*
3
val_byte_idx
=
val_bit_off
//
8
val_bit_shift
=
val_bit_off
%
8
val_raw0
=
tl
.
load
(
KV_cache_ptr
+
val_base
+
val_byte_idx
,
mask
=
d_mask
,
other
=
0
).
to
(
tl
.
int32
)
val_raw1
=
tl
.
load
(
KV_cache_ptr
+
val_base
+
val_byte_idx
+
1
,
mask
=
d_mask
,
other
=
0
).
to
(
tl
.
int32
)
raw16
=
val_raw0
|
(
val_raw1
<<
8
)
v_idx
=
((
raw16
>>
val_bit_shift
)
&
0x7
).
to
(
tl
.
float32
)
sc_base
=
val_base
+
VAL_DATA_BYTES
sc_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_base
).
to
(
tl
.
uint16
)
sc_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_base
+
1
).
to
(
tl
.
uint16
)
v_scale
=
(
sc_lo
|
(
sc_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
zr_lo
=
tl
.
load
(
KV_cache_ptr
+
sc_base
+
2
).
to
(
tl
.
uint16
)
zr_hi
=
tl
.
load
(
KV_cache_ptr
+
sc_base
+
3
).
to
(
tl
.
uint16
)
v_zero
=
(
zr_lo
|
(
zr_hi
<<
8
)).
to
(
tl
.
float16
,
bitcast
=
True
).
to
(
tl
.
float32
)
v_vals
=
v_idx
*
v_scale
+
v_zero
else
:
v_vals
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
vo_base
=
bid
*
stride_vo_b
+
hid
*
stride_vo_h
+
pos
*
stride_vo_s
tl
.
store
(
V_out_ptr
+
vo_base
+
d_offs
,
v_vals
.
to
(
tl
.
float16
),
mask
=
d_mask
)
# ---------------------------------------------------------------------------
# Stage 2: Reuse from triton_decode_attention.py
# ---------------------------------------------------------------------------
from
vllm.v1.attention.ops.triton_decode_attention
import
(
_fwd_kernel_stage2
,
)
# ---------------------------------------------------------------------------
# Launcher — cached constants + fused GEMM
# ---------------------------------------------------------------------------
_layout_cache
:
dict
=
{}
def
_get_layout
(
D
,
mse_bits
,
value_quant_bits
,
key_packed_size
):
"""Get cached layout constants."""
key
=
(
D
,
mse_bits
,
value_quant_bits
,
key_packed_size
)
cfg
=
_layout_cache
.
get
(
key
)
if
cfg
is
None
:
val_data_bytes
=
math
.
ceil
(
D
*
value_quant_bits
/
8
)
cfg
=
{
"mse_bytes"
:
math
.
ceil
(
D
*
mse_bits
/
8
),
"val_data_bytes"
:
val_data_bytes
,
"mse_bits"
:
mse_bits
,
"n_centroids"
:
2
**
mse_bits
,
"BLOCK_D"
:
triton
.
next_power_of_2
(
D
),
}
_layout_cache
[
key
]
=
cfg
return
cfg
def
triton_turboquant_decode_attention
(
query
:
torch
.
Tensor
,
# [B, Hq, D] — original query
kv_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, Hk, padded_slot] uint8
block_table
:
torch
.
Tensor
,
# [B, max_num_blocks] int32
seq_lens
:
torch
.
Tensor
,
# [B] int32
Pi
:
torch
.
Tensor
,
# [D, D] float32
centroids
:
torch
.
Tensor
,
# [n_centroids] float32
scale
:
float
,
mse_bits
:
int
,
key_packed_size
:
int
,
value_quant_bits
:
int
,
key_fp8
:
bool
=
False
,
norm_correction
:
bool
=
False
,
PiT
:
torch
.
Tensor
|
None
=
None
,
# [D, D] pre-computed Pi.T contiguous
# Pre-allocated buffers (optional, avoids per-call allocation)
mid_o_buf
:
torch
.
Tensor
|
None
=
None
,
output_buf
:
torch
.
Tensor
|
None
=
None
,
lse_buf
:
torch
.
Tensor
|
None
=
None
,
buf_holder
:
object
|
None
=
None
,
max_num_kv_splits
:
int
=
32
,
# fixed split count (must be constant for cudagraph)
)
->
torch
.
Tensor
:
"""Launch fused TQ decode attention (Triton stage1 + stage2).
Returns: output tensor [B, Hq, D] in query's dtype.
"""
B
,
Hq
,
D
=
query
.
shape
Hk
=
kv_cache
.
shape
[
2
]
block_size
=
kv_cache
.
shape
[
1
]
kv_group_size
=
Hq
//
Hk
device
=
query
.
device
cfg
=
_get_layout
(
D
,
mse_bits
,
value_quant_bits
,
key_packed_size
)
# Compute q_rot = q @ Pi.T (rotated query for MSE key scoring)
# FP8 path: pass query directly (float16); kernel casts inline.
# MSE path: still needs external GEMM (cuBLAS), so q_rot is float32.
if
key_fp8
:
q_rot
=
query
.
contiguous
()
else
:
q_float
=
query
.
float
()
if
PiT
is
None
:
PiT
=
Pi
.
T
.
contiguous
()
q_rot
=
(
q_float
@
PiT
).
contiguous
()
NUM_KV_SPLITS
=
max_num_kv_splits
if
(
mid_o_buf
is
not
None
and
mid_o_buf
.
shape
[
0
]
>=
B
and
mid_o_buf
.
shape
[
2
]
>=
NUM_KV_SPLITS
):
mid_o
=
mid_o_buf
[:
B
,
:
Hq
,
:
NUM_KV_SPLITS
,
:]
else
:
mid_o
=
torch
.
empty
(
B
,
Hq
,
NUM_KV_SPLITS
,
D
+
1
,
dtype
=
torch
.
float32
,
device
=
device
,
)
if
buf_holder
is
not
None
:
buf_holder
.
_tq_mid_o_buf
=
mid_o
# Stage 1: split-KV tiled attention scoring + value accumulation
fp8_e4b15
=
_use_fp8_e4b15
(
device
.
index
or
0
)
BLOCK_KV
=
4
grid
=
(
B
,
Hq
,
NUM_KV_SPLITS
)
_tq_decode_stage1
[
grid
](
q_rot
,
kv_cache
,
block_table
,
seq_lens
,
centroids
,
mid_o
,
q_rot
.
stride
(
0
),
q_rot
.
stride
(
1
),
kv_cache
.
stride
(
0
),
kv_cache
.
stride
(
1
),
kv_cache
.
stride
(
2
),
block_table
.
stride
(
0
),
mid_o
.
stride
(
0
),
mid_o
.
stride
(
1
),
mid_o
.
stride
(
2
),
NUM_KV_HEADS
=
Hk
,
HEAD_DIM
=
D
,
BLOCK_SIZE
=
block_size
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
KV_GROUP_SIZE
=
kv_group_size
,
MSE_BITS
=
mse_bits
,
MSE_BYTES
=
cfg
[
"mse_bytes"
],
KPS
=
key_packed_size
,
VQB
=
value_quant_bits
,
VAL_DATA_BYTES
=
cfg
[
"val_data_bytes"
],
ATTN_SCALE
=
scale
,
BLOCK_D
=
cfg
[
"BLOCK_D"
],
BLOCK_KV
=
BLOCK_KV
,
KEY_FP8
=
1
if
key_fp8
else
0
,
NORM_CORRECTION
=
1
if
norm_correction
else
0
,
FP8_E4B15
=
fp8_e4b15
,
num_warps
=
1
,
num_stages
=
1
,
)
# Stage 2: Reduce across KV splits
if
output_buf
is
not
None
and
output_buf
.
shape
[
0
]
>=
B
:
output
=
output_buf
[:
B
,
:
Hq
,
:
D
]
else
:
output
=
torch
.
empty
(
B
,
Hq
,
D
,
dtype
=
torch
.
float32
,
device
=
device
)
if
buf_holder
is
not
None
:
buf_holder
.
_tq_output_buf
=
output
if
lse_buf
is
not
None
and
lse_buf
.
shape
[
0
]
>=
B
:
lse
=
lse_buf
[:
B
,
:
Hq
]
else
:
lse
=
torch
.
empty
(
B
,
Hq
,
dtype
=
torch
.
float32
,
device
=
device
)
if
buf_holder
is
not
None
:
buf_holder
.
_tq_lse_buf
=
lse
grid2
=
(
B
,
Hq
)
_fwd_kernel_stage2
[
grid2
](
mid_o
,
output
,
lse
,
seq_lens
,
mid_o
.
stride
(
0
),
mid_o
.
stride
(
1
),
mid_o
.
stride
(
2
),
output
.
stride
(
0
),
output
.
stride
(
1
),
lse
.
stride
(
0
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
cfg
[
"BLOCK_D"
],
Lv
=
D
,
num_warps
=
4
,
num_stages
=
2
,
)
return
output
.
to
(
query
.
dtype
)
vllm/v1/attention/ops/triton_turboquant_store.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused Triton kernels for TurboQuant KV store.
Two kernels:
1. _tq_fused_store_fp8: FP8 key scatter + value uniform quantization.
2. _tq_fused_store_mse: Fused bucketize + centroid gather + residual norm
+ MSE index packing + value quantization (eliminates 4 PyTorch kernel
launches vs the old pack-only approach).
The launcher `triton_turboquant_store` selects the appropriate kernel.
"""
import
math
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.ops.triton_turboquant_decode
import
_use_fp8_e4b15
# ═══════════════════════════════════════════════════════════════════════
# Shared: value uniform quantization + pack + scale/zero store
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
# pid * D offset into Value_ptr
slot_base
,
# byte offset into KV_cache_ptr for this slot+head
d_offs
,
# tl.arange(0, BLOCK_D)
d_mask
,
# d_offs < D
D
:
tl
.
constexpr
,
KPS
:
tl
.
constexpr
,
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
BLOCK_VAL
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
,
):
"""Uniform quantization of values to VQB bits, pack, and store with scale/zero."""
val_cache_offset
=
KPS
if
VQB
==
3
:
val_vec
=
tl
.
load
(
Value_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
val_min
=
tl
.
min
(
tl
.
where
(
d_mask
,
val_vec
,
float
(
"inf"
)),
axis
=
0
)
val_max
=
tl
.
max
(
tl
.
where
(
d_mask
,
val_vec
,
-
float
(
"inf"
)),
axis
=
0
)
v_scale
=
(
val_max
-
val_min
)
/
7.0
v_scale
=
tl
.
where
(
v_scale
>
1e-8
,
v_scale
,
1e-8
)
q_vals
=
tl
.
minimum
(
tl
.
maximum
(((
val_vec
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
7
)
grp_offs
=
tl
.
arange
(
0
,
BLOCK_GRP
)
grp_mask
=
grp_offs
<
(
D
//
8
)
q_grp
=
tl
.
reshape
(
q_vals
,
[
BLOCK_GRP
,
8
])
shifts_3bit
=
tl
.
arange
(
0
,
8
)
*
3
packed_24
=
tl
.
sum
(
q_grp
<<
shifts_3bit
[
None
,
:],
axis
=
1
)
b0
=
(
packed_24
&
0xFF
).
to
(
tl
.
uint8
)
b1
=
((
packed_24
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
b2
=
((
packed_24
>>
16
)
&
0xFF
).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
,
b0
,
mask
=
grp_mask
,
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
+
1
,
b1
,
mask
=
grp_mask
,
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
+
2
,
b2
,
mask
=
grp_mask
,
)
sc_offset
=
val_cache_offset
+
VAL_DATA_BYTES
sc_f16
=
v_scale
.
to
(
tl
.
float16
)
sc_u16
=
sc_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
,
(
sc_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
1
,
((
sc_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
zr_f16
=
val_min
.
to
(
tl
.
float16
)
zr_u16
=
zr_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
2
,
(
zr_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
3
,
((
zr_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
else
:
# VQB == 4
val_vec
=
tl
.
load
(
Value_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
val_min
=
tl
.
min
(
tl
.
where
(
d_mask
,
val_vec
,
float
(
"inf"
)),
axis
=
0
)
val_max
=
tl
.
max
(
tl
.
where
(
d_mask
,
val_vec
,
-
float
(
"inf"
)),
axis
=
0
)
v_scale
=
(
val_max
-
val_min
)
/
15.0
v_scale
=
tl
.
where
(
v_scale
>
1e-8
,
v_scale
,
1e-8
)
val_offs
=
tl
.
arange
(
0
,
BLOCK_VAL
)
val_mask
=
val_offs
<
VAL_DATA_BYTES
v0
=
tl
.
load
(
Value_ptr
+
base
+
val_offs
*
2
,
mask
=
val_mask
&
(
val_offs
*
2
<
D
),
other
=
val_min
,
)
v1
=
tl
.
load
(
Value_ptr
+
base
+
val_offs
*
2
+
1
,
mask
=
val_mask
&
(
val_offs
*
2
+
1
<
D
),
other
=
val_min
,
)
q0
=
tl
.
minimum
(
tl
.
maximum
(((
v0
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
15
)
q1
=
tl
.
minimum
(
tl
.
maximum
(((
v1
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
15
)
packed_val
=
(
q0
|
(
q1
<<
4
)).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
val_offs
,
packed_val
,
mask
=
val_mask
,
)
sc_offset
=
val_cache_offset
+
VAL_DATA_BYTES
sc_f16
=
v_scale
.
to
(
tl
.
float16
)
sc_u16
=
sc_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
,
(
sc_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
1
,
((
sc_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
zr_f16
=
val_min
.
to
(
tl
.
float16
)
zr_u16
=
zr_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
2
,
(
zr_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
3
,
((
zr_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
# ═══════════════════════════════════════════════════════════════════════
# FP8 key store + value uniform quantization
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_tq_fused_store_fp8
(
Key_ptr
,
# [NH, D] float16/bfloat16 — raw keys
Value_ptr
,
# [NH, D] float16/bfloat16 — raw values
KV_cache_ptr
,
# [total_bytes] uint8 (flattened view)
Slot_mapping_ptr
,
# [N] int32 — per-token slot indices
# Cache strides (for computing byte offsets)
stride_cache_block
:
tl
.
constexpr
,
stride_cache_pos
:
tl
.
constexpr
,
stride_cache_head
:
tl
.
constexpr
,
# Dimensions
D
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# TQ layout
KPS
:
tl
.
constexpr
,
# Value quantization
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
# Packing block sizes
BLOCK_VAL
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
=
16
,
FP8_E4B15
:
tl
.
constexpr
=
0
,
# 1 = e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
"""FP8 key cast+scatter + value uniform quantization: one program per (token, head)."""
pid
=
tl
.
program_id
(
0
)
token_idx
=
pid
//
H
head_idx
=
pid
%
H
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
return
blk
=
slot
//
BLOCK_SIZE
off
=
slot
%
BLOCK_SIZE
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
)
base
=
pid
*
D
# ── FP8 KEY: cast to FP8 in-kernel and store ─────────────────
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
D
k_vals
=
tl
.
load
(
Key_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
)
if
FP8_E4B15
:
k_fp8
=
k_vals
.
to
(
tl
.
float8e4b15
)
else
:
x_f32
=
k_vals
.
to
(
tl
.
float32
)
k_fp8
=
x_f32
.
to
(
tl
.
float8e4nv
)
k_bytes
=
k_fp8
.
to
(
tl
.
uint8
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
d_offs
,
k_bytes
,
mask
=
d_mask
)
# ── VALUE QUANTIZE + PACK ───────────────────────────────────────
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
slot_base
,
d_offs
,
d_mask
,
D
=
D
,
KPS
=
KPS
,
VQB
=
VQB
,
VAL_DATA_BYTES
=
VAL_DATA_BYTES
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
BLOCK_GRP
,
)
# ═══════════════════════════════════════════════════════════════════════
# Fused MSE store: bucketize + centroid gather + residual norm + pack
# (eliminates 4 PyTorch kernel launches per layer vs pack-only kernel)
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_tq_fused_store_mse
(
# Post-rotation inputs
Y_ptr
,
# [NH, D] float32 — rotated normalized keys (x_hat @ PiT)
Norms_ptr
,
# [NH] float32 — key vector norms (||k||)
Value_ptr
,
# [NH, D] float32 — raw values
# Quantization tables
Centroids_ptr
,
# [n_centroids] float32
Midpoints_ptr
,
# [n_centroids-1] float32
# Cache and indexing
KV_cache_ptr
,
# [total_bytes] uint8 (flattened view)
Slot_mapping_ptr
,
# [N] int32 — per-token slot indices
# Cache strides
stride_cache_block
:
tl
.
constexpr
,
stride_cache_pos
:
tl
.
constexpr
,
stride_cache_head
:
tl
.
constexpr
,
# Dimensions
D
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# TQ layout
MSE_BYTES
:
tl
.
constexpr
,
KPS
:
tl
.
constexpr
,
# Value quantization
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
# Packing block sizes
BLOCK_VAL
:
tl
.
constexpr
,
# MSE params
MSE_BITS
:
tl
.
constexpr
,
N_CENTROIDS
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
=
16
,
):
"""Fused MSE quantize + pack + store.
Performs bucketize, centroid gather, residual norm, MSE index packing,
and value quantization in one kernel — eliminates 4 PyTorch kernel
launches (bucketize, gather, subtract, norm) per layer vs pack-only.
"""
pid
=
tl
.
program_id
(
0
)
token_idx
=
pid
//
H
head_idx
=
pid
%
H
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
return
blk
=
slot
//
BLOCK_SIZE
off
=
slot
%
BLOCK_SIZE
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
)
base
=
pid
*
D
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
D
# ── 1. INLINE BUCKETIZE ──────────────────────────────────────────
y_vec
=
tl
.
load
(
Y_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
)
idx
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
int32
)
for
i
in
range
(
N_CENTROIDS
-
1
):
mid_val
=
tl
.
load
(
Midpoints_ptr
+
i
)
idx
+=
tl
.
where
(
y_vec
>=
mid_val
,
1
,
0
)
# ── 2. CENTROID GATHER + RESIDUAL NORM ────────────────────────────
centroid_vals
=
tl
.
load
(
Centroids_ptr
+
idx
,
mask
=
d_mask
,
other
=
0.0
)
residual
=
y_vec
-
centroid_vals
gamma
=
tl
.
sqrt
(
tl
.
sum
(
tl
.
where
(
d_mask
,
residual
*
residual
,
0.0
),
axis
=
0
))
# ── 3. PACK MSE INDICES from register idx ─────────────────────────
if
MSE_BITS
==
4
:
idx_pairs
=
tl
.
reshape
(
idx
,
[
BLOCK_D
//
2
,
2
])
shifts_4
=
tl
.
arange
(
0
,
2
)
*
4
packed
=
tl
.
sum
((
idx_pairs
&
0xF
)
<<
shifts_4
[
None
,
:],
axis
=
1
).
to
(
tl
.
uint8
)
mse_offs
=
tl
.
arange
(
0
,
BLOCK_D
//
2
)
mse_mask
=
mse_offs
<
MSE_BYTES
tl
.
store
(
KV_cache_ptr
+
slot_base
+
mse_offs
,
packed
,
mask
=
mse_mask
)
elif
MSE_BITS
==
3
:
grp_offs
=
tl
.
arange
(
0
,
BLOCK_GRP
)
grp_mask
=
grp_offs
<
(
D
//
8
)
idx_grp
=
tl
.
reshape
(
idx
,
[
BLOCK_GRP
,
8
])
shifts_3
=
tl
.
arange
(
0
,
8
)
*
3
packed_24
=
tl
.
sum
((
idx_grp
&
0x7
)
<<
shifts_3
[
None
,
:],
axis
=
1
)
b0
=
(
packed_24
&
0xFF
).
to
(
tl
.
uint8
)
b1
=
((
packed_24
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
b2
=
((
packed_24
>>
16
)
&
0xFF
).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
,
b0
,
mask
=
grp_mask
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
+
1
,
b1
,
mask
=
grp_mask
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
+
2
,
b2
,
mask
=
grp_mask
)
# ── 4. STORE NORMS (vec_norm + gamma as fp16) ─────────────────────
norm_offset
=
MSE_BYTES
vn_f16
=
tl
.
load
(
Norms_ptr
+
pid
).
to
(
tl
.
float16
)
vn_u16
=
vn_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
,
(
vn_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
1
,
((
vn_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
)
gm_f16
=
gamma
.
to
(
tl
.
float16
)
gm_u16
=
gm_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
2
,
(
gm_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
3
,
((
gm_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
)
# ── 5. VALUE QUANTIZE + PACK ──────────────────────────────────────
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
slot_base
,
d_offs
,
d_mask
,
D
=
D
,
KPS
=
KPS
,
VQB
=
VQB
,
VAL_DATA_BYTES
=
VAL_DATA_BYTES
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
BLOCK_GRP
,
)
# ═══════════════════════════════════════════════════════════════════════
# Launcher
# ═══════════════════════════════════════════════════════════════════════
def
triton_turboquant_store
(
key
:
torch
.
Tensor
,
# [N, H, D] — raw keys (post-RoPE)
value
:
torch
.
Tensor
,
# [N, H, D] — raw values
kv_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, Hk, padded_slot] uint8
slot_mapping
:
torch
.
Tensor
,
# [N] int32
PiT
:
torch
.
Tensor
,
# [D, D] float32
centroids
:
torch
.
Tensor
,
# [n_centroids] float32
midpoints
:
torch
.
Tensor
,
# [n_centroids-1] float32
mse_bits
:
int
,
key_packed_size
:
int
,
value_quant_bits
:
int
,
key_fp8
:
bool
=
False
,
):
"""Launch TQ store kernel — FP8 uses _tq_fused_store_fp8, MSE uses _tq_fused_store_mse."""
N
,
H
,
D
=
key
.
shape
NH
=
N
*
H
block_size
=
kv_cache
.
shape
[
1
]
num_kv_heads
=
kv_cache
.
shape
[
2
]
padded_slot
=
kv_cache
.
shape
[
3
]
BLOCK_D
=
triton
.
next_power_of_2
(
D
)
mse_bytes
=
math
.
ceil
(
D
*
mse_bits
/
8
)
n_centroids
=
2
**
mse_bits
val_data_bytes
=
math
.
ceil
(
D
*
value_quant_bits
/
8
)
BLOCK_VAL
=
triton
.
next_power_of_2
(
val_data_bytes
)
# Cache strides
stride_block
=
block_size
*
num_kv_heads
*
padded_slot
stride_pos
=
num_kv_heads
*
padded_slot
stride_head
=
padded_slot
block_grp
=
triton
.
next_power_of_2
(
D
//
8
)
if
D
>=
8
else
1
# ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ──
if
key_fp8
:
k_flat
=
key
.
reshape
(
NH
,
D
).
contiguous
()
v_flat
=
value
.
reshape
(
NH
,
D
).
contiguous
()
fp8_e4b15
=
_use_fp8_e4b15
(
key
.
device
.
index
or
0
)
grid
=
(
NH
,)
_tq_fused_store_fp8
[
grid
](
k_flat
,
v_flat
,
kv_cache
.
view
(
-
1
),
slot_mapping
,
stride_cache_block
=
stride_block
,
stride_cache_pos
=
stride_pos
,
stride_cache_head
=
stride_head
,
D
=
D
,
H
=
H
,
BLOCK_SIZE
=
block_size
,
BLOCK_D
=
BLOCK_D
,
KPS
=
key_packed_size
,
VQB
=
value_quant_bits
,
VAL_DATA_BYTES
=
val_data_bytes
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
block_grp
,
FP8_E4B15
=
fp8_e4b15
,
num_warps
=
4
,
num_stages
=
1
,
)
return
# ── MSE PATH: external GEMM + fused bucketize/pack kernel ──
# Normalize + rotation GEMM externally (cuBLAS is faster than in-kernel)
k_flat
=
key
.
float
().
reshape
(
NH
,
D
)
norms
=
k_flat
.
norm
(
dim
=
1
,
keepdim
=
True
)
x_hat
=
k_flat
/
(
norms
+
1e-8
)
y
=
(
x_hat
@
PiT
).
contiguous
()
v_flat
=
value
.
float
().
reshape
(
NH
,
D
)
# Fused kernel: bucketize + centroid gather + residual norm + pack
grid
=
(
NH
,)
_tq_fused_store_mse
[
grid
](
y
,
norms
.
squeeze
(
1
),
v_flat
,
centroids
,
midpoints
,
kv_cache
.
view
(
-
1
),
slot_mapping
,
stride_cache_block
=
stride_block
,
stride_cache_pos
=
stride_pos
,
stride_cache_head
=
stride_head
,
D
=
D
,
H
=
H
,
BLOCK_SIZE
=
block_size
,
BLOCK_D
=
BLOCK_D
,
MSE_BYTES
=
mse_bytes
,
KPS
=
key_packed_size
,
VQB
=
value_quant_bits
,
VAL_DATA_BYTES
=
val_data_bytes
,
BLOCK_VAL
=
BLOCK_VAL
,
MSE_BITS
=
mse_bits
,
N_CENTROIDS
=
n_centroids
,
BLOCK_GRP
=
block_grp
,
num_warps
=
4
,
num_stages
=
1
,
)
vllm/v1/core/single_type_kv_cache_manager.py
View file @
c3270a92
...
...
@@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (
MLAAttentionSpec
,
SinkFullAttentionSpec
,
SlidingWindowSpec
,
TQFullAttentionSpec
,
)
from
vllm.v1.request
import
Request
...
...
@@ -51,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_pool
=
block_pool
self
.
enable_caching
=
enable_caching
# self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
...
...
@@ -204,6 +206,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv
(
num_total_computed_tokens
,
self
.
block_size
)
-
len
(
req_blocks
)
)
req_blocks
.
extend
(
allocated_blocks
)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
,
num_tokens_main_model
:
int
...
...
@@ -230,6 +234,8 @@ class SingleTypeKVCacheManager(ABC):
else
:
new_blocks
=
self
.
block_pool
.
get_new_blocks
(
num_new_blocks
)
req_blocks
.
extend
(
new_blocks
)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in new_blocks)
return
new_blocks
def
cache_blocks
(
self
,
request
:
Request
,
num_tokens
:
int
)
->
None
:
...
...
@@ -1048,6 +1054,7 @@ class SinkFullAttentionManager(FullAttentionManager):
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
TQFullAttentionSpec
:
FullAttentionManager
,
MLAAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
ChunkedLocalAttentionSpec
:
ChunkedLocalAttentionManager
,
...
...
vllm/v1/kv_cache_interface.py
View file @
c3270a92
...
...
@@ -187,6 +187,32 @@ class FullAttentionSpec(AttentionSpec):
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
TQFullAttentionSpec
(
FullAttentionSpec
):
"""FullAttentionSpec with TQ-aware page size.
Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
real_page_size_bytes to use TQ slot bytes instead of the raw
head_size * dtype formula.
"""
tq_slot_size
:
int
=
0
@
property
def
real_page_size_bytes
(
self
)
->
int
:
if
self
.
tq_slot_size
>
0
:
return
self
.
block_size
*
self
.
num_kv_heads
*
self
.
tq_slot_size
return
super
().
real_page_size_bytes
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
merged
=
super
().
merge
(
specs
)
assert
all
(
s
.
tq_slot_size
==
specs
[
0
].
tq_slot_size
for
s
in
specs
),
(
"All TQ layers in the same KV cache group must use the same tq_slot_size."
)
return
replace
(
merged
,
tq_slot_size
=
specs
[
0
].
tq_slot_size
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
MLAAttentionSpec
(
FullAttentionSpec
):
# TODO(Lucas/Chen): less hacky way to do this
...
...
vllm/v1/worker/dp_utils.py
View file @
c3270a92
...
...
@@ -6,6 +6,7 @@ import numpy as np
import
torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.logger
import
init_logger
...
...
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
"""
if
parallel_config
.
data_parallel_size
==
1
:
if
parallel_config
.
data_parallel_size
==
1
or
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
:
# Early exit.
return
False
,
None
,
cudagraph_mode
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c3270a92
...
...
@@ -189,6 +189,7 @@ from .utils import (
sanity_check_mm_encoder_outputs
,
)
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.utils.torch_utils
import
async_tensor_h2d
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
...
...
@@ -5117,9 +5118,6 @@ class GPUModelRunner(
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
model_kwargs
=
self
.
_init_model_kwargs
()
else
:
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
=
torch
.
randint
(
0
,
self
.
model_config
.
get_vocab_size
(),
(
num_tokens_padded
,),
dtype
=
torch
.
int32
)
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
None
...
...
@@ -5234,9 +5232,15 @@ class GPUModelRunner(
self
.
eplb_step
(
is_dummy
=
True
,
is_profile
=
is_profile
)
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices_device
=
torch
.
from_numpy
(
logit_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
# logit_indices_device = torch.from_numpy(logit_indices).to(
# self.device, non_blocking=True
# )
logit_indices
=
logit_indices
.
tolist
()
logit_indices_device
=
async_tensor_h2d
(
logit_indices
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
return
hidden_states
,
hidden_states
[
logit_indices_device
]
@
torch
.
inference_mode
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment