Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d353b63
Unverified
Commit
1d353b63
authored
Aug 21, 2025
by
Pavani Majety
Committed by
GitHub
Aug 21, 2025
Browse files
[Core] Always use tensor cores for Flashinfer Decode Wrapper (#23214)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
34962746
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
65 deletions
+32
-65
benchmarks/kernels/benchmark_trtllm_decode_attention.py
benchmarks/kernels/benchmark_trtllm_decode_attention.py
+1
-1
tests/kernels/attention/test_flashinfer.py
tests/kernels/attention/test_flashinfer.py
+2
-4
tests/kernels/attention/test_flashinfer_trtllm_attention.py
tests/kernels/attention/test_flashinfer_trtllm_attention.py
+1
-3
vllm/envs.py
vllm/envs.py
+0
-7
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+28
-50
No files found.
benchmarks/kernels/benchmark_trtllm_decode_attention.py
View file @
1d353b63
...
@@ -110,7 +110,7 @@ def benchmark_decode(
...
@@ -110,7 +110,7 @@ def benchmark_decode(
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
workspace_buffer
,
kv_layout
,
kv_layout
,
use_tensor_cores
=
((
num_qo_heads
//
num_kv_heads
)
>
4
)
,
use_tensor_cores
=
True
,
)
)
wrapper
.
plan
(
wrapper
.
plan
(
kv_indptr
,
kv_indptr
,
...
...
tests/kernels/attention/test_flashinfer.py
View file @
1d353b63
...
@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
...
@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
(
use_tensor_cores
=
True
)
(
num_query_heads
//
num_kv_heads
)
>
4
)
)
wrapper
.
plan
(
wrapper
.
plan
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
...
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
...
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
assert
num_query_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
use_tensor_cores
=
(
num_query_heads
//
num_kv_heads
)
>
4
use_tensor_cores
=
True
kv_cache_dtype
=
torch
.
float8_e4m3fn
kv_cache_dtype
=
torch
.
float8_e4m3fn
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
...
...
tests/kernels/attention/test_flashinfer_trtllm_attention.py
View file @
1d353b63
...
@@ -136,9 +136,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
...
@@ -136,9 +136,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
# Baseline Decode
# Baseline Decode
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
workspace_buffer
,
kv_layout
,
use_tensor_cores
=
True
)
kv_layout
,
use_tensor_cores
=
((
num_qo_heads
//
num_kv_heads
)
>
4
))
wrapper
.
plan
(
kv_indptr
,
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_indices
,
kv_last_page_lens
,
kv_last_page_lens
,
...
...
vllm/envs.py
View file @
1d353b63
...
@@ -42,7 +42,6 @@ if TYPE_CHECKING:
...
@@ -42,7 +42,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
Optional
[
bool
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
Optional
[
bool
]
=
None
VLLM_FLASHINFER_FORCE_TENSOR_CORES
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
Optional
[
int
]
=
0
VLLM_CPU_KVCACHE_SPACE
:
Optional
[
int
]
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
...
@@ -465,11 +464,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -465,11 +464,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_USE_FLASHINFER_SAMPLER"
]))
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_USE_FLASHINFER_SAMPLER"
]))
if
"VLLM_USE_FLASHINFER_SAMPLER"
in
os
.
environ
else
None
,
if
"VLLM_USE_FLASHINFER_SAMPLER"
in
os
.
environ
else
None
,
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
,
"0"
))),
# Pipeline stage partition strategy
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION"
:
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
...
@@ -1221,7 +1215,6 @@ def compute_hash() -> str:
...
@@ -1221,7 +1215,6 @@ def compute_hash() -> str:
"VLLM_USE_AITER_UNIFIED_ATTENTION"
,
"VLLM_USE_AITER_UNIFIED_ATTENTION"
,
"VLLM_ATTENTION_BACKEND"
,
"VLLM_ATTENTION_BACKEND"
,
"VLLM_USE_FLASHINFER_SAMPLER"
,
"VLLM_USE_FLASHINFER_SAMPLER"
,
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
,
"VLLM_DISABLED_KERNELS"
,
"VLLM_DISABLED_KERNELS"
,
"VLLM_USE_DEEP_GEMM"
,
"VLLM_USE_DEEP_GEMM"
,
"VLLM_USE_TRTLLM_FP4_GEMM"
,
"VLLM_USE_TRTLLM_FP4_GEMM"
,
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
1d353b63
...
@@ -13,7 +13,6 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
...
@@ -13,7 +13,6 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
from
flashinfer.decode
import
_get_range_buf
,
trtllm_batch_decode_with_kv_cache
from
flashinfer.decode
import
_get_range_buf
,
trtllm_batch_decode_with_kv_cache
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionType
)
AttentionType
)
...
@@ -228,8 +227,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -228,8 +227,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
q_data_type
=
self
.
kv_cache_dtype
self
.
q_data_type
=
self
.
kv_cache_dtype
else
:
else
:
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
self
.
use_tensor_cores
=
(
envs
.
VLLM_FLASHINFER_FORCE_TENSOR_CORES
or
(
self
.
num_qo_heads
//
self
.
num_kv_heads
>
4
))
self
.
_cascade_wrapper
=
None
# Wrapper for cascade attention
self
.
_cascade_wrapper
=
None
# Wrapper for cascade attention
...
@@ -308,7 +305,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -308,7 +305,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_buffer
=
paged_kv_indptr
,
paged_kv_indptr_buffer
=
paged_kv_indptr
,
paged_kv_indices_buffer
=
paged_kv_indices
,
paged_kv_indices_buffer
=
paged_kv_indices
,
paged_kv_last_page_len_buffer
=
paged_kv_last_page_len
,
paged_kv_last_page_len_buffer
=
paged_kv_last_page_len
,
use_tensor_cores
=
self
.
use_tensor_cores
)
# Tensor cores are enabled by default because the perf would be
# atleast as good as cuda cores for all attention ops in latest
# gpus.
use_tensor_cores
=
True
,
)
# save the decode wrapper
# save the decode wrapper
if
use_cudagraph
:
if
use_cudagraph
:
...
@@ -984,52 +985,29 @@ def fast_plan_decode(
...
@@ -984,52 +985,29 @@ def fast_plan_decode(
self
.
_paged_kv_last_page_len_buf
.
copy_
(
last_page_len_cpu
,
self
.
_paged_kv_last_page_len_buf
.
copy_
(
last_page_len_cpu
,
non_blocking
=
True
)
non_blocking
=
True
)
if
self
.
use_tensor_cores
:
qo_indptr_host
=
_get_range_buf
(
batch_size
+
1
,
"cpu"
)
qo_indptr_host
=
_get_range_buf
(
batch_size
+
1
,
"cpu"
)
try
:
try
:
# Make sure we pass exactly 15 arguments for tensor core version
# Make sure we pass exactly 15 arguments for tensor core version
self
.
_plan_info
=
self
.
_cached_module
.
plan
(
self
.
_plan_info
=
self
.
_cached_module
.
plan
(
self
.
_float_workspace_buffer
,
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
qo_indptr_host
,
qo_indptr_host
,
indptr_cpu
,
indptr_cpu
,
seq_lens_cpu
,
seq_lens_cpu
,
batch_size
,
# total_num_rows
batch_size
,
# total_num_rows
batch_size
,
batch_size
,
num_qo_heads
,
num_qo_heads
,
num_kv_heads
,
num_kv_heads
,
page_size
,
page_size
,
self
.
is_cuda_graph_enabled
,
self
.
is_cuda_graph_enabled
,
head_dim
,
head_dim
,
head_dim
,
head_dim
,
False
,
# causal
False
,
# causal
)
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error in tensor core plan:
{
e
}
"
)
from
e
raise
RuntimeError
(
f
"Error in tensor core plan:
{
e
}
"
)
from
e
else
:
try
:
# Make sure we pass exactly 15 arguments for standard version
self
.
_plan_info
=
self
.
_cached_module
.
plan
(
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
indptr_cpu
,
batch_size
,
num_qo_heads
,
num_kv_heads
,
page_size
,
self
.
is_cuda_graph_enabled
,
window_left
,
logits_soft_cap
,
head_dim
,
head_dim
,
torch
.
empty
(
0
,
dtype
=
q_data_type
),
torch
.
empty
(
0
,
dtype
=
kv_data_type
),
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error in standard plan:
{
e
}
"
)
from
e
self
.
_pos_encoding_mode
=
pos_encoding_mode
self
.
_pos_encoding_mode
=
pos_encoding_mode
self
.
_window_left
=
window_left
self
.
_window_left
=
window_left
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment