Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ce3b670
Commit
0ce3b670
authored
Dec 16, 2025
by
zhuwenwen
Browse files
add fuse_rmsnorm_rope_quant_gfx938 to support use fp8_e4m3 mla
parent
a9f57e73
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
314 additions
and
192 deletions
+314
-192
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+1
-1
vllm/attention/layer.py
vllm/attention/layer.py
+9
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+6
-0
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+239
-172
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+58
-17
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+1
-1
No files found.
vllm/attention/backends/flashmla.py
View file @
0ce3b670
...
@@ -260,7 +260,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -260,7 +260,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
o
,
_
=
flash_mla_with_kvcache_fp8
(
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
,
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
...
...
vllm/attention/layer.py
View file @
0ce3b670
...
@@ -199,6 +199,8 @@ class Attention(nn.Module):
...
@@ -199,6 +199,8 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
# definition specify the output tensor shape.
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
query_nope
:
Optional
[
torch
.
Size
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -265,7 +267,7 @@ class Attention(nn.Module):
...
@@ -265,7 +267,7 @@ class Attention(nn.Module):
query
,
key
,
value
,
output
,
self
.
layer_name
)
query
,
key
,
value
,
output
,
self
.
layer_name
)
else
:
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
None
,
q_ori
,
key_normed
,
positions
,
weight
,
cos_sin_cache
)
query
,
key
,
value
,
output
,
self
.
layer_name
,
None
,
query_nope
,
num_local_heads
,
q_ori
,
key_normed
,
positions
,
weight
,
cos_sin_cache
)
return
output
.
view
(
-
1
,
hidden_size
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
else
:
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
...
@@ -506,6 +508,8 @@ def unified_attention_with_output(
...
@@ -506,6 +508,8 @@ def unified_attention_with_output(
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -537,6 +541,8 @@ def unified_attention_with_output(
...
@@ -537,6 +541,8 @@ def unified_attention_with_output(
attn_metadata
,
attn_metadata
,
output
=
output
,
output
=
output
,
output_scale
=
output_scale
,
output_scale
=
output_scale
,
query_nope
=
query_nope
,
num_local_heads
=
num_local_heads
,
q_ori
=
q_ori
,
q_ori
=
q_ori
,
key_normed
=
key_normed
,
key_normed
=
key_normed
,
positions
=
positions
,
positions
=
positions
,
...
@@ -566,6 +572,8 @@ else:
...
@@ -566,6 +572,8 @@ else:
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
0ce3b670
...
@@ -667,6 +667,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -667,6 +667,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
positions
=
positions
,
...
@@ -715,6 +717,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -715,6 +717,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
positions
=
positions
,
...
@@ -774,6 +778,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -774,6 +778,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
self
.
num_local_heads
*
self
.
v_head_dim
),
query_nope
=
q
[...,
:
self
.
qk_nope_head_dim
],
num_local_heads
=
self
.
num_local_heads
,
q_ori
=
q
,
q_ori
=
q
,
key_normed
=
kv_c_normed
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
positions
=
positions
,
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
0ce3b670
...
@@ -22,19 +22,22 @@
...
@@ -22,19 +22,22 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
import
typing
from
collections.abc
import
Callable
,
Iterable
from
itertools
import
islice
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
os
import
os
import
re
import
re
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
@@ -48,17 +51,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -48,17 +51,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
...
@@ -105,49 +108,86 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -105,49 +108,86 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
vllm_config
:
VllmConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
parallel_config
=
vllm_config
.
parallel_config
quant_config
=
vllm_config
.
quant_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
ep_group
=
get_ep_group
().
device_group
self
.
ep_rank
=
self
.
ep_group
.
rank
()
self
.
ep_size
=
self
.
ep_group
.
size
()
self
.
n_routed_experts
=
config
.
num_experts
self
.
is_sequence_parallel
=
parallel_config
.
use_sequence_parallel_moe
if
self
.
tp_size
>
config
.
num_experts
:
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
)
f
"the number of experts
{
config
.
num_experts
}
."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
# Load balancing settings.
vllm_config
=
get_current_vllm_config
()
eplb_config
=
vllm_config
.
parallel_config
.
eplb_config
self
.
enable_eplb
=
parallel_config
.
enable_eplb
self
.
n_logical_experts
=
self
.
n_routed_experts
self
.
n_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
n_physical_experts
=
(
self
.
n_logical_experts
+
self
.
n_redundant_experts
)
self
.
n_local_physical_experts
=
self
.
n_physical_experts
//
self
.
ep_size
self
.
physical_expert_start
=
(
self
.
ep_rank
*
self
.
n_local_physical_experts
)
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
num_experts
=
self
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
Fals
e
,
reduce_results
=
Tru
e
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
prefix
=
f
"
{
prefix
}
.experts"
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
config
.
num_experts
,
bias
=
False
,
bias
=
False
,
quant_config
=
None
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate"
)
prefix
=
f
"
{
prefix
}
.gate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
assert
hidden_states
.
dim
(
orig_shape
=
hidden_states
.
shape
)
<=
2
,
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
hidden_dim
=
hidden_states
.
shape
[
-
1
]
is_input_1d
=
hidden_states
.
dim
()
==
1
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
is_sequence_parallel
:
hidden_states
=
sequence_parallel_chunk
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
self
.
tp_size
>
1
:
if
self
.
is_sequence_parallel
:
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
# noqa E501
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
)
final_hidden_states
,
0
)
final_hidden_states
=
final_hidden_states
[:
num_tokens
]
return
final_hidden_states
.
view
(
orig_shape
)
# return to 1d if input is 1d
return
final_hidden_states
.
squeeze
(
0
)
if
is_input_1d
else
\
final_hidden_states
class
Qwen3MoeAttention
(
nn
.
Module
):
class
Qwen3MoeAttention
(
nn
.
Module
):
...
@@ -166,6 +206,7 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -166,6 +206,7 @@ class Qwen3MoeAttention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -189,6 +230,7 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -189,6 +230,7 @@ class Qwen3MoeAttention(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
dual_chunk_attention_config
=
dual_chunk_attention_config
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -210,72 +252,25 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -210,72 +252,25 @@ class Qwen3MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
**
{
"layer_idx"
:
extract_layer_index
(
prefix
),
"dual_chunk_attention_config"
:
dual_chunk_attention_config
,
}
if
dual_chunk_attention_config
else
{},
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -283,52 +278,23 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -283,52 +278,23 @@ class Qwen3MoeAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
:
# Add qk-norm
# Fused RMSNorm + RoPE path through custom op.
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
self
.
head_dim
)
if
(
cos_sin_cache
.
device
!=
q
.
device
if
envs
.
VLLM_USE_APEX_RN
:
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
# # q, k 使用 continuous
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
else
:
else
:
# Add qk-norm then RoPE (original path).
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
q
=
q_by_head
.
view
(
q
.
shape
)
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
else
:
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
q
=
q_by_head
.
view
(
q
.
shape
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
if
envs
.
VLLM_USE_APEX_RN
:
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
else
:
else
:
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
k
=
k_by_head
.
view
(
k
.
shape
)
k
=
k_by_head
.
view
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -336,19 +302,21 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -336,19 +302,21 @@ class Qwen3MoeAttention(nn.Module):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen3MoeAttention
(
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -362,6 +330,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -362,6 +330,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
# `mlp_only_layers` in the config.
# `mlp_only_layers` in the config.
...
@@ -371,8 +340,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -371,8 +340,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
config
.
num_experts
>
0
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
config
=
config
,
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
vllm_config
=
vllm_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
prefix
=
f
"
{
prefix
}
.mlp"
)
else
:
else
:
self
.
mlp
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
self
.
mlp
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
...
@@ -416,9 +384,11 @@ class Qwen3MoeModel(nn.Module):
...
@@ -416,9 +384,11 @@ class Qwen3MoeModel(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_text_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
parallel_config
=
vllm_config
.
parallel_config
eplb_config
=
parallel_config
.
eplb_config
self
.
num_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -433,12 +403,11 @@ class Qwen3MoeModel(nn.Module):
...
@@ -433,12 +403,11 @@ class Qwen3MoeModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen3MoeDecoderLayer
(
config
=
config
,
lambda
prefix
:
Qwen3MoeDecoderLayer
(
vllm_config
=
vllm_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
)
...
@@ -475,8 +444,7 @@ class Qwen3MoeModel(nn.Module):
...
@@ -475,8 +444,7 @@ class Qwen3MoeModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -486,6 +454,16 @@ class Qwen3MoeModel(nn.Module):
...
@@ -486,6 +454,16 @@ class Qwen3MoeModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
num_redundant_experts
=
self
.
num_redundant_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -502,16 +480,9 @@ class Qwen3MoeModel(nn.Module):
...
@@ -502,16 +480,9 @@ class Qwen3MoeModel(nn.Module):
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
"_weight_scale"
,
".input_scale"
,
"_input_scale"
)
"_weight_scale"
,
".input_scale"
,
"_input_scale"
)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
current_count
=
loaded_weight
.
current_count
current_count
=
loaded_weight
.
current_count
...
@@ -537,35 +508,68 @@ class Qwen3MoeModel(nn.Module):
...
@@ -537,35 +508,68 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
if
name
.
endswith
(
"scale"
):
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
(
param
,
loaded_weight
,
shard_id
)
default_weight_loader
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
loaded_weight
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
# Anyway, this is an expert weight and should not be
if
is_pp_missing_parameter
(
name
,
self
):
# attempted to load as other weights later
is_expert_weight
=
True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name_mapped
,
self
):
continue
continue
# Skip loading extra parameters for GPTQ/modelopt models.
# Skip loading extra parameters for GPTQ/modelopt models.
if
name
.
endswith
(
if
name_mapped
.
endswith
(
ignore_suffixes
)
and
name
not
in
params_dict
:
ignore_suffixes
)
and
name_mapped
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
param
=
params_dict
[
name_mapped
]
weight_loader
(
param
,
# We should ask the weight loader to return success or not
loaded_weight
,
# here since otherwise we may skip experts with other
name
,
# available replicas.
shard_id
=
shard_id
,
weight_loader
=
typing
.
cast
(
Callable
[...,
bool
],
expert_id
=
expert_id
)
param
.
weight_loader
)
break
success
=
weight_loader
(
param
,
loaded_weight
,
name_mapped
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
return_success
=
True
)
if
success
:
name
=
name_mapped
break
else
:
else
:
if
is_expert_weight
:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models.
# Skip loading extra parameters for GPTQ/modelopt models.
if
name
.
endswith
(
if
name
.
endswith
(
ignore_suffixes
)
and
name
not
in
params_dict
:
ignore_suffixes
)
and
name
not
in
params_dict
:
...
@@ -635,7 +639,8 @@ class Qwen3MoeModel(nn.Module):
...
@@ -635,7 +639,8 @@ class Qwen3MoeModel(nn.Module):
return
loaded_params
return
loaded_params
class
Qwen3MoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
Qwen3MoeForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
,
MixtureOfExperts
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -652,7 +657,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -652,7 +657,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_
text_
config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -660,13 +665,74 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -660,13 +665,74 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self
.
expert_weights
=
[]
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
example_layer
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
continue
assert
isinstance
(
layer
,
Qwen3MoeDecoderLayer
)
if
isinstance
(
layer
.
mlp
,
Qwen3MoeSparseMoeBlock
):
example_layer
=
layer
.
mlp
self
.
moe_layers
.
append
(
layer
.
mlp
.
experts
)
if
example_layer
is
None
:
raise
RuntimeError
(
"No Qwen3MoE layer found in the model.layers."
)
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
self
.
num_expert_groups
=
1
self
.
num_shared_experts
=
0
self
.
num_logical_experts
=
example_layer
.
n_logical_experts
self
.
num_physical_experts
=
example_layer
.
n_physical_experts
self
.
num_local_physical_experts
=
example_layer
.
n_local_physical_experts
self
.
num_routed_experts
=
example_layer
.
n_routed_experts
self
.
num_redundant_experts
=
example_layer
.
n_redundant_experts
def
set_eplb_state
(
self
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
None
:
for
layer_idx
,
layer
in
enumerate
(
self
.
moe_layers
):
# Register the expert weights.
self
.
expert_weights
.
append
(
layer
.
get_expert_weights
())
layer
.
set_eplb_state
(
moe_layer_idx
=
layer_idx
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
update_physical_experts_metadata
(
self
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
)
->
None
:
assert
self
.
num_local_physical_experts
==
num_local_physical_experts
self
.
num_physical_experts
=
num_physical_experts
self
.
num_local_physical_experts
=
num_local_physical_experts
self
.
num_redundant_experts
=
(
num_physical_experts
-
self
.
num_logical_experts
)
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
.
mlp
,
Qwen3MoeSparseMoeBlock
):
moe
=
layer
.
mlp
moe
.
n_local_physical_experts
=
num_local_physical_experts
moe
.
n_physical_experts
=
num_physical_experts
moe
.
n_redundant_experts
=
self
.
num_redundant_experts
moe
.
experts
.
update_expert_map
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
@@ -684,13 +750,14 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -684,13 +750,14 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
sampling_metadata
)
return
logits
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/v1/attention/backends/mla/common.py
View file @
0ce3b670
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
,
fuse_rmsnorm_rope_quant_gfx938
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -1095,6 +1096,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1095,6 +1096,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata
:
M
,
attn_metadata
:
M
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1154,7 +1157,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1154,7 +1157,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale
=
layer
.
_k_scale
,
scale
=
layer
.
_k_scale
,
)
)
else
:
else
:
from
lightop
import
fused_rms_norm_rope_contiguous
if
self
.
kv_cache_dtype
==
"auto"
:
if
self
.
kv_cache_dtype
==
"auto"
:
if
q
.
dtype
==
torch
.
float16
:
if
q
.
dtype
==
torch
.
float16
:
kv_cache_dtype_str
=
"fp16"
kv_cache_dtype_str
=
"fp16"
...
@@ -1162,22 +1164,61 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1162,22 +1164,61 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
kv_cache_dtype_str
=
"bf16"
else
:
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
kv_cache_dtype_str
=
self
.
kv_cache_dtype
fused_rms_norm_rope_contiguous
(
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
positions
[:
num_actual_toks
,
...],
if
has_prefill
:
q
,
fused_rms_norm_rope_contiguous
(
k_pe
.
squeeze
(
1
),
positions
[:
num_actual_toks
,
...],
k_c_normed
,
# not normed
q
,
key_normed
[:
num_actual_toks
,
...],
# normed
k_pe
.
squeeze
(
1
),
weight
,
k_c_normed
,
# not normed
cos_sin_cache
,
key_normed
[:
num_actual_toks
,
...],
# normed
attn_metadata
.
slot_mapping
.
flatten
(),
weight
,
kv_cache
,
cos_sin_cache
,
kv_cache_dtype_str
,
attn_metadata
.
slot_mapping
.
flatten
(),
1.0
,
kv_cache
,
False
,
kv_cache_dtype_str
,
1e-6
,
1.0
,
)
False
,
1e-6
,
)
else
:
q_tensor
=
torch
.
randn
(
q
.
shape
[
0
],
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
q_quant_gt
=
q_tensor
.
to
(
kv_cache_dtype_str
)
q_quant
=
torch
.
empty_like
(
q_quant_gt
)
fuse_rmsnorm_rope_quant_gfx938
(
positions
[:
num_actual_toks
,
...],
query_nope
,
q
,
q_quant
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
if
has_prefill
:
if
has_prefill
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
0ce3b670
...
@@ -179,7 +179,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -179,7 +179,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment