Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e5ed6c6c
Unverified
Commit
e5ed6c6c
authored
Mar 20, 2026
by
Kaihang Jiang
Committed by
GitHub
Mar 20, 2026
Browse files
[BugFix] Allow qk_nope_head_dim=192 in FlashInfer MLA backend checks (#37475)
Signed-off-by:
Kaihang Jiang
<
kaihangj@nvidia.com
>
parent
b3d0b379
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
8 deletions
+8
-8
vllm/v1/attention/backends/mla/flashinfer_mla.py
vllm/v1/attention/backends/mla/flashinfer_mla.py
+4
-4
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
+4
-4
No files found.
vllm/v1/attention/backends/mla/flashinfer_mla.py
View file @
e5ed6c6c
...
...
@@ -77,17 +77,17 @@ class FlashInferMLABackend(MLACommonBackend):
use_sparse
:
bool
,
device_capability
:
DeviceCapability
,
)
->
str
|
None
:
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128]
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128
, 192
]
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
if
vllm_config
.
model_config
is
not
None
:
hf_text_config
=
vllm_config
.
model_config
.
hf_text_config
qk_nope_head_dim
=
getattr
(
hf_text_config
,
"qk_nope_head_dim"
,
1
)
if
qk_nope_head_dim
not
in
[
64
,
128
]:
if
qk_nope_head_dim
not
in
[
64
,
128
,
192
]:
return
(
f
"FlashInfer MLA kernel requires qk_nope_head_dim
in [64, 128],
"
f
"but got
{
qk_nope_head_dim
}
"
"FlashInfer MLA kernel requires qk_nope_head_dim "
f
"
in [64, 128, 192],
but got
{
qk_nope_head_dim
}
"
)
return
None
...
...
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
View file @
e5ed6c6c
...
...
@@ -113,17 +113,17 @@ class FlashInferMLASparseBackend(AttentionBackend):
use_sparse
:
bool
,
device_capability
:
DeviceCapability
,
)
->
str
|
None
:
# FlashInfer MLA sparse kernel requires qk_nope_head_dim
== 128
# FlashInfer MLA sparse kernel requires qk_nope_head_dim
in [128, 192]
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
if
vllm_config
.
model_config
is
not
None
:
hf_text_config
=
vllm_config
.
model_config
.
hf_text_config
qk_nope_head_dim
=
getattr
(
hf_text_config
,
"qk_nope_head_dim"
,
1
)
if
qk_nope_head_dim
!=
128
:
if
qk_nope_head_dim
not
in
[
128
,
192
]
:
return
(
f
"FlashInfer MLA Sparse kernel requires qk_nope_head_dim
== 128,
"
f
"but got
{
qk_nope_head_dim
}
"
"FlashInfer MLA Sparse kernel requires qk_nope_head_dim "
f
"
in [128, 192],
but got
{
qk_nope_head_dim
}
"
)
# Check for index_topk which indicates sparse model
if
not
hasattr
(
hf_text_config
,
"index_topk"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment