Commit ad7c14d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

support fuse cat + q to fp8 + mla

parent ab674544
...@@ -1318,7 +1318,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1318,7 +1318,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False, False,
1e-6, 1e-6,
) )
else: if has_decode:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device) q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device) q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device) q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
......
...@@ -186,7 +186,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -186,7 +186,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
o, _ = flash_mla_with_kvcache_fp8_with_cat( o, _ = flash_mla_with_kvcache_fp8_with_cat(
q_nope=q_nope.unsqueeze(1), q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1), q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
...@@ -199,32 +199,49 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -199,32 +199,49 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k = k_scale, descale_k = k_scale,
) )
else: else:
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_CAT_MLA:
if q_nope.shape[0] < 1024: o, _ = flash_mla_with_kvcache_fp8_with_cat(
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode q_nope=q_nope.unsqueeze(1),
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q_pe=q_pe.unsqueeze(1),
.unsqueeze(1) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q = q_scale,
descale_k = k_scale,
)
else:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else: else:
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
else: o, _ = flash_mla_with_kvcache_fp8(
q = torch.cat([q_nope, q_pe], dim=-1)\ q=q.to(torch.float8_e4m3fn),
.unsqueeze(1) # Add seqlen dim of 1 (decode) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
o, _ = flash_mla_with_kvcache_fp8( block_table=attn_metadata.decode.block_table,
q=q.to(torch.float8_e4m3fn), cache_seqlens=attn_metadata.decode.seq_lens,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1 head_dim_v=self.kv_lora_rank,
block_table=attn_metadata.decode.block_table, tile_scheduler_metadata=attn_metadata.decode.
cache_seqlens=attn_metadata.decode.seq_lens, tile_scheduler_metadata,
head_dim_v=self.kv_lora_rank, num_splits=attn_metadata.decode.num_splits,
tile_scheduler_metadata=attn_metadata.decode. softmax_scale=self.scale,
tile_scheduler_metadata, causal=True,
num_splits=attn_metadata.decode.num_splits, descale_q=q_scale,
softmax_scale=self.scale, descale_k=k_scale,
causal=True, )
descale_q=q_scale,
descale_k=k_scale,
)
else: else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3": if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment