Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9b92dcd
Unverified
Commit
e9b92dcd
authored
Sep 03, 2025
by
bnellnm
Committed by
GitHub
Sep 03, 2025
Browse files
[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
fa4311d8
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
157 additions
and
66 deletions
+157
-66
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+1
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+2
-2
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+2
-2
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+2
-2
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+3
-3
vllm/model_executor/layers/quantization/rtn.py
vllm/model_executor/layers/quantization/rtn.py
+2
-2
vllm/model_executor/layers/shared_fused_moe/__init__.py
vllm/model_executor/layers/shared_fused_moe/__init__.py
+6
-0
vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py
...odel_executor/layers/shared_fused_moe/shared_fused_moe.py
+56
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+63
-40
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+2
-0
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+16
-13
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+2
-1
No files found.
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
e9b92dcd
...
@@ -654,7 +654,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -654,7 +654,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
assert
self
.
fused_experts
is
None
assert
self
.
fused_experts
is
None
if
enable_eplb
:
if
enable_eplb
:
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
e9b92dcd
...
@@ -491,7 +491,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -491,7 +491,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
)
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
)
...
@@ -1366,7 +1366,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
...
@@ -1366,7 +1366,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
):
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
e9b92dcd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
...
@@ -305,7 +305,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -305,7 +305,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
assert
self
.
fused_experts
is
None
assert
self
.
fused_experts
is
None
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
e9b92dcd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
,
Union
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -554,7 +554,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -554,7 +554,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
e9b92dcd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
...
@@ -226,7 +226,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -226,7 +226,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
assert
self
.
fused_experts
is
None
assert
self
.
fused_experts
is
None
if
enable_eplb
:
if
enable_eplb
:
...
@@ -390,7 +390,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
...
@@ -390,7 +390,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
assert
self
.
fused_experts
is
None
assert
self
.
fused_experts
is
None
if
enable_eplb
:
if
enable_eplb
:
...
...
vllm/model_executor/layers/quantization/rtn.py
View file @
e9b92dcd
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# Copyright © 2025, Oracle and/or its affiliates.
# Copyright © 2025, Oracle and/or its affiliates.
import
os
import
os
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -291,7 +291,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
...
@@ -291,7 +291,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
assert
self
.
fused_experts
is
None
assert
self
.
fused_experts
is
None
if
enable_eplb
:
if
enable_eplb
:
...
...
vllm/model_executor/layers/shared_fused_moe/__init__.py
0 → 100644
View file @
e9b92dcd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.model_executor.layers.shared_fused_moe.shared_fused_moe
import
(
SharedFusedMoE
)
__all__
=
[
"SharedFusedMoE"
]
vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py
0 → 100644
View file @
e9b92dcd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm.distributed
import
tensor_model_parallel_all_reduce
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class
SharedFusedMoE
(
FusedMoE
):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def
__init__
(
self
,
shared_experts
:
torch
.
nn
.
Module
,
use_overlapped
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
_shared_experts
=
shared_experts
self
.
use_overlapped
=
use_overlapped
@
property
def
shared_experts
(
self
)
->
Optional
[
torch
.
nn
.
Module
]:
return
self
.
_shared_experts
if
self
.
use_overlapped
else
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
self
.
use_overlapped
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if
(
self
.
reduce_results
and
self
.
tp_size
>
1
and
self
.
must_reduce_shared_expert_outputs
()):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
else
:
shared_out
,
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
return
shared_out
,
fused_out
vllm/model_executor/models/deepseek_v2.py
View file @
e9b92dcd
...
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -147,63 +148,85 @@ class DeepseekV2MoE(nn.Module):
...
@@ -147,63 +148,85 @@ class DeepseekV2MoE(nn.Module):
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
if
config
.
n_shared_experts
is
None
:
num_experts
=
config
.
n_routed_experts
,
self
.
experts
=
FusedMoE
(
top_k
=
config
.
num_experts_per_tok
,
num_experts
=
config
.
n_routed_experts
,
hidden_size
=
config
.
hidden_size
,
top_k
=
config
.
num_experts_per_tok
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_size
=
config
.
hidden_size
,
reduce_results
=
False
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
reduce_results
=
False
,
quant_config
=
quant_config
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
quant_config
=
quant_config
,
num_expert_group
=
config
.
n_group
,
use_grouped_topk
=
True
,
topk_group
=
config
.
topk_group
,
num_expert_group
=
config
.
n_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
topk_group
=
config
.
topk_group
,
scoring_func
=
config
.
scoring_func
,
prefix
=
f
"
{
prefix
}
.experts"
,
# we do scaling outside, set factor to 1.0 to avoid double mul
scoring_func
=
config
.
scoring_func
,
routed_scaling_factor
=
1.0
,
# we do scaling outside, set factor to 1.0 to avoid double mul
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
1.0
,
enable_eplb
=
self
.
enable_eplb
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
num_redundant_experts
=
self
.
n_redundant_experts
)
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
)
if
config
.
n_shared_experts
is
not
None
:
self
.
shared_experts
=
None
else
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
config
.
n_shared_experts
)
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(
reduce_results
=
False
,
),
prefix
=
f
"
{
prefix
}
.shared_experts"
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
hidden_states
.
dtype
!=
torch
.
float16
:
fused_moe_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
router_logits
=
router_logits
)
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
if
self
.
shared_experts
is
not
None
:
shared_output
,
final_hidden_states
=
fused_moe_out
else
:
else
:
# Fix FP16 overflow
shared_output
=
None
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
fused_moe_out
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
# Fix FP16 overflow
if
shared_output
is
not
None
:
# See DeepseekV2DecoderLayer for more details.
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
*=
self
.
routed_scaling_factor
else
:
elif
self
.
shared_experts
is
not
None
:
# Fix FP16 overflow
assert
shared_output
is
not
None
# See DeepseekV2DecoderLayer for more details.
shared_output
*=
(
1.
/
self
.
routed_scaling_factor
)
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
shared_experts
is
not
None
:
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
(
final_hidden_states
=
(
...
...
vllm/model_executor/models/glm4_moe.py
View file @
e9b92dcd
...
@@ -184,6 +184,8 @@ class Glm4MoE(nn.Module):
...
@@ -184,6 +184,8 @@ class Glm4MoE(nn.Module):
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
else
:
shared_output
=
None
router_logits
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
router_logits
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
...
vllm/model_executor/models/llama4.py
View file @
e9b92dcd
...
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
...
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
...
@@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
...
@@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
quant_config
=
None
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.router"
)
prefix
=
f
"
{
prefix
}
.router"
)
self
.
experts
=
FusedMoE
(
self
.
shared_expert
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size_moe
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.shared_expert"
,
reduce_results
=
False
,
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_expert
,
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
...
@@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
reduce_results
=
False
,
reduce_results
=
False
,
renormalize
=
False
,
renormalize
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
prefix
=
f
"
{
prefix
}
.experts"
,
self
.
shared_expert
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size_moe
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.shared_expert"
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(),
)
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
router_logits
,
_
=
self
.
router
(
hidden_states
)
router_logits
,
_
=
self
.
router
(
hidden_states
)
shared_out
=
self
.
shared_expert
(
hidden_states
)
routed_out
=
self
.
experts
(
shared_out
,
routed_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
)
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
e9b92dcd
...
@@ -500,7 +500,8 @@ class Worker(WorkerBase):
...
@@ -500,7 +500,8 @@ class Worker(WorkerBase):
parallel_config
=
self
.
vllm_config
.
parallel_config
parallel_config
=
self
.
vllm_config
.
parallel_config
moe_modules
=
[
moe_modules
=
[
module
for
module
in
self
.
model_runner
.
model
.
modules
()
module
for
module
in
self
.
model_runner
.
model
.
modules
()
if
module
.
__class__
.
__name__
==
"FusedMoE"
if
(
module
.
__class__
.
__name__
==
"FusedMoE"
or
module
.
__class__
.
__name__
==
"SharedFusedMoE"
)
]
]
num_local_experts
=
moe_modules
[
0
].
moe_config
.
num_local_experts
num_local_experts
=
moe_modules
[
0
].
moe_config
.
num_local_experts
assert
all
(
module
.
moe_config
.
num_local_experts
==
num_local_experts
assert
all
(
module
.
moe_config
.
num_local_experts
==
num_local_experts
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment