Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99981972
Commit
99981972
authored
Dec 17, 2025
by
zhuwenwen
Browse files
remove fuse_rmsnorm_rope_quant_gfx938
parent
0ce3b670
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
127 deletions
+41
-127
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+18
-42
vllm/attention/layer.py
vllm/attention/layer.py
+1
-9
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+0
-6
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+16
-56
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+6
-14
No files found.
vllm/attention/backends/flashmla.py
View file @
99981972
...
...
@@ -90,20 +90,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size
)
if
m
.
num_decode_tokens
>
0
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_decoding_metadata_dense_fp8
(
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_metadata
(
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_metadata
(
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
return
m
...
...
@@ -118,22 +110,13 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
# Run a dummy `get_mla_metadata` so we can get the right shapes
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
self
.
_graph_decoder_tile_scheduler_metadata
,
\
self
.
_graph_decode_num_splits
=
get_mla_decoding_metadata_dense_fp8
(
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
self
.
_graph_decoder_tile_scheduler_metadata
,
\
self
.
_graph_decode_num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
self
.
_graph_decoder_tile_scheduler_metadata
,
\
self
.
_graph_decode_num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
with
super
().
graph_capture
(
max_batch_size
):
yield
...
...
@@ -147,18 +130,11 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size
,
is_encoder_decoder_model
)
assert
metadata
.
num_decode_tokens
>
0
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_decoding_metadata_dense_fp8
(
self
.
_graph_seq_lens
[:
batch_size
],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_metadata
(
self
.
_graph_seq_lens
[:
batch_size
],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_metadata
(
self
.
_graph_seq_lens
[:
batch_size
],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
self
.
_graph_decoder_tile_scheduler_metadata
.
copy_
(
decoder_tile_scheduler_metadata
)
...
...
vllm/attention/layer.py
View file @
99981972
...
...
@@ -198,8 +198,6 @@ class Attention(nn.Module):
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
query_nope
:
Optional
[
torch
.
Size
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -267,7 +265,7 @@ class Attention(nn.Module):
query
,
key
,
value
,
output
,
self
.
layer_name
)
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
None
,
query_nope
,
num_local_heads
,
q_ori
,
key_normed
,
positions
,
weight
,
cos_sin_cache
)
query
,
key
,
value
,
output
,
self
.
layer_name
,
None
,
q_ori
,
key_normed
,
positions
,
weight
,
cos_sin_cache
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
if
self
.
use_direct_call
:
...
...
@@ -508,8 +506,6 @@ def unified_attention_with_output(
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -541,8 +537,6 @@ def unified_attention_with_output(
attn_metadata
,
output
=
output
,
output_scale
=
output_scale
,
query_nope
=
query_nope
,
num_local_heads
=
num_local_heads
,
q_ori
=
q_ori
,
key_normed
=
key_normed
,
positions
=
positions
,
...
...
@@ -572,8 +566,6 @@ else:
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
99981972
...
...
@@ -667,8 +667,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
...
...
@@ -717,8 +715,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
...
...
@@ -778,8 +774,6 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
99981972
...
...
@@ -217,7 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
,
fuse_rmsnorm_rope_quant_gfx938
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -1164,61 +1163,22 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
has_prefill
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
q_tensor
=
torch
.
randn
(
q
.
shape
[
0
],
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
q_quant_gt
=
q_tensor
.
to
(
kv_cache_dtype_str
)
q_quant
=
torch
.
empty_like
(
q_quant_gt
)
fuse_rmsnorm_rope_quant_gfx938
(
positions
[:
num_actual_toks
,
...],
query_nope
,
q
,
q_quant
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
if
has_prefill
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
99981972
...
...
@@ -73,20 +73,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
tile_scheduler_metadata
,
num_splits
=
\
get_mla_decoding_metadata_dense_fp8
(
seq_lens
,
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
tile_scheduler_metadata
,
num_splits
=
\
get_mla_metadata
(
seq_lens
,
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
tile_scheduler_metadata
,
num_splits
=
\
get_mla_metadata
(
seq_lens
,
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
if
self
.
runner
.
full_cuda_graph
:
# First time around (CUDAGraph capture), allocate the static buffer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment