Commit 0ce3b670 authored by zhuwenwen's avatar zhuwenwen
Browse files

add fuse_rmsnorm_rope_quant_gfx938 to support use fp8_e4m3 mla

parent a9f57e73
...@@ -260,7 +260,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -260,7 +260,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8( o, _ = flash_mla_with_kvcache_fp8(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
...@@ -199,6 +199,8 @@ class Attention(nn.Module): ...@@ -199,6 +199,8 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model # shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape. # definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None, output_shape: Optional[torch.Size] = None,
query_nope: Optional[torch.Size] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -265,7 +267,7 @@ class Attention(nn.Module): ...@@ -265,7 +267,7 @@ class Attention(nn.Module):
query, key, value, output, self.layer_name) query, key, value, output, self.layer_name)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache) query, key, value, output, self.layer_name, None, query_nope, num_local_heads, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
if self.use_direct_call: if self.use_direct_call:
...@@ -506,6 +508,8 @@ def unified_attention_with_output( ...@@ -506,6 +508,8 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -537,6 +541,8 @@ def unified_attention_with_output( ...@@ -537,6 +541,8 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
query_nope=query_nope,
num_local_heads=num_local_heads,
q_ori=q_ori, q_ori=q_ori,
key_normed=key_normed, key_normed=key_normed,
positions=positions, positions=positions,
...@@ -566,6 +572,8 @@ else: ...@@ -566,6 +572,8 @@ else:
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
......
...@@ -667,6 +667,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -667,6 +667,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -715,6 +717,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -715,6 +717,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -774,6 +778,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -774,6 +778,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
......
This diff is collapsed.
...@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_gfx938
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1095,6 +1096,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1095,6 +1096,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -1154,7 +1157,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1154,7 +1157,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale, scale=layer._k_scale,
) )
else: else:
from lightop import fused_rms_norm_rope_contiguous
if self.kv_cache_dtype == "auto": if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16: if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16" kv_cache_dtype_str = "fp16"
...@@ -1163,6 +1165,45 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1163,6 +1165,45 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if has_prefill:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant_gt = q_tensor.to(kv_cache_dtype_str)
q_quant = torch.empty_like(q_quant_gt)
fuse_rmsnorm_rope_quant_gfx938(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous( fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...], positions[:num_actual_toks, ...],
q, q,
......
...@@ -179,7 +179,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -179,7 +179,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024: if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment