Commit da85feb7 authored by zhuwenwen's avatar zhuwenwen
Browse files

convert q to float8_e4m3fn

parent 99981972
...@@ -238,7 +238,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -238,7 +238,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8( o, _ = flash_mla_with_kvcache_fp8(
q=q, q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor, cache_seqlens=decode_meta.seq_lens_tensor,
......
...@@ -198,7 +198,7 @@ class Attention(nn.Module): ...@@ -198,7 +198,7 @@ class Attention(nn.Module):
# For some alternate attention backends like MLA the attention output # For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model # shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape. # definition specify the output tensor shape.
num_local_heads: Optional[int] = None, output_shape: Optional[torch.Size] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
......
...@@ -1163,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1163,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16" kv_cache_dtype_str = "bf16"
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
from lightop import fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous( fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...], positions[:num_actual_toks, ...],
q, q,
......
...@@ -185,7 +185,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -185,7 +185,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8( o, _ = flash_mla_with_kvcache_fp8(
q=q, q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment