Commit 18341f14 authored by zhuwenwen's avatar zhuwenwen
Browse files

support fuse cat + q to fp8 + mla

parent bac201c9
......@@ -277,6 +277,61 @@ def flash_mla_with_kvcache_fp8(
)
return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
#
# TODO: Add fake functions
#
......
......@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
flash_mla_with_kvcache_fp8_with_cat,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm.logger import init_logger
......@@ -181,32 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
if envs.VLLM_USE_CAT_MLA:
o, _ = flash_mla_with_kvcache_fp8_with_cat(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q = q_scale,
descale_k = k_scale,
)
else:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if envs.VLLM_USE_OPT_CAT:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment