"vscode:/vscode.git/clone" did not exist on "0e237f00357c968a4f7ae25accd533e924baceff"
Commit 99981972 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove fuse_rmsnorm_rope_quant_gfx938

parent 0ce3b670
......@@ -90,20 +90,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size)
if m.num_decode_tokens > 0:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_decoding_metadata_dense_fp8(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
else:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
return m
......@@ -118,22 +110,13 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@contextmanager
def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_decoding_metadata_dense_fp8(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
else:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
with super().graph_capture(max_batch_size):
yield
......@@ -147,18 +130,11 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_decoding_metadata_dense_fp8(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
else:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
self._graph_decoder_tile_scheduler_metadata.copy_(
decoder_tile_scheduler_metadata)
......
......@@ -198,8 +198,6 @@ class Attention(nn.Module):
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
query_nope: Optional[torch.Size] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
......@@ -267,7 +265,7 @@ class Attention(nn.Module):
query, key, value, output, self.layer_name)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, query_nope, num_local_heads, q_ori, key_normed, positions, weight, cos_sin_cache)
query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
......@@ -508,8 +506,6 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......@@ -541,8 +537,6 @@ def unified_attention_with_output(
attn_metadata,
output=output,
output_scale=output_scale,
query_nope=query_nope,
num_local_heads=num_local_heads,
q_ori=q_ori,
key_normed=key_normed,
positions=positions,
......@@ -572,8 +566,6 @@ else:
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......
......@@ -667,8 +667,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
......@@ -717,8 +715,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
......@@ -778,8 +774,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
......
......@@ -217,7 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_gfx938
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -1164,61 +1163,22 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if has_prefill:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant_gt = q_tensor.to(kv_cache_dtype_str)
q_quant = torch.empty_like(q_quant_gt)
fuse_rmsnorm_rope_quant_gfx938(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......
......@@ -73,20 +73,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
tile_scheduler_metadata, num_splits = \
get_mla_decoding_metadata_dense_fp8(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
else:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment