Unverified Commit e5ed6c6c authored by Kaihang Jiang's avatar Kaihang Jiang Committed by GitHub
Browse files

[BugFix] Allow qk_nope_head_dim=192 in FlashInfer MLA backend checks (#37475)


Signed-off-by: default avatarKaihang Jiang <kaihangj@nvidia.com>
parent b3d0b379
...@@ -77,17 +77,17 @@ class FlashInferMLABackend(MLACommonBackend): ...@@ -77,17 +77,17 @@ class FlashInferMLABackend(MLACommonBackend):
use_sparse: bool, use_sparse: bool,
device_capability: DeviceCapability, device_capability: DeviceCapability,
) -> str | None: ) -> str | None:
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128] # FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128, 192]
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None: if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim not in [64, 128]: if qk_nope_head_dim not in [64, 128, 192]:
return ( return (
f"FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128], " "FlashInfer MLA kernel requires qk_nope_head_dim "
f"but got {qk_nope_head_dim}" f"in [64, 128, 192], but got {qk_nope_head_dim}"
) )
return None return None
......
...@@ -113,17 +113,17 @@ class FlashInferMLASparseBackend(AttentionBackend): ...@@ -113,17 +113,17 @@ class FlashInferMLASparseBackend(AttentionBackend):
use_sparse: bool, use_sparse: bool,
device_capability: DeviceCapability, device_capability: DeviceCapability,
) -> str | None: ) -> str | None:
# FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128 # FlashInfer MLA sparse kernel requires qk_nope_head_dim in [128, 192]
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None: if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim != 128: if qk_nope_head_dim not in [128, 192]:
return ( return (
f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, " "FlashInfer MLA Sparse kernel requires qk_nope_head_dim "
f"but got {qk_nope_head_dim}" f"in [128, 192], but got {qk_nope_head_dim}"
) )
# Check for index_topk which indicates sparse model # Check for index_topk which indicates sparse model
if not hasattr(hf_text_config, "index_topk"): if not hasattr(hf_text_config, "index_topk"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment