Commit ba0cd35c authored by zhuwenwen's avatar zhuwenwen
Browse files

if kv_cache_detype=="fp8_e4m3", use non fused cat + mla

parent 92058666
......@@ -168,7 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if not envs.VLLM_USE_CAT_MLA:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
......@@ -181,7 +181,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if not envs.VLLM_USE_CAT_MLA:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment