Unverified Commit a076ec1a authored by b8zhong's avatar b8zhong Committed by GitHub
Browse files

Revert "fix llama4 kv cache layout" (#12437)

parent 72b5f3d0
......@@ -21,7 +21,7 @@ The support matrix is split into two parts: MHA (standard attention) and MLA (mu
| **Triton** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |
| **Torch Native (SDPA)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlexAttention (PyTorch)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **TRTLLM MHA** | 16, 32 or 64 | | ✅ | ❌ | ❌ | ❌ |
| **TRTLLM MHA** | 16, 32 or 64 | | ✅ | ❌ | ❌ | ❌ |
| **Dual Chunk FlashAttention** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **AITER (ROCm)** | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
......
......@@ -980,13 +980,6 @@ class ServerArgs:
logger.warning(
"Use trtllm_mha as attention backend on sm100 for Llama4 model"
)
if is_sm100_supported() and self.attention_backend == "trtllm_mha":
# TODO(brayden): remove this once TRTLLM MHA kernel for FP8 w/ tileSizeKv=128 is available.
# This is a Llama 4 specific issue only.
self.kv_cache_dtype = "bfloat16"
logger.warning(
"Setting kv_cache_dtype to bfloat16 for Llama4 with trtllm_mha backend, due to a missing FlashInfer TRTLLM MHA kernel for FP8 KV Cache"
)
if is_sm100_supported() and self.moe_runner_backend == "auto":
if self.quantization in {"fp8", "modelopt_fp8"}:
self.moe_runner_backend = "flashinfer_trtllm"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment