Commit ad7c14d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

support fuse cat + q to fp8 + mla

parent ab674544
...@@ -1318,7 +1318,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1318,7 +1318,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False, False,
1e-6, 1e-6,
) )
else: if has_decode:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device) q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device) q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device) q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
......
...@@ -186,7 +186,24 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -186,7 +186,24 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
o, _ = flash_mla_with_kvcache_fp8_with_cat( o, _ = flash_mla_with_kvcache_fp8_with_cat(
q_nope=q_nope.unsqueeze(1), q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1), q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q = q_scale,
descale_k = k_scale,
)
else:
if envs.VLLM_USE_CAT_MLA:
o, _ = flash_mla_with_kvcache_fp8_with_cat(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment