Commit 99981972 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove fuse_rmsnorm_rope_quant_gfx938

parent 0ce3b670
...@@ -90,14 +90,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -90,14 +90,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size) batch_size)
if m.num_decode_tokens > 0: if m.num_decode_tokens > 0:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_decoding_metadata_dense_fp8(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
else:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \ m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata( get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:], m.seq_lens_tensor[m.num_prefills:],
...@@ -118,15 +110,6 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -118,15 +110,6 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes # Run a dummy `get_mla_metadata` so we can get the right shapes
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_decoding_metadata_dense_fp8(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
else:
self._graph_decoder_tile_scheduler_metadata, \ self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata( self._graph_decode_num_splits = get_mla_metadata(
torch.ones( torch.ones(
...@@ -147,13 +130,6 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -147,13 +130,6 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size, is_encoder_decoder_model) batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0 assert metadata.num_decode_tokens > 0
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_decoding_metadata_dense_fp8(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
else:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size], self._graph_seq_lens[:batch_size],
self.num_q_heads, self.num_q_heads,
......
...@@ -198,8 +198,6 @@ class Attention(nn.Module): ...@@ -198,8 +198,6 @@ class Attention(nn.Module):
# For some alternate attention backends like MLA the attention output # For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model # shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape. # definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
query_nope: Optional[torch.Size] = None,
num_local_heads: Optional[int] = None, num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
...@@ -267,7 +265,7 @@ class Attention(nn.Module): ...@@ -267,7 +265,7 @@ class Attention(nn.Module):
query, key, value, output, self.layer_name) query, key, value, output, self.layer_name)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, query_nope, num_local_heads, q_ori, key_normed, positions, weight, cos_sin_cache) query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
if self.use_direct_call: if self.use_direct_call:
...@@ -508,8 +506,6 @@ def unified_attention_with_output( ...@@ -508,8 +506,6 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -541,8 +537,6 @@ def unified_attention_with_output( ...@@ -541,8 +537,6 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
query_nope=query_nope,
num_local_heads=num_local_heads,
q_ori=q_ori, q_ori=q_ori,
key_normed=key_normed, key_normed=key_normed,
positions=positions, positions=positions,
...@@ -572,8 +566,6 @@ else: ...@@ -572,8 +566,6 @@ else:
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
......
...@@ -667,8 +667,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -667,8 +667,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -717,8 +715,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -717,8 +715,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -778,8 +774,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -778,8 +774,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
......
...@@ -217,7 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,7 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_gfx938
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1165,45 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1165,45 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if has_prefill:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant_gt = q_tensor.to(kv_cache_dtype_str)
q_quant = torch.empty_like(q_quant_gt)
fuse_rmsnorm_rope_quant_gfx938(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous( fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...], positions[:num_actual_toks, ...],
q, q,
......
...@@ -73,14 +73,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -73,14 +73,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor, def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
tile_scheduler_metadata, num_splits = \
get_mla_decoding_metadata_dense_fp8(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
else:
tile_scheduler_metadata, num_splits = \ tile_scheduler_metadata, num_splits = \
get_mla_metadata( get_mla_metadata(
seq_lens, seq_lens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment