Unverified Commit caf8b1c0 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix MTP+FlashInfer crash when trtllm kernels are available but disabled (#26361)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
Signed-off-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 1b86bd8e
...@@ -220,6 +220,8 @@ def force_use_trtllm_attention() -> bool | None: ...@@ -220,6 +220,8 @@ def force_use_trtllm_attention() -> bool | None:
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
"""Check if the current configuration supports TRTLLM attention.""" """Check if the current configuration supports TRTLLM attention."""
if force_use_trtllm_attention() is False:
return False
has_trtllm = supports_trtllm_attention() has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0) return has_trtllm and (num_qo_heads % num_kv_heads == 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment