Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0885aa25
Commit
0885aa25
authored
Apr 18, 2026
by
wanglong3
Committed by
zhangzbb
Apr 18, 2026
Browse files
[feature][Attention Backend] TurboQuant: 2-bit KV cache compression with 4x capacity #38479
parent
4fca01b8
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1934 additions
and
0 deletions
+1934
-0
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+4
-0
vllm/v1/attention/backends/registry.py
vllm/v1/attention/backends/registry.py
+1
-0
vllm/v1/attention/backends/turboquant_attn.py
vllm/v1/attention/backends/turboquant_attn.py
+812
-0
vllm/v1/attention/ops/triton_turboquant_decode.py
vllm/v1/attention/ops/triton_turboquant_decode.py
+624
-0
vllm/v1/attention/ops/triton_turboquant_store.py
vllm/v1/attention/ops/triton_turboquant_store.py
+460
-0
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+7
-0
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+26
-0
No files found.
vllm/utils/torch_utils.py
View file @
0885aa25
...
...
@@ -42,6 +42,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"int8"
:
torch
.
int8
,
"fp8_inc"
:
torch
.
float8_e4m3fn
,
"fp8_ds_mla"
:
torch
.
uint8
,
"turboquant_k8v4"
:
torch
.
uint8
,
"turboquant_4bit_nc"
:
torch
.
uint8
,
"turboquant_k3v4_nc"
:
torch
.
uint8
,
"turboquant_3bit_nc"
:
torch
.
uint8
,
}
TORCH_DTYPE_TO_NUMPY_DTYPE
=
{
...
...
vllm/v1/attention/backends/registry.py
View file @
0885aa25
...
...
@@ -78,6 +78,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN
=
"vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
TURBOQUANT
=
"vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM
=
None
...
...
vllm/v1/attention/backends/turboquant_attn.py
0 → 100644
View file @
0885aa25
This diff is collapsed.
Click to expand it.
vllm/v1/attention/ops/triton_turboquant_decode.py
0 → 100644
View file @
0885aa25
This diff is collapsed.
Click to expand it.
vllm/v1/attention/ops/triton_turboquant_store.py
0 → 100644
View file @
0885aa25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused Triton kernels for TurboQuant KV store.
Two kernels:
1. _tq_fused_store_fp8: FP8 key scatter + value uniform quantization.
2. _tq_fused_store_mse: Fused bucketize + centroid gather + residual norm
+ MSE index packing + value quantization (eliminates 4 PyTorch kernel
launches vs the old pack-only approach).
The launcher `triton_turboquant_store` selects the appropriate kernel.
"""
import
math
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.ops.triton_turboquant_decode
import
_use_fp8_e4b15
# ═══════════════════════════════════════════════════════════════════════
# Shared: value uniform quantization + pack + scale/zero store
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
# pid * D offset into Value_ptr
slot_base
,
# byte offset into KV_cache_ptr for this slot+head
d_offs
,
# tl.arange(0, BLOCK_D)
d_mask
,
# d_offs < D
D
:
tl
.
constexpr
,
KPS
:
tl
.
constexpr
,
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
BLOCK_VAL
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
,
):
"""Uniform quantization of values to VQB bits, pack, and store with scale/zero."""
val_cache_offset
=
KPS
if
VQB
==
3
:
val_vec
=
tl
.
load
(
Value_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
val_min
=
tl
.
min
(
tl
.
where
(
d_mask
,
val_vec
,
float
(
"inf"
)),
axis
=
0
)
val_max
=
tl
.
max
(
tl
.
where
(
d_mask
,
val_vec
,
-
float
(
"inf"
)),
axis
=
0
)
v_scale
=
(
val_max
-
val_min
)
/
7.0
v_scale
=
tl
.
where
(
v_scale
>
1e-8
,
v_scale
,
1e-8
)
q_vals
=
tl
.
minimum
(
tl
.
maximum
(((
val_vec
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
7
)
grp_offs
=
tl
.
arange
(
0
,
BLOCK_GRP
)
grp_mask
=
grp_offs
<
(
D
//
8
)
q_grp
=
tl
.
reshape
(
q_vals
,
[
BLOCK_GRP
,
8
])
shifts_3bit
=
tl
.
arange
(
0
,
8
)
*
3
packed_24
=
tl
.
sum
(
q_grp
<<
shifts_3bit
[
None
,
:],
axis
=
1
)
b0
=
(
packed_24
&
0xFF
).
to
(
tl
.
uint8
)
b1
=
((
packed_24
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
b2
=
((
packed_24
>>
16
)
&
0xFF
).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
,
b0
,
mask
=
grp_mask
,
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
+
1
,
b1
,
mask
=
grp_mask
,
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
+
2
,
b2
,
mask
=
grp_mask
,
)
sc_offset
=
val_cache_offset
+
VAL_DATA_BYTES
sc_f16
=
v_scale
.
to
(
tl
.
float16
)
sc_u16
=
sc_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
,
(
sc_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
1
,
((
sc_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
zr_f16
=
val_min
.
to
(
tl
.
float16
)
zr_u16
=
zr_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
2
,
(
zr_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
3
,
((
zr_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
else
:
# VQB == 4
val_vec
=
tl
.
load
(
Value_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
val_min
=
tl
.
min
(
tl
.
where
(
d_mask
,
val_vec
,
float
(
"inf"
)),
axis
=
0
)
val_max
=
tl
.
max
(
tl
.
where
(
d_mask
,
val_vec
,
-
float
(
"inf"
)),
axis
=
0
)
v_scale
=
(
val_max
-
val_min
)
/
15.0
v_scale
=
tl
.
where
(
v_scale
>
1e-8
,
v_scale
,
1e-8
)
val_offs
=
tl
.
arange
(
0
,
BLOCK_VAL
)
val_mask
=
val_offs
<
VAL_DATA_BYTES
v0
=
tl
.
load
(
Value_ptr
+
base
+
val_offs
*
2
,
mask
=
val_mask
&
(
val_offs
*
2
<
D
),
other
=
val_min
,
)
v1
=
tl
.
load
(
Value_ptr
+
base
+
val_offs
*
2
+
1
,
mask
=
val_mask
&
(
val_offs
*
2
+
1
<
D
),
other
=
val_min
,
)
q0
=
tl
.
minimum
(
tl
.
maximum
(((
v0
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
15
)
q1
=
tl
.
minimum
(
tl
.
maximum
(((
v1
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
15
)
packed_val
=
(
q0
|
(
q1
<<
4
)).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
val_offs
,
packed_val
,
mask
=
val_mask
,
)
sc_offset
=
val_cache_offset
+
VAL_DATA_BYTES
sc_f16
=
v_scale
.
to
(
tl
.
float16
)
sc_u16
=
sc_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
,
(
sc_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
1
,
((
sc_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
zr_f16
=
val_min
.
to
(
tl
.
float16
)
zr_u16
=
zr_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
2
,
(
zr_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
3
,
((
zr_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
# ═══════════════════════════════════════════════════════════════════════
# FP8 key store + value uniform quantization
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_tq_fused_store_fp8
(
Key_ptr
,
# [NH, D] float16/bfloat16 — raw keys
Value_ptr
,
# [NH, D] float16/bfloat16 — raw values
KV_cache_ptr
,
# [total_bytes] uint8 (flattened view)
Slot_mapping_ptr
,
# [N] int32 — per-token slot indices
# Cache strides (for computing byte offsets)
stride_cache_block
:
tl
.
constexpr
,
stride_cache_pos
:
tl
.
constexpr
,
stride_cache_head
:
tl
.
constexpr
,
# Dimensions
D
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# TQ layout
KPS
:
tl
.
constexpr
,
# Value quantization
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
# Packing block sizes
BLOCK_VAL
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
=
16
,
FP8_E4B15
:
tl
.
constexpr
=
0
,
# 1 = e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
"""FP8 key cast+scatter + value uniform quantization: one program per (token, head)."""
pid
=
tl
.
program_id
(
0
)
token_idx
=
pid
//
H
head_idx
=
pid
%
H
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
return
blk
=
slot
//
BLOCK_SIZE
off
=
slot
%
BLOCK_SIZE
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
)
base
=
pid
*
D
# ── FP8 KEY: cast to FP8 in-kernel and store ─────────────────
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
D
k_vals
=
tl
.
load
(
Key_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
)
if
FP8_E4B15
:
k_fp8
=
k_vals
.
to
(
tl
.
float8e4b15
)
else
:
x_f32
=
k_vals
.
to
(
tl
.
float32
)
k_fp8
=
x_f32
.
to
(
tl
.
float8e4nv
)
k_bytes
=
k_fp8
.
to
(
tl
.
uint8
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
d_offs
,
k_bytes
,
mask
=
d_mask
)
# ── VALUE QUANTIZE + PACK ───────────────────────────────────────
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
slot_base
,
d_offs
,
d_mask
,
D
=
D
,
KPS
=
KPS
,
VQB
=
VQB
,
VAL_DATA_BYTES
=
VAL_DATA_BYTES
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
BLOCK_GRP
,
)
# ═══════════════════════════════════════════════════════════════════════
# Fused MSE store: bucketize + centroid gather + residual norm + pack
# (eliminates 4 PyTorch kernel launches per layer vs pack-only kernel)
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_tq_fused_store_mse
(
# Post-rotation inputs
Y_ptr
,
# [NH, D] float32 — rotated normalized keys (x_hat @ PiT)
Norms_ptr
,
# [NH] float32 — key vector norms (||k||)
Value_ptr
,
# [NH, D] float32 — raw values
# Quantization tables
Centroids_ptr
,
# [n_centroids] float32
Midpoints_ptr
,
# [n_centroids-1] float32
# Cache and indexing
KV_cache_ptr
,
# [total_bytes] uint8 (flattened view)
Slot_mapping_ptr
,
# [N] int32 — per-token slot indices
# Cache strides
stride_cache_block
:
tl
.
constexpr
,
stride_cache_pos
:
tl
.
constexpr
,
stride_cache_head
:
tl
.
constexpr
,
# Dimensions
D
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# TQ layout
MSE_BYTES
:
tl
.
constexpr
,
KPS
:
tl
.
constexpr
,
# Value quantization
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
# Packing block sizes
BLOCK_VAL
:
tl
.
constexpr
,
# MSE params
MSE_BITS
:
tl
.
constexpr
,
N_CENTROIDS
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
=
16
,
):
"""Fused MSE quantize + pack + store.
Performs bucketize, centroid gather, residual norm, MSE index packing,
and value quantization in one kernel — eliminates 4 PyTorch kernel
launches (bucketize, gather, subtract, norm) per layer vs pack-only.
"""
pid
=
tl
.
program_id
(
0
)
token_idx
=
pid
//
H
head_idx
=
pid
%
H
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
return
blk
=
slot
//
BLOCK_SIZE
off
=
slot
%
BLOCK_SIZE
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
)
base
=
pid
*
D
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
D
# ── 1. INLINE BUCKETIZE ──────────────────────────────────────────
y_vec
=
tl
.
load
(
Y_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
)
idx
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
int32
)
for
i
in
range
(
N_CENTROIDS
-
1
):
mid_val
=
tl
.
load
(
Midpoints_ptr
+
i
)
idx
+=
tl
.
where
(
y_vec
>=
mid_val
,
1
,
0
)
# ── 2. CENTROID GATHER + RESIDUAL NORM ────────────────────────────
centroid_vals
=
tl
.
load
(
Centroids_ptr
+
idx
,
mask
=
d_mask
,
other
=
0.0
)
residual
=
y_vec
-
centroid_vals
gamma
=
tl
.
sqrt
(
tl
.
sum
(
tl
.
where
(
d_mask
,
residual
*
residual
,
0.0
),
axis
=
0
))
# ── 3. PACK MSE INDICES from register idx ─────────────────────────
if
MSE_BITS
==
4
:
idx_pairs
=
tl
.
reshape
(
idx
,
[
BLOCK_D
//
2
,
2
])
shifts_4
=
tl
.
arange
(
0
,
2
)
*
4
packed
=
tl
.
sum
((
idx_pairs
&
0xF
)
<<
shifts_4
[
None
,
:],
axis
=
1
).
to
(
tl
.
uint8
)
mse_offs
=
tl
.
arange
(
0
,
BLOCK_D
//
2
)
mse_mask
=
mse_offs
<
MSE_BYTES
tl
.
store
(
KV_cache_ptr
+
slot_base
+
mse_offs
,
packed
,
mask
=
mse_mask
)
elif
MSE_BITS
==
3
:
grp_offs
=
tl
.
arange
(
0
,
BLOCK_GRP
)
grp_mask
=
grp_offs
<
(
D
//
8
)
idx_grp
=
tl
.
reshape
(
idx
,
[
BLOCK_GRP
,
8
])
shifts_3
=
tl
.
arange
(
0
,
8
)
*
3
packed_24
=
tl
.
sum
((
idx_grp
&
0x7
)
<<
shifts_3
[
None
,
:],
axis
=
1
)
b0
=
(
packed_24
&
0xFF
).
to
(
tl
.
uint8
)
b1
=
((
packed_24
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
b2
=
((
packed_24
>>
16
)
&
0xFF
).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
,
b0
,
mask
=
grp_mask
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
+
1
,
b1
,
mask
=
grp_mask
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
+
2
,
b2
,
mask
=
grp_mask
)
# ── 4. STORE NORMS (vec_norm + gamma as fp16) ─────────────────────
norm_offset
=
MSE_BYTES
vn_f16
=
tl
.
load
(
Norms_ptr
+
pid
).
to
(
tl
.
float16
)
vn_u16
=
vn_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
,
(
vn_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
1
,
((
vn_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
)
gm_f16
=
gamma
.
to
(
tl
.
float16
)
gm_u16
=
gm_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
2
,
(
gm_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
3
,
((
gm_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
)
# ── 5. VALUE QUANTIZE + PACK ──────────────────────────────────────
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
slot_base
,
d_offs
,
d_mask
,
D
=
D
,
KPS
=
KPS
,
VQB
=
VQB
,
VAL_DATA_BYTES
=
VAL_DATA_BYTES
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
BLOCK_GRP
,
)
# ═══════════════════════════════════════════════════════════════════════
# Launcher
# ═══════════════════════════════════════════════════════════════════════
def
triton_turboquant_store
(
key
:
torch
.
Tensor
,
# [N, H, D] — raw keys (post-RoPE)
value
:
torch
.
Tensor
,
# [N, H, D] — raw values
kv_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, Hk, padded_slot] uint8
slot_mapping
:
torch
.
Tensor
,
# [N] int32
PiT
:
torch
.
Tensor
,
# [D, D] float32
centroids
:
torch
.
Tensor
,
# [n_centroids] float32
midpoints
:
torch
.
Tensor
,
# [n_centroids-1] float32
mse_bits
:
int
,
key_packed_size
:
int
,
value_quant_bits
:
int
,
key_fp8
:
bool
=
False
,
):
"""Launch TQ store kernel — FP8 uses _tq_fused_store_fp8, MSE uses _tq_fused_store_mse."""
N
,
H
,
D
=
key
.
shape
NH
=
N
*
H
block_size
=
kv_cache
.
shape
[
1
]
num_kv_heads
=
kv_cache
.
shape
[
2
]
padded_slot
=
kv_cache
.
shape
[
3
]
BLOCK_D
=
triton
.
next_power_of_2
(
D
)
mse_bytes
=
math
.
ceil
(
D
*
mse_bits
/
8
)
n_centroids
=
2
**
mse_bits
val_data_bytes
=
math
.
ceil
(
D
*
value_quant_bits
/
8
)
BLOCK_VAL
=
triton
.
next_power_of_2
(
val_data_bytes
)
# Cache strides
stride_block
=
block_size
*
num_kv_heads
*
padded_slot
stride_pos
=
num_kv_heads
*
padded_slot
stride_head
=
padded_slot
block_grp
=
triton
.
next_power_of_2
(
D
//
8
)
if
D
>=
8
else
1
# ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ──
if
key_fp8
:
k_flat
=
key
.
reshape
(
NH
,
D
).
contiguous
()
v_flat
=
value
.
reshape
(
NH
,
D
).
contiguous
()
fp8_e4b15
=
_use_fp8_e4b15
(
key
.
device
.
index
or
0
)
grid
=
(
NH
,)
_tq_fused_store_fp8
[
grid
](
k_flat
,
v_flat
,
kv_cache
.
view
(
-
1
),
slot_mapping
,
stride_cache_block
=
stride_block
,
stride_cache_pos
=
stride_pos
,
stride_cache_head
=
stride_head
,
D
=
D
,
H
=
H
,
BLOCK_SIZE
=
block_size
,
BLOCK_D
=
BLOCK_D
,
KPS
=
key_packed_size
,
VQB
=
value_quant_bits
,
VAL_DATA_BYTES
=
val_data_bytes
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
block_grp
,
FP8_E4B15
=
fp8_e4b15
,
num_warps
=
4
,
num_stages
=
1
,
)
return
# ── MSE PATH: external GEMM + fused bucketize/pack kernel ──
# Normalize + rotation GEMM externally (cuBLAS is faster than in-kernel)
k_flat
=
key
.
float
().
reshape
(
NH
,
D
)
norms
=
k_flat
.
norm
(
dim
=
1
,
keepdim
=
True
)
x_hat
=
k_flat
/
(
norms
+
1e-8
)
y
=
(
x_hat
@
PiT
).
contiguous
()
v_flat
=
value
.
float
().
reshape
(
NH
,
D
)
# Fused kernel: bucketize + centroid gather + residual norm + pack
grid
=
(
NH
,)
_tq_fused_store_mse
[
grid
](
y
,
norms
.
squeeze
(
1
),
v_flat
,
centroids
,
midpoints
,
kv_cache
.
view
(
-
1
),
slot_mapping
,
stride_cache_block
=
stride_block
,
stride_cache_pos
=
stride_pos
,
stride_cache_head
=
stride_head
,
D
=
D
,
H
=
H
,
BLOCK_SIZE
=
block_size
,
BLOCK_D
=
BLOCK_D
,
MSE_BYTES
=
mse_bytes
,
KPS
=
key_packed_size
,
VQB
=
value_quant_bits
,
VAL_DATA_BYTES
=
val_data_bytes
,
BLOCK_VAL
=
BLOCK_VAL
,
MSE_BITS
=
mse_bits
,
N_CENTROIDS
=
n_centroids
,
BLOCK_GRP
=
block_grp
,
num_warps
=
4
,
num_stages
=
1
,
)
vllm/v1/core/single_type_kv_cache_manager.py
View file @
0885aa25
...
...
@@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (
MLAAttentionSpec
,
SinkFullAttentionSpec
,
SlidingWindowSpec
,
TQFullAttentionSpec
,
)
from
vllm.v1.request
import
Request
...
...
@@ -51,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_pool
=
block_pool
self
.
enable_caching
=
enable_caching
# self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
...
...
@@ -204,6 +206,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv
(
num_total_computed_tokens
,
self
.
block_size
)
-
len
(
req_blocks
)
)
req_blocks
.
extend
(
allocated_blocks
)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
,
num_tokens_main_model
:
int
...
...
@@ -230,6 +234,8 @@ class SingleTypeKVCacheManager(ABC):
else
:
new_blocks
=
self
.
block_pool
.
get_new_blocks
(
num_new_blocks
)
req_blocks
.
extend
(
new_blocks
)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in new_blocks)
return
new_blocks
def
cache_blocks
(
self
,
request
:
Request
,
num_tokens
:
int
)
->
None
:
...
...
@@ -1048,6 +1054,7 @@ class SinkFullAttentionManager(FullAttentionManager):
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
TQFullAttentionSpec
:
FullAttentionManager
,
MLAAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
ChunkedLocalAttentionSpec
:
ChunkedLocalAttentionManager
,
...
...
vllm/v1/kv_cache_interface.py
View file @
0885aa25
...
...
@@ -187,6 +187,32 @@ class FullAttentionSpec(AttentionSpec):
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
TQFullAttentionSpec
(
FullAttentionSpec
):
"""FullAttentionSpec with TQ-aware page size.
Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
real_page_size_bytes to use TQ slot bytes instead of the raw
head_size * dtype formula.
"""
tq_slot_size
:
int
=
0
@
property
def
real_page_size_bytes
(
self
)
->
int
:
if
self
.
tq_slot_size
>
0
:
return
self
.
block_size
*
self
.
num_kv_heads
*
self
.
tq_slot_size
return
super
().
real_page_size_bytes
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
merged
=
super
().
merge
(
specs
)
assert
all
(
s
.
tq_slot_size
==
specs
[
0
].
tq_slot_size
for
s
in
specs
),
(
"All TQ layers in the same KV cache group must use the same tq_slot_size."
)
return
replace
(
merged
,
tq_slot_size
=
specs
[
0
].
tq_slot_size
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
MLAAttentionSpec
(
FullAttentionSpec
):
# TODO(Lucas/Chen): less hacky way to do this
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment