Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1001522
Unverified
Commit
a1001522
authored
Dec 17, 2025
by
Ye (Charlotte) Qi
Committed by
GitHub
Dec 17, 2025
Browse files
[Kernels][FI] Skip trtllm attention when num_kv_heads=1 (#30842)
Signed-off-by:
Ye (Charlotte) Qi
<
yeq@meta.com
>
parent
4c054d89
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
1 deletion
+56
-1
tests/kernels/attention/test_flashinfer_trtllm_attention.py
tests/kernels/attention/test_flashinfer_trtllm_attention.py
+35
-0
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+21
-1
No files found.
tests/kernels/attention/test_flashinfer_trtllm_attention.py
View file @
a1001522
...
@@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline(
...
@@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline(
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
atol
,
rtol
=
rtol
),
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
atol
,
rtol
=
rtol
),
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
,
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
,
)
)
def
test_trtllm_attention_rejects_num_kv_heads_1
()
->
None
:
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
When num_kv_heads=1 (MQA), the KV cache strides become degenerate
(stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled
to fail because TMA descriptors cannot handle degenerate 4D tensors with
singleton dimensions.
This test verifies that can_use_trtllm_attention returns False for
num_kv_heads=1 configurations.
"""
from
vllm.utils.flashinfer
import
can_use_trtllm_attention
# num_kv_heads=1 should be rejected
assert
not
can_use_trtllm_attention
(
num_qo_heads
=
64
,
num_kv_heads
=
1
),
(
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
assert
not
can_use_trtllm_attention
(
num_qo_heads
=
32
,
num_kv_heads
=
1
),
(
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
# num_kv_heads > 1 should be accepted (if platform supports it)
# Note: This may return False on non-Blackwell platforms, which is fine
result_kv8
=
can_use_trtllm_attention
(
num_qo_heads
=
64
,
num_kv_heads
=
8
)
result_kv1
=
can_use_trtllm_attention
(
num_qo_heads
=
64
,
num_kv_heads
=
1
)
# Even if platform doesn't support TRTLLM, num_kv_heads=1 should never
# return True when num_kv_heads > 1 returns True
if
result_kv8
:
assert
not
result_kv1
,
(
"If TRTLLM is supported for num_kv_heads=8, "
"it must be rejected for num_kv_heads=1"
)
vllm/utils/flashinfer.py
View file @
a1001522
...
@@ -305,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
...
@@ -305,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
if
force_use_trtllm_attention
()
is
False
:
if
force_use_trtllm_attention
()
is
False
:
return
False
return
False
has_trtllm
=
supports_trtllm_attention
()
has_trtllm
=
supports_trtllm_attention
()
return
has_trtllm
and
(
num_qo_heads
%
num_kv_heads
==
0
)
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if
has_trtllm
and
num_kv_heads
==
1
:
logger
.
warning_once
(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return
has_trtllm
and
(
num_qo_heads
%
num_kv_heads
==
0
)
and
(
num_kv_heads
!=
1
)
def
use_trtllm_attention
(
def
use_trtllm_attention
(
...
@@ -355,6 +366,15 @@ def use_trtllm_attention(
...
@@ -355,6 +366,15 @@ def use_trtllm_attention(
)
)
return
False
return
False
# num_kv_heads=1 is not supported
if
num_kv_heads
==
1
:
if
force_use_trtllm
:
logger
.
warning_once
(
"TRTLLM attention does not support num_kv_heads=1, "
"but --attention-config.use_trtllm_attention is set to 1"
)
return
False
if
has_spec
and
not
is_prefill
:
if
has_spec
and
not
is_prefill
:
# Speculative decoding requires TRTLLM attention for decodes
# Speculative decoding requires TRTLLM attention for decodes
logger
.
info_once
(
"Using TRTLLM attention (enabled for speculative decoding)."
)
logger
.
info_once
(
"Using TRTLLM attention (enabled for speculative decoding)."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment