Unverified Commit 1c46dea0 authored by shyeh25's avatar shyeh25 Committed by GitHub
Browse files

Revert "[Kernels][FI] Skip trtllm attention when num_kv_heads=1 (#308… (#31617)


Signed-off-by: default avatarshyeh25 <206795756+shyeh25@users.noreply.github.com>
parent 02859973
...@@ -456,38 +456,3 @@ def test_flashinfer_trtllm_prefill_with_baseline( ...@@ -456,38 +456,3 @@ def test_flashinfer_trtllm_prefill_with_baseline(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}", f"{torch.max(torch.abs(output - output_trtllm))}",
) )
def test_trtllm_attention_rejects_num_kv_heads_1(default_vllm_config) -> None:
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
When num_kv_heads=1 (MQA), the KV cache strides become degenerate
(stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled
to fail because TMA descriptors cannot handle degenerate 4D tensors with
singleton dimensions.
This test verifies that can_use_trtllm_attention returns False for
num_kv_heads=1 configurations.
"""
from vllm.utils.flashinfer import can_use_trtllm_attention
# num_kv_heads=1 should be rejected
assert not can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
assert not can_use_trtllm_attention(num_qo_heads=32, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
# num_kv_heads > 1 should be accepted (if platform supports it)
# Note: This may return False on non-Blackwell platforms, which is fine
result_kv8 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=8)
result_kv1 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1)
# Even if platform doesn't support TRTLLM, num_kv_heads=1 should never
# return True when num_kv_heads > 1 returns True
if result_kv8:
assert not result_kv1, (
"If TRTLLM is supported for num_kv_heads=8, "
"it must be rejected for num_kv_heads=1"
)
...@@ -305,18 +305,7 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: ...@@ -305,18 +305,7 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
if force_use_trtllm_attention() is False: if force_use_trtllm_attention() is False:
return False return False
has_trtllm = supports_trtllm_attention() has_trtllm = supports_trtllm_attention()
# num_kv_heads=1 is not supported due to TMA descriptor building limitations. return has_trtllm and (num_qo_heads % num_kv_heads == 0)
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if has_trtllm and num_kv_heads == 1:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1)
def use_trtllm_attention( def use_trtllm_attention(
...@@ -366,15 +355,6 @@ def use_trtllm_attention( ...@@ -366,15 +355,6 @@ def use_trtllm_attention(
) )
return False return False
# num_kv_heads=1 is not supported
if num_kv_heads == 1:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1, "
"but --attention-config.use_trtllm_attention is set to 1"
)
return False
if has_spec and not is_prefill: if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes # Speculative decoding requires TRTLLM attention for decodes
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).") logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment