Commit 99981972 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove fuse_rmsnorm_rope_quant_gfx938

parent 0ce3b670
...@@ -90,20 +90,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -90,20 +90,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size) batch_size)
if m.num_decode_tokens > 0: if m.num_decode_tokens > 0:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8: m.decode_tile_scheduler_metadata, m.decode_num_splits = \
m.decode_tile_scheduler_metadata, m.decode_num_splits = \ get_mla_metadata(
get_mla_decoding_metadata_dense_fp8( m.seq_lens_tensor[m.num_prefills:],
m.seq_lens_tensor[m.num_prefills:], self.num_q_heads,
self.num_q_heads, 1, # MQA for the decode path
1, # MQA for the decode path )
)
else:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
return m return m
...@@ -118,22 +110,13 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -118,22 +110,13 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes # Run a dummy `get_mla_metadata` so we can get the right shapes
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8: self._graph_decoder_tile_scheduler_metadata, \
self._graph_decoder_tile_scheduler_metadata, \ self._graph_decode_num_splits = get_mla_metadata(
self._graph_decode_num_splits = get_mla_decoding_metadata_dense_fp8( torch.ones(
torch.ones( max_batch_size, dtype=torch.int32, device=self.runner.device),
max_batch_size, dtype=torch.int32, device=self.runner.device), self.num_q_heads,
self.num_q_heads, 1, # MQA for the decode path
1, # MQA for the decode path )
)
else:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
with super().graph_capture(max_batch_size): with super().graph_capture(max_batch_size):
yield yield
...@@ -147,18 +130,11 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -147,18 +130,11 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size, is_encoder_decoder_model) batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0 assert metadata.num_decode_tokens > 0
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8: decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_decoding_metadata_dense_fp8( self._graph_seq_lens[:batch_size],
self._graph_seq_lens[:batch_size], self.num_q_heads,
self.num_q_heads, 1, # MQA for the decode path
1, # MQA for the decode path )
)
else:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
self._graph_decoder_tile_scheduler_metadata.copy_( self._graph_decoder_tile_scheduler_metadata.copy_(
decoder_tile_scheduler_metadata) decoder_tile_scheduler_metadata)
......
...@@ -198,8 +198,6 @@ class Attention(nn.Module): ...@@ -198,8 +198,6 @@ class Attention(nn.Module):
# For some alternate attention backends like MLA the attention output # For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model # shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape. # definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
query_nope: Optional[torch.Size] = None,
num_local_heads: Optional[int] = None, num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
...@@ -267,7 +265,7 @@ class Attention(nn.Module): ...@@ -267,7 +265,7 @@ class Attention(nn.Module):
query, key, value, output, self.layer_name) query, key, value, output, self.layer_name)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, query_nope, num_local_heads, q_ori, key_normed, positions, weight, cos_sin_cache) query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
if self.use_direct_call: if self.use_direct_call:
...@@ -508,8 +506,6 @@ def unified_attention_with_output( ...@@ -508,8 +506,6 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -541,8 +537,6 @@ def unified_attention_with_output( ...@@ -541,8 +537,6 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
query_nope=query_nope,
num_local_heads=num_local_heads,
q_ori=q_ori, q_ori=q_ori,
key_normed=key_normed, key_normed=key_normed,
positions=positions, positions=positions,
...@@ -572,8 +566,6 @@ else: ...@@ -572,8 +566,6 @@ else:
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
......
...@@ -667,8 +667,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -667,8 +667,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -717,8 +715,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -717,8 +715,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -778,8 +774,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -778,8 +774,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
......
...@@ -217,7 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,7 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_gfx938
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1164,61 +1163,22 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1164,61 +1163,22 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16" kv_cache_dtype_str = "bf16"
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8: fused_rms_norm_rope_contiguous(
if has_prefill: positions[:num_actual_toks, ...],
fused_rms_norm_rope_contiguous( q,
positions[:num_actual_toks, ...], k_pe.squeeze(1),
q, k_c_normed, # not normed
k_pe.squeeze(1), key_normed[:num_actual_toks, ...], # normed
k_c_normed, # not normed weight,
key_normed[:num_actual_toks, ...], # normed cos_sin_cache,
weight, attn_metadata.slot_mapping.flatten(),
cos_sin_cache, kv_cache,
attn_metadata.slot_mapping.flatten(), kv_cache_dtype_str,
kv_cache, 1.0,
kv_cache_dtype_str, False,
1.0, 1e-6,
False, )
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant_gt = q_tensor.to(kv_cache_dtype_str)
q_quant = torch.empty_like(q_quant_gt)
fuse_rmsnorm_rope_quant_gfx938(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill: if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......
...@@ -73,20 +73,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -73,20 +73,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor, def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8: tile_scheduler_metadata, num_splits = \
tile_scheduler_metadata, num_splits = \ get_mla_metadata(
get_mla_decoding_metadata_dense_fp8( seq_lens,
seq_lens, self.num_q_heads,
self.num_q_heads, 1, # MQA for the decode path
1, # MQA for the decode path )
)
else:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if self.runner.full_cuda_graph: if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer # First time around (CUDAGraph capture), allocate the static buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment