Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bbe0574d
Unverified
Commit
bbe0574d
authored
Feb 04, 2026
by
zhanqiuhu
Committed by
GitHub
Feb 05, 2026
Browse files
[Bugfix] Disable TRTLLM attention when KV transfer is enabled (#33192)
Signed-off-by:
Zhanqiu Hu
<
zh338@cornell.edu
>
parent
4d951353
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
0 deletions
+17
-0
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+17
-0
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
bbe0574d
...
@@ -573,6 +573,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -573,6 +573,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm
=
can_use_trtllm_attention
(
self
.
num_qo_heads
,
self
.
num_kv_heads
)
can_use_trtllm
=
can_use_trtllm_attention
(
self
.
num_qo_heads
,
self
.
num_kv_heads
)
# TRTLLM attention requires strictly contiguous KV cache tensors.
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
# permuted into non-contiguous views, which causes assertion failures.
self
.
_kv_transfer_enabled
=
vllm_config
.
kv_transfer_config
is
not
None
if
can_use_trtllm
and
self
.
_kv_transfer_enabled
:
logger
.
info_once
(
"TRTLLM attention is disabled because KV transfer "
"(P/D disaggregation) is enabled. TRTLLM attention requires "
"strictly contiguous KV cache tensors which may not be "
"guaranteed with KV transfer."
)
can_use_trtllm
=
False
if
(
if
(
can_use_trtllm
can_use_trtllm
and
not
vllm_config
.
attention_config
.
disable_flashinfer_q_quantization
and
not
vllm_config
.
attention_config
.
disable_flashinfer_q_quantization
...
@@ -822,6 +836,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -822,6 +836,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
has_sinks
=
self
.
has_sinks
,
has_sinks
=
self
.
has_sinks
,
has_spec
=
uses_spec_reorder
,
has_spec
=
uses_spec_reorder
,
)
)
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
if
self
.
_kv_transfer_enabled
:
prefill_use_trtllm
=
False
decode_use_trtllm
=
(
decode_use_trtllm
=
(
self
.
use_trtllm_decode_attention
and
self
.
dcp_world_size
<=
1
self
.
use_trtllm_decode_attention
and
self
.
dcp_world_size
<=
1
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment