Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ce3b670
Commit
0ce3b670
authored
Dec 16, 2025
by
zhuwenwen
Browse files
add fuse_rmsnorm_rope_quant_gfx938 to support use fp8_e4m3 mla
parent
a9f57e73
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
314 additions
and
192 deletions
+314
-192
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+1
-1
vllm/attention/layer.py
vllm/attention/layer.py
+9
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+6
-0
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+239
-172
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+58
-17
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+1
-1
No files found.
vllm/attention/backends/flashmla.py
View file @
0ce3b670
...
@@ -260,7 +260,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -260,7 +260,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
o
,
_
=
flash_mla_with_kvcache_fp8
(
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
,
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
...
...
vllm/attention/layer.py
View file @
0ce3b670
...
@@ -199,6 +199,8 @@ class Attention(nn.Module):
...
@@ -199,6 +199,8 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
# definition specify the output tensor shape.
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
query_nope
:
Optional
[
torch
.
Size
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -265,7 +267,7 @@ class Attention(nn.Module):
...
@@ -265,7 +267,7 @@ class Attention(nn.Module):
query
,
key
,
value
,
output
,
self
.
layer_name
)
query
,
key
,
value
,
output
,
self
.
layer_name
)
else
:
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
None
,
q_ori
,
key_normed
,
positions
,
weight
,
cos_sin_cache
)
query
,
key
,
value
,
output
,
self
.
layer_name
,
None
,
query_nope
,
num_local_heads
,
q_ori
,
key_normed
,
positions
,
weight
,
cos_sin_cache
)
return
output
.
view
(
-
1
,
hidden_size
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
else
:
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
...
@@ -506,6 +508,8 @@ def unified_attention_with_output(
...
@@ -506,6 +508,8 @@ def unified_attention_with_output(
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -537,6 +541,8 @@ def unified_attention_with_output(
...
@@ -537,6 +541,8 @@ def unified_attention_with_output(
attn_metadata
,
attn_metadata
,
output
=
output
,
output
=
output
,
output_scale
=
output_scale
,
output_scale
=
output_scale
,
query_nope
=
query_nope
,
num_local_heads
=
num_local_heads
,
q_ori
=
q_ori
,
q_ori
=
q_ori
,
key_normed
=
key_normed
,
key_normed
=
key_normed
,
positions
=
positions
,
positions
=
positions
,
...
@@ -566,6 +572,8 @@ else:
...
@@ -566,6 +572,8 @@ else:
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
0ce3b670
...
@@ -667,6 +667,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -667,6 +667,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
positions
=
positions
,
...
@@ -715,6 +717,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -715,6 +717,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
positions
=
positions
,
...
@@ -774,6 +778,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -774,6 +778,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
positions
=
positions
,
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
0ce3b670
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/mla/common.py
View file @
0ce3b670
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
,
fuse_rmsnorm_rope_quant_gfx938
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -1095,6 +1096,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1095,6 +1096,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata
:
M
,
attn_metadata
:
M
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1154,7 +1157,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1154,7 +1157,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale
=
layer
.
_k_scale
,
scale
=
layer
.
_k_scale
,
)
)
else
:
else
:
from
lightop
import
fused_rms_norm_rope_contiguous
if
self
.
kv_cache_dtype
==
"auto"
:
if
self
.
kv_cache_dtype
==
"auto"
:
if
q
.
dtype
==
torch
.
float16
:
if
q
.
dtype
==
torch
.
float16
:
kv_cache_dtype_str
=
"fp16"
kv_cache_dtype_str
=
"fp16"
...
@@ -1163,6 +1165,45 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1163,6 +1165,45 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else
:
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
kv_cache_dtype_str
=
self
.
kv_cache_dtype
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
has_prefill
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
q_tensor
=
torch
.
randn
(
q
.
shape
[
0
],
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
q_quant_gt
=
q_tensor
.
to
(
kv_cache_dtype_str
)
q_quant
=
torch
.
empty_like
(
q_quant_gt
)
fuse_rmsnorm_rope_quant_gfx938
(
positions
[:
num_actual_toks
,
...],
query_nope
,
q
,
q_quant
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
fused_rms_norm_rope_contiguous
(
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
positions
[:
num_actual_toks
,
...],
q
,
q
,
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
0ce3b670
...
@@ -179,7 +179,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -179,7 +179,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment