Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e42af78b
Unverified
Commit
e42af78b
authored
Sep 11, 2025
by
Xiaozhu Meng
Committed by
GitHub
Sep 11, 2025
Browse files
[flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention (#24197)
Signed-off-by:
Xiaozhu
<
mxz297@gmail.com
>
parent
074854b2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
121 additions
and
14 deletions
+121
-14
vllm/envs.py
vllm/envs.py
+6
-0
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+6
-5
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+109
-9
No files found.
vllm/envs.py
View file @
e42af78b
...
@@ -163,6 +163,7 @@ if TYPE_CHECKING:
...
@@ -163,6 +163,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE
:
bool
=
False
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE
:
bool
=
False
VLLM_ENABLE_RESPONSES_API_STORE
:
bool
=
False
VLLM_ENABLE_RESPONSES_API_STORE
:
bool
=
False
VLLM_USE_TRTLLM_ATTENTION
:
Optional
[
str
]
=
None
VLLM_USE_TRTLLM_ATTENTION
:
Optional
[
str
]
=
None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
:
bool
=
False
VLLM_HAS_FLASHINFER_CUBIN
:
bool
=
False
VLLM_HAS_FLASHINFER_CUBIN
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
:
bool
=
False
...
@@ -1155,6 +1156,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1155,6 +1156,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION"
:
"VLLM_USE_TRTLLM_ATTENTION"
:
lambda
:
os
.
getenv
(
"VLLM_USE_TRTLLM_ATTENTION"
,
None
),
lambda
:
os
.
getenv
(
"VLLM_USE_TRTLLM_ATTENTION"
,
None
),
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION"
,
"0"
))),
# If set, it means we pre-downloaded cubin files and flashinfer will
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN"
:
"VLLM_HAS_FLASHINFER_CUBIN"
:
...
@@ -1310,6 +1315,7 @@ def compute_hash() -> str:
...
@@ -1310,6 +1315,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"
,
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"
,
"VLLM_USE_CUDNN_PREFILL"
,
"VLLM_USE_CUDNN_PREFILL"
,
"VLLM_USE_TRTLLM_ATTENTION"
,
"VLLM_USE_TRTLLM_ATTENTION"
,
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION"
,
"VLLM_ROCM_USE_AITER"
,
"VLLM_ROCM_USE_AITER"
,
"VLLM_ROCM_USE_AITER_PAGED_ATTN"
,
"VLLM_ROCM_USE_AITER_PAGED_ATTN"
,
"VLLM_ROCM_USE_AITER_LINEAR"
,
"VLLM_ROCM_USE_AITER_LINEAR"
,
...
...
vllm/utils/flashinfer.py
View file @
e42af78b
...
@@ -200,11 +200,6 @@ def use_trtllm_attention(
...
@@ -200,11 +200,6 @@ def use_trtllm_attention(
logger
.
info_once
(
"Using TRTLLM attention (query is quantized)."
)
logger
.
info_once
(
"Using TRTLLM attention (query is quantized)."
)
return
True
return
True
# TRTLLM prefill attention does not support FP8 kv cache with
# non-quantized query
if
is_prefill
and
kv_cache_dtype
.
startswith
(
"fp8"
):
return
False
# If sinks are being used, we must use TRTLLM attention as it's
# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
# the only backend that supports them
if
has_sinks
:
if
has_sinks
:
...
@@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm(
...
@@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm(
return
output
return
output
@
functools
.
cache
def
flashinfer_disable_q_quantization
()
->
bool
:
"""Cache result which only depends on the environment"""
return
envs
.
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__
=
[
__all__
=
[
"has_flashinfer"
,
"has_flashinfer"
,
"flashinfer_trtllm_fp8_block_scale_moe"
,
"flashinfer_trtllm_fp8_block_scale_moe"
,
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
e42af78b
...
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
is_pin_memory_available
from
vllm.utils
import
cdiv
,
is_pin_memory_available
from
vllm.utils.flashinfer
import
(
supports_trtllm_attention
,
from
vllm.utils.flashinfer
import
(
flashinfer_disable_q_quantization
,
supports_trtllm_attention
,
use_trtllm_attention
)
use_trtllm_attention
)
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
...
@@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8
...
@@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
FlashInferBackend
(
AttentionBackend
):
@
triton
.
jit
def
_trtllm_prefill_attn_kvfp8_dequant
(
kv_cache_ptr
,
block_tables_prefill_ptr
,
block_table_stride
,
mock_kv_cache_ptr
,
k_scale_ptr
,
v_scale_ptr
,
K_CACHE_STRIDE
:
tl
.
constexpr
,
KV_CACHE_STRIDE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
mock_block_table_idx
=
tl
.
program_id
(
1
).
to
(
tl
.
int64
)
orig_page_num
=
tl
.
load
(
block_tables_prefill_ptr
+
batch_idx
*
block_table_stride
+
mock_block_table_idx
).
to
(
tl
.
int64
)
if
orig_page_num
<=
0
:
return
dequant_dtype
=
mock_kv_cache_ptr
.
dtype
.
element_ty
# Dequantize K
k_scale_val
=
tl
.
load
(
k_scale_ptr
)
offset
=
orig_page_num
*
KV_CACHE_STRIDE
+
tl
.
arange
(
0
,
K_CACHE_STRIDE
)
fp8_vals
=
tl
.
load
(
kv_cache_ptr
+
offset
)
dequantized_vals
=
fp8_vals
.
to
(
tl
.
float32
)
*
k_scale_val
mock_cache_offset
=
(
batch_idx
*
block_table_stride
+
mock_block_table_idx
+
1
)
*
KV_CACHE_STRIDE
+
tl
.
arange
(
0
,
K_CACHE_STRIDE
)
dequantized_vals
=
dequantized_vals
.
to
(
dequant_dtype
)
tl
.
store
(
mock_kv_cache_ptr
+
mock_cache_offset
,
dequantized_vals
)
# Dequantize V
v_scale_val
=
tl
.
load
(
v_scale_ptr
)
offset
=
(
orig_page_num
*
KV_CACHE_STRIDE
+
K_CACHE_STRIDE
+
tl
.
arange
(
0
,
K_CACHE_STRIDE
))
fp8_vals
=
tl
.
load
(
kv_cache_ptr
+
offset
)
dequantized_vals
=
fp8_vals
.
to
(
tl
.
float32
)
*
v_scale_val
mock_cache_offset
=
(
(
batch_idx
*
block_table_stride
+
mock_block_table_idx
+
1
)
*
KV_CACHE_STRIDE
+
K_CACHE_STRIDE
+
tl
.
arange
(
0
,
K_CACHE_STRIDE
))
dequantized_vals
=
dequantized_vals
.
to
(
dequant_dtype
)
tl
.
store
(
mock_kv_cache_ptr
+
mock_cache_offset
,
dequantized_vals
)
def
trtllm_prefill_attn_kvfp8_dequant
(
kv_cache
:
torch
.
Tensor
,
block_tables_prefill
:
torch
.
Tensor
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
dequant_dtype
:
torch
.
dtype
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
batch_size
,
num_of_page_per_token
=
block_tables_prefill
.
shape
s
=
kv_cache
.
shape
assert
s
[
1
]
==
2
assert
dequant_dtype
in
(
torch
.
bfloat16
,
torch
.
float16
)
k_cache_stride
=
s
[
2
]
*
s
[
3
]
*
s
[
4
]
kv_cache_stride
=
k_cache_stride
*
s
[
1
]
new_s
=
(
batch_size
*
num_of_page_per_token
+
1
,
s
[
1
],
s
[
2
],
s
[
3
],
s
[
4
])
# mock kv cache contains just the pages needed by this prefill
mock_kv_cache
=
torch
.
empty
(
new_s
,
dtype
=
dequant_dtype
,
device
=
kv_cache
.
device
)
# we simply sequentially index the pages needed by this prefill
mock_block_table
=
torch
.
arange
(
start
=
1
,
end
=
batch_size
*
num_of_page_per_token
+
1
,
dtype
=
torch
.
int32
,
device
=
block_tables_prefill
.
device
,
).
reshape
(
batch_size
,
num_of_page_per_token
)
grid
=
(
batch_size
,
num_of_page_per_token
)
_trtllm_prefill_attn_kvfp8_dequant
[
grid
](
kv_cache
,
block_tables_prefill
,
num_of_page_per_token
,
mock_kv_cache
,
k_scale
,
v_scale
,
k_cache_stride
,
kv_cache_stride
,
)
return
mock_kv_cache
,
mock_block_table
class
FlashInferBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
@
classmethod
@
classmethod
...
@@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend):
...
@@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend):
@
dataclass
@
dataclass
class
FlashInferMetadata
:
class
FlashInferMetadata
:
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_actual_tokens
:
int
# Number of tokens excluding padding.
# The data type of the query
# The data type of the query
...
@@ -175,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -175,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
kv_cache_spec
.
block_size
)
self
.
kv_cache_spec
.
block_size
)
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_pages
=
max_num_reqs
*
max_num_pages_per_req
max_num_pages
=
max_num_reqs
*
max_num_pages_per_req
self
.
enable_cuda_graph
=
self
.
compilation_config
.
cudagraph_mode
.
\
self
.
enable_cuda_graph
=
(
self
.
compilation_config
.
cudagraph_mode
.
\
decode_mode
()
==
CUDAGraphMode
.
FULL
decode_mode
()
==
CUDAGraphMode
.
FULL
)
if
self
.
enable_cuda_graph
:
if
self
.
enable_cuda_graph
:
# For full cudagraph capture, one `decode_wrapper` for each batch
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
# size is needed for FlashInfer.
...
@@ -201,7 +282,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -201,7 +282,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
assert
self
.
kv_cache_spec
.
dtype
==
self
.
model_config
.
dtype
assert
self
.
kv_cache_spec
.
dtype
==
self
.
model_config
.
dtype
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
if
supports_trtllm_attention
()[
0
]:
if
supports_trtllm_attention
()[
0
]
and
\
not
flashinfer_disable_q_quantization
():
self
.
q_data_type
=
self
.
kv_cache_dtype
self
.
q_data_type
=
self
.
kv_cache_dtype
else
:
else
:
self
.
q_data_type
=
self
.
model_config
.
dtype
self
.
q_data_type
=
self
.
model_config
.
dtype
...
@@ -795,11 +877,29 @@ class FlashInferImpl(AttentionImpl):
...
@@ -795,11 +877,29 @@ class FlashInferImpl(AttentionImpl):
assert
self
.
o_sf_scale
is
None
assert
self
.
o_sf_scale
is
None
out
=
output
[
num_decode_tokens
:]
out
=
output
[
num_decode_tokens
:]
if
attn_metadata
.
q_data_type
!=
FP8_DTYPE
\
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill
mock_kv_cache
,
mock_block_table
=
(
trtllm_prefill_attn_kvfp8_dequant
(
kv_cache_permute
,
block_tables_prefill
,
layer
.
_k_scale
,
layer
.
_v_scale
,
attn_metadata
.
q_data_type
,
))
else
:
mock_kv_cache
=
kv_cache_permute
mock_block_table
=
block_tables_prefill
trtllm_batch_context_with_kv_cache
(
trtllm_batch_context_with_kv_cache
(
query
=
prefill_query
,
query
=
prefill_query
,
kv_cache
=
kv_cache
_permute
,
kv_cache
=
mock_
kv_cache
,
workspace_buffer
=
workspace_buffer
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_table
s_prefill
,
block_tables
=
mock_
block_table
,
seq_lens
=
seq_lens_prefill
,
seq_lens
=
seq_lens_prefill
,
max_q_len
=
attn_metadata
.
max_q_len
,
max_q_len
=
attn_metadata
.
max_q_len
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
...
@@ -837,7 +937,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -837,7 +937,7 @@ class FlashInferImpl(AttentionImpl):
decode_query
=
decode_query
.
contiguous
()
decode_query
=
decode_query
.
contiguous
()
workspace_buffer
=
decode_wrapper
.
_float_workspace_buffer
workspace_buffer
=
decode_wrapper
.
_float_workspace_buffer
block_tables_decode
=
attn_metadata
.
\
block_tables_decode
=
attn_metadata
.
\
block_table_tensor
[:
num_decode_tokens
]
block_table_tensor
[:
num_decode_tokens
]
seq_lens_decode
=
attn_metadata
.
seq_lens
[:
num_decode_tokens
]
seq_lens_decode
=
attn_metadata
.
seq_lens
[:
num_decode_tokens
]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment