Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
65e1d22d
Commit
65e1d22d
authored
Feb 06, 2026
by
zhuwenwen
Browse files
fix moe run error
parent
661623f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
103 additions
and
4 deletions
+103
-4
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
.../model_executor/layers/fused_moe/fused_moe_method_base.py
+102
-1
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+1
-3
No files found.
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
View file @
65e1d22d
...
...
@@ -25,4 +25,105 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
()
self
.
moe
:
FusedMoEConfig
=
moe
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
\ No newline at end of file
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return
False
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
FusedMoEPrepareAndFinalize
|
None
:
from
.all2all_utils
import
maybe_make_prepare_finalize
return
maybe_make_prepare_finalize
(
self
.
moe
,
self
.
moe_quant_config
,
routing_tables
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
must select appropriate gemm "
"implementation based on the prepare_finalize"
)
def
prepare_dp_allgather_tensor
(
self
,
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise
NotImplementedError
(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f
"
{
self
.
__class__
.
__name__
}
."
)
@
abstractmethod
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
raise
NotImplementedError
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
None
@
property
def
supports_eplb
(
self
)
->
bool
:
return
False
@
property
def
allow_inplace
(
self
)
->
bool
:
return
False
@
property
def
method_name
(
self
)
->
str
:
return
self
.
__class__
.
__name__
@
property
def
is_monolithic
(
self
)
->
bool
:
return
False
# @abstractmethod
def
apply
(
self
,
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
# @abstractmethod
def
apply_monolithic
(
self
,
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
\ No newline at end of file
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
65e1d22d
...
...
@@ -1069,7 +1069,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
defer_input_quant
=
self
.
fused_experts
.
expects_unquantized_inputs
,
)
else
:
# Overlap shared expert compute with all2all dispatch.
...
...
@@ -1082,7 +1081,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
defer_input_quant
=
self
.
fused_experts
.
expects_unquantized_inputs
,
)
# TODO(lucas): refactor this in the alternative schedules followup
...
...
@@ -1361,4 +1359,4 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment