Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
327a02d8
Unverified
Commit
327a02d8
authored
Jan 18, 2026
by
bnellnm
Committed by
GitHub
Jan 18, 2026
Browse files
[MoE Refactor] Separate Router into OO Classes (#30623)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
2f03035a
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
642 additions
and
655 deletions
+642
-655
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+1
-1
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+1
-1
tests/kernels/moe/test_block_fp8.py
tests/kernels/moe/test_block_fp8.py
+4
-2
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+1
-1
tests/kernels/moe/test_flashinfer_moe.py
tests/kernels/moe/test_flashinfer_moe.py
+1
-1
tests/kernels/moe/test_grouped_topk.py
tests/kernels/moe/test_grouped_topk.py
+1
-1
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+3
-1
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+1
-1
tests/kernels/moe/test_nvfp4_moe.py
tests/kernels/moe/test_nvfp4_moe.py
+1
-1
tests/kernels/moe/test_pplx_cutlass_moe.py
tests/kernels/moe/test_pplx_cutlass_moe.py
+1
-1
tests/kernels/moe/test_routing.py
tests/kernels/moe/test_routing.py
+499
-0
tests/kernels/moe/test_routing_simulator.py
tests/kernels/moe/test_routing_simulator.py
+38
-34
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+1
-1
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+9
-0
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+9
-5
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+5
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+0
-375
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
.../model_executor/layers/fused_moe/fused_moe_method_base.py
+3
-3
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
...del_executor/layers/fused_moe/fused_moe_modular_method.py
+3
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+60
-224
No files found.
tests/kernels/moe/modular_kernel_tools/common.py
View file @
327a02d8
...
@@ -21,12 +21,12 @@ from vllm.distributed import (
...
@@ -21,12 +21,12 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.utils.import_utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.import_utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
.mk_objects
import
(
from
.mk_objects
import
(
...
...
tests/kernels/moe/test_batched_moe.py
View file @
327a02d8
...
@@ -15,10 +15,10 @@ from tests.kernels.moe.utils import (
...
@@ -15,10 +15,10 @@ from tests.kernels.moe.utils import (
from
tests.kernels.quant_utils
import
native_batched_masked_quant_matmul
from
tests.kernels.quant_utils
import
native_batched_masked_quant_matmul
from
tests.kernels.utils
import
torch_experts
from
tests.kernels.utils
import
torch_experts
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
invoke_moe_batched_triton_kernel
,
invoke_moe_batched_triton_kernel
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
from
vllm.triton_utils
import
tl
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
...
...
tests/kernels/moe/test_block_fp8.py
View file @
327a02d8
...
@@ -11,13 +11,15 @@ from tests.kernels.quant_utils import (
...
@@ -11,13 +11,15 @@ from tests.kernels.quant_utils import (
)
)
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
(
fused_experts
,
fused_topk
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm_shape
,
_valid_deep_gemm_shape
,
deep_gemm_moe_fp8
,
deep_gemm_moe_fp8
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
modular_triton_fused_moe
,
modular_triton_fused_moe
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
...
tests/kernels/moe/test_cutlass_moe.py
View file @
327a02d8
...
@@ -10,6 +10,7 @@ import torch
...
@@ -10,6 +10,7 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
...
@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
...
@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8
,
CutlassExpertsFp8
,
run_cutlass_moe_fp8
,
run_cutlass_moe_fp8
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
MoEPrepareAndFinalizeNoEP
,
)
)
...
...
tests/kernels/moe/test_flashinfer_moe.py
View file @
327a02d8
...
@@ -12,6 +12,7 @@ from tests.kernels.quantization.nvfp4_utils import (
...
@@ -12,6 +12,7 @@ from tests.kernels.quantization.nvfp4_utils import (
from
tests.kernels.utils
import
torch_moe
from
tests.kernels.utils
import
torch_moe
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
FlashInferExperts
,
is_valid_flashinfer_cutlass_fused_moe
,
is_valid_flashinfer_cutlass_fused_moe
,
...
@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
...
@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
create_flashinfer_prepare_finalize
,
create_flashinfer_prepare_finalize
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
...
...
tests/kernels/moe/test_grouped_topk.py
View file @
327a02d8
...
@@ -14,7 +14,7 @@ from vllm.config import (
...
@@ -14,7 +14,7 @@ from vllm.config import (
get_cached_compilation_config
,
get_cached_compilation_config
,
set_current_vllm_config
,
set_current_vllm_config
,
)
)
from
vllm.model_executor.layers.fused_moe.
fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.
router.grouped_topk_router
import
(
GroupedTopk
,
GroupedTopk
,
fused_grouped_topk
,
fused_grouped_topk
,
)
)
...
...
tests/kernels/moe/test_moe.py
View file @
327a02d8
...
@@ -24,6 +24,9 @@ from vllm._aiter_ops import rocm_aiter_ops
...
@@ -24,6 +24,9 @@ from vllm._aiter_ops import rocm_aiter_ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.fused_moe
import
(
fused_topk
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FUSED_MOE_UNQUANTIZED_CONFIG
,
int4_w4a16_moe_quant_config
,
int4_w4a16_moe_quant_config
,
...
@@ -34,7 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
...
@@ -34,7 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe
,
fused_marlin_moe
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
modular_triton_fused_moe
,
modular_triton_fused_moe
,
)
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
View file @
327a02d8
...
@@ -9,7 +9,7 @@ import numpy as np
...
@@ -9,7 +9,7 @@ import numpy as np
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.layers.fused_moe
.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.layer
import
determine_expert_map
from
vllm.model_executor.layers.fused_moe.layer
import
determine_expert_map
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_permute
,
...
...
tests/kernels/moe/test_nvfp4_moe.py
View file @
327a02d8
...
@@ -13,11 +13,11 @@ from tests.kernels.quantization.nvfp4_utils import (
...
@@ -13,11 +13,11 @@ from tests.kernels.quantization.nvfp4_utils import (
from
tests.kernels.utils
import
torch_moe
from
tests.kernels.utils
import
torch_moe
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
nvfp4_moe_quant_config
from
vllm.model_executor.layers.fused_moe.config
import
nvfp4_moe_quant_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
CutlassExpertsFp4
,
CutlassExpertsFp4
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
MoEPrepareAndFinalizeNoEP
,
)
)
...
...
tests/kernels/moe/test_pplx_cutlass_moe.py
View file @
327a02d8
...
@@ -8,9 +8,9 @@ import torch
...
@@ -8,9 +8,9 @@ import torch
from
tests.kernels.utils
import
torch_experts
from
tests.kernels.utils
import
torch_experts
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassBatchedExpertsFp8
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassBatchedExpertsFp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
...
...
tests/kernels/moe/test_routing.py
0 → 100644
View file @
327a02d8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
pytest
import
torch
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
from
vllm.model_executor.layers.fused_moe.router.router_factory
import
(
create_fused_moe_router
,
)
from
vllm.model_executor.models.llama4
import
Llama4MoE
# Test parameters
MK_S
=
[(
32
,
256
),
(
64
,
512
)]
TOP_KS
=
[
2
,
4
,
6
]
NUM_EXPERTS
=
[
8
,
16
,
64
]
def
setup_eplb_state
(
enable_eplb
:
bool
,
global_num_experts
:
int
)
->
EplbLayerState
:
if
not
enable_eplb
:
return
EplbLayerState
()
# Initialize EPLB state with proper tensors for testing
# For testing purposes, we use a simple 1:1 mapping (no redundant experts)
# expert_load_view: tracks load on each expert (shape: num_experts)
expert_load_view
=
torch
.
zeros
(
global_num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# logical_to_physical_map: maps logical experts to physical experts
# Shape: (num_logical_experts, max_slots)
# For testing, use simple 1:1 mapping with single slot per expert
logical_to_physical_map
=
torch
.
arange
(
global_num_experts
,
dtype
=
torch
.
int64
,
device
=
"cuda"
).
unsqueeze
(
-
1
)
# logical_replica_count: number of replicas per logical expert
# Shape: (num_logical_experts,)
# For testing, each logical expert has exactly 1 replica
logical_replica_count
=
torch
.
ones
(
global_num_experts
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
return
EplbLayerState
(
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
make_test_data
(
m
:
int
,
k
:
int
,
num_experts
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
)
/
10
logits
=
torch
.
randn
((
m
,
num_experts
),
device
=
"cuda"
)
return
hidden_states
,
logits
def
make_e_score_correction_bias
(
e_score_correction_bias_val
:
float
,
num_experts
:
int
,
)
->
torch
.
Tensor
:
# return torch.randn(num_experts, device="cuda") * e_score_correction_bias_val
return
torch
.
full
(
(
num_experts
,),
e_score_correction_bias_val
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
def
assert_routing_results_close
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
baseline_weights
:
torch
.
Tensor
,
baseline_ids
:
torch
.
Tensor
,
rtol
:
float
=
1e-3
,
atol
:
float
=
1e-3
,
):
"""
Compare routing results, sorting by expert ID first to handle non-deterministic
ordering from sorted=False in topk.
"""
# Sort both results by expert IDs for consistent comparison
sorted_indices_actual
=
torch
.
argsort
(
topk_ids
,
dim
=-
1
)
sorted_indices_baseline
=
torch
.
argsort
(
baseline_ids
.
to
(
topk_ids
.
dtype
),
dim
=-
1
)
# Gather the sorted values
topk_ids_sorted
=
torch
.
gather
(
topk_ids
,
1
,
sorted_indices_actual
)
topk_weights_sorted
=
torch
.
gather
(
topk_weights
,
1
,
sorted_indices_actual
)
baseline_ids_sorted
=
torch
.
gather
(
baseline_ids
.
to
(
topk_ids
.
dtype
),
1
,
sorted_indices_baseline
)
baseline_weights_sorted
=
torch
.
gather
(
baseline_weights
,
1
,
sorted_indices_baseline
)
# Compare
torch
.
testing
.
assert_close
(
topk_ids_sorted
,
baseline_ids_sorted
)
torch
.
testing
.
assert_close
(
topk_weights_sorted
,
baseline_weights_sorted
,
rtol
=
rtol
,
atol
=
atol
)
def
baseline_fused_topk
(
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Baseline for standard fused top-k routing.
Algorithm:
1. Apply softmax to router logits
2. Select top-k experts
3. Optionally renormalize the weights
"""
scores
=
torch
.
softmax
(
router_logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
# Use sorted=False to match vllm implementation (vllm_is_batch_invariant
# defaults to False)
topk_weights
,
topk_ids
=
torch
.
topk
(
scores
,
top_k
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
baseline_fused_topk_bias
(
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
e_score_correction_bias
:
torch
.
Tensor
,
routed_scaling_factor
:
float
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Baseline for fused top-k with bias correction.
Algorithm:
1. Apply softmax to router logits
2. Add bias to scores for expert selection
3. Select top-k experts using biased scores
4. Get weights from original (unbiased) scores
5. Apply routed scaling factor
6. Optionally renormalize the weights
"""
# Apply softmax to get scores
scores
=
torch
.
softmax
(
router_logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
# Add bias for expert selection
scores_for_choice
=
scores
+
e_score_correction_bias
.
unsqueeze
(
0
)
# Select top-k using biased scores (sorted=False to match implementation)
topk_ids
=
torch
.
topk
(
scores_for_choice
,
k
=
top_k
,
dim
=-
1
,
sorted
=
False
)[
1
]
# Get weights from original scores (not biased)
topk_weights
=
scores
.
gather
(
1
,
topk_ids
)
# Renormalize if needed (BEFORE applying scaling factor)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# Apply scaling factor (AFTER renormalization, if applicable)
if
routed_scaling_factor
!=
1.0
:
topk_weights
*=
routed_scaling_factor
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
baseline_grouped_topk
(
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
num_expert_group
:
int
,
topk_group
:
int
,
scoring_func
:
str
,
renormalize
:
bool
,
e_score_correction_bias
:
torch
.
Tensor
|
None
,
routed_scaling_factor
:
float
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Baseline for grouped top-k routing (e.g., DeepSeek).
Algorithm:
1. Apply scoring function (softmax or sigmoid)
2. Optionally add bias
3. Select top-k groups based on max scores within each group
4. Mask scores to only include selected groups
5. Select top-k experts from masked scores
6. Apply scaling factor
7. Optionally renormalize
"""
num_token
=
router_logits
.
shape
[
0
]
# Apply scoring function
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
router_logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
elif
scoring_func
==
"sigmoid"
:
scores
=
torch
.
sigmoid
(
router_logits
.
float
())
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
# Handle bias correction
if
e_score_correction_bias
is
not
None
:
original_scores
=
scores
scores
=
scores
+
e_score_correction_bias
.
unsqueeze
(
0
)
# For bias case, use sum of top-2 scores in each group
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
)
)
else
:
# Use max score in each group
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# Select top-k groups
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# Create mask for selected groups
group_mask
=
torch
.
zeros_like
(
group_scores
)
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# Expand mask to all experts
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
)
.
reshape
(
num_token
,
-
1
)
)
# Mask scores (set non-selected to -inf)
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
float
(
"-inf"
))
# Select top-k experts
if
e_score_correction_bias
is
not
None
:
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
top_k
,
dim
=-
1
,
sorted
=
False
)[
1
]
topk_weights
=
original_scores
.
gather
(
1
,
topk_ids
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
top_k
,
dim
=-
1
,
sorted
=
False
)
# Renormalize if needed
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# Apply scaling factor
if
routed_scaling_factor
!=
1.0
:
topk_weights
*=
routed_scaling_factor
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
baseline_custom_llama4
(
router_logits
:
torch
.
Tensor
,
top_k
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Baseline for Llama4 custom routing.
Algorithm:
1. Select top-k expert indices (without softmax)
2. Apply sigmoid to the selected scores
"""
router_scores
,
router_indices
=
torch
.
topk
(
router_logits
,
top_k
,
dim
=-
1
)
router_scores
=
torch
.
sigmoid
(
router_scores
.
float
())
return
router_scores
.
to
(
torch
.
float32
),
router_indices
.
to
(
torch
.
int32
)
@
pytest
.
mark
.
parametrize
(
"m,k"
,
MK_S
)
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"global_num_experts"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_eplb"
,
[
False
,
True
])
def
test_fused_topk
(
m
:
int
,
k
:
int
,
top_k
:
int
,
global_num_experts
:
int
,
renormalize
:
bool
,
enable_eplb
:
bool
,
):
if
top_k
>
global_num_experts
:
pytest
.
skip
(
f
"top_k (
{
top_k
}
) > global_num_experts (
{
global_num_experts
}
)"
)
eplb_state
=
setup_eplb_state
(
enable_eplb
,
global_num_experts
)
router
=
create_fused_moe_router
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
renormalize
=
renormalize
,
enable_eplb
=
enable_eplb
,
eplb_state
=
eplb_state
,
)
hidden_states
,
router_logits
=
make_test_data
(
m
,
k
,
global_num_experts
)
# Get router output
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
,
router_logits
)
# Compute baseline
baseline_weights
,
baseline_ids
=
baseline_fused_topk
(
router_logits
,
top_k
,
renormalize
)
# Compare results
assert_routing_results_close
(
topk_weights
,
topk_ids
,
baseline_weights
,
baseline_ids
)
@
pytest
.
mark
.
parametrize
(
"m,k"
,
MK_S
)
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"global_num_experts"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_eplb"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"e_score_correction_bias_val"
,
[
0.9
])
@
pytest
.
mark
.
parametrize
(
"routed_scaling_factor"
,
[
1.0
,
1.1
])
def
test_fused_topk_bias
(
m
:
int
,
k
:
int
,
top_k
:
int
,
global_num_experts
:
int
,
renormalize
:
bool
,
enable_eplb
:
bool
,
e_score_correction_bias_val
:
float
,
routed_scaling_factor
:
float
,
):
if
top_k
>
global_num_experts
:
pytest
.
skip
(
f
"top_k (
{
top_k
}
) > global_num_experts (
{
global_num_experts
}
)"
)
eplb_state
=
setup_eplb_state
(
enable_eplb
,
global_num_experts
)
e_score_correction_bias
=
make_e_score_correction_bias
(
e_score_correction_bias_val
,
global_num_experts
,
)
router
=
create_fused_moe_router
(
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
renormalize
=
renormalize
,
enable_eplb
=
enable_eplb
,
eplb_state
=
eplb_state
,
)
hidden_states
,
router_logits
=
make_test_data
(
m
,
k
,
global_num_experts
)
# Get router output
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
,
router_logits
)
# Compute baseline
baseline_weights
,
baseline_ids
=
baseline_fused_topk_bias
(
router_logits
,
top_k
,
renormalize
,
e_score_correction_bias
,
routed_scaling_factor
,
)
# Compare results
assert_routing_results_close
(
topk_weights
,
topk_ids
,
baseline_weights
,
baseline_ids
)
@
pytest
.
mark
.
parametrize
(
"m,k"
,
MK_S
)
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"global_num_experts,num_expert_group,topk_group"
,
[
(
64
,
8
,
4
),
# 8 groups of 8 experts, select 4 groups
(
32
,
4
,
2
),
# 4 groups of 8 experts, select 2 groups
],
)
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_eplb"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"e_score_correction_bias_val"
,
[
0.9
])
@
pytest
.
mark
.
parametrize
(
"routed_scaling_factor"
,
[
1.0
,
1.1
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"sigmoid"
,
"softmax"
])
def
test_grouped_topk
(
m
:
int
,
k
:
int
,
top_k
:
int
,
global_num_experts
:
int
,
renormalize
:
bool
,
enable_eplb
:
bool
,
num_expert_group
:
int
,
topk_group
:
int
,
scoring_func
:
str
,
e_score_correction_bias_val
:
float
,
routed_scaling_factor
:
float
,
):
if
top_k
>
global_num_experts
:
pytest
.
skip
(
f
"top_k (
{
top_k
}
) > global_num_experts (
{
global_num_experts
}
)"
)
eplb_state
=
setup_eplb_state
(
enable_eplb
,
global_num_experts
)
e_score_correction_bias
=
make_e_score_correction_bias
(
e_score_correction_bias_val
,
global_num_experts
,
)
routing_method_type
=
None
if
scoring_func
==
"llama4"
:
routing_method_type
=
RoutingMethodType
.
Llama4
scoring_func
=
"sigmoid"
router
=
create_fused_moe_router
(
use_grouped_topk
=
True
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routing_method_type
=
routing_method_type
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
renormalize
=
renormalize
,
enable_eplb
=
enable_eplb
,
eplb_state
=
eplb_state
,
)
hidden_states
,
router_logits
=
make_test_data
(
m
,
k
,
global_num_experts
)
# Get router output
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
,
router_logits
)
# Compute baseline
baseline_weights
,
baseline_ids
=
baseline_grouped_topk
(
router_logits
,
top_k
,
num_expert_group
,
topk_group
,
scoring_func
,
renormalize
,
e_score_correction_bias
,
routed_scaling_factor
,
)
# Compare results
assert_routing_results_close
(
topk_weights
,
topk_ids
,
baseline_weights
,
baseline_ids
)
@
pytest
.
mark
.
parametrize
(
"m,k"
,
MK_S
)
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"global_num_experts"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_eplb"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"custom_routing_function"
,
[
Llama4MoE
.
custom_routing_function
])
def
test_custom
(
m
:
int
,
k
:
int
,
top_k
:
int
,
global_num_experts
:
int
,
renormalize
:
bool
,
enable_eplb
:
bool
,
custom_routing_function
:
Callable
,
):
if
top_k
>
global_num_experts
:
pytest
.
skip
(
f
"top_k (
{
top_k
}
) > global_num_experts (
{
global_num_experts
}
)"
)
eplb_state
=
setup_eplb_state
(
enable_eplb
,
global_num_experts
)
router
=
create_fused_moe_router
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
custom_routing_function
=
custom_routing_function
,
renormalize
=
renormalize
,
enable_eplb
=
enable_eplb
,
eplb_state
=
eplb_state
,
)
hidden_states
,
router_logits
=
make_test_data
(
m
,
k
,
global_num_experts
)
# Get router output
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
,
router_logits
)
# Compute baseline (Llama4 uses sigmoid)
baseline_weights
,
baseline_ids
=
baseline_custom_llama4
(
router_logits
,
top_k
)
# Compare results
assert_routing_results_close
(
topk_weights
,
topk_ids
,
baseline_weights
,
baseline_ids
)
# TODO: is other test sufficient?
# # See tests/test_routing_simulatator.py
# @pytest.mark.parametrize("m,k", MK_S)
# @pytest.mark.parametrize("top_k", TOP_KS)
# @pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
# @pytest.mark.parametrize("renormalize", [False, True])
# @pytest.mark.parametrize("enable_eplb", [False, True])
# @pytest.mark.parameterize("strategy", ["uniform_random", "normal_routing"])
# def test_simulated(
# m: int,
# k: int,
# top_k: int,
# global_num_experts: int,
# renormalize: bool,
# enable_eplb: bool,
# strategy: str,
# monkeypatch,
# ):
# eplb_state = setup_eplb_state(enable_eplb)
# monkeypatch.setenv("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", strategy)
# router = create_fused_moe_router(
# top_k=top_k,
# global_num_experts=global_num_experts,
# enable_eplb=enable_eplb,
# eplb_state=eplb_state,
# )
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
tests/test_routing_simulator.py
→
tests/
kernels/moe/
test_routing_simulator.py
View file @
327a02d8
...
@@ -19,7 +19,7 @@ from vllm.distributed import (
...
@@ -19,7 +19,7 @@ from vllm.distributed import (
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
vllm.model_executor.layers.fused_moe.routing_simulator
import
(
from
vllm.model_executor.layers.fused_moe.
router.
routing_simulator
_router
import
(
DistributionBasedRouting
,
DistributionBasedRouting
,
RoutingSimulator
,
RoutingSimulator
,
)
)
...
@@ -109,6 +109,8 @@ def test_routing_strategy_integration(monkeypatch, device):
...
@@ -109,6 +109,8 @@ def test_routing_strategy_integration(monkeypatch, device):
tensor_model_parallel_size
=
1
,
tensor_model_parallel_size
=
1
,
pipeline_model_parallel_size
=
1
,
pipeline_model_parallel_size
=
1
,
)
)
for
strategy
in
strategies
:
fused_moe
=
FusedMoE
(
fused_moe
=
FusedMoE
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
...
@@ -116,9 +118,9 @@ def test_routing_strategy_integration(monkeypatch, device):
...
@@ -116,9 +118,9 @@ def test_routing_strategy_integration(monkeypatch, device):
intermediate_size
=
0
,
intermediate_size
=
0
,
use_grouped_topk
=
False
,
use_grouped_topk
=
False
,
renormalize
=
True
,
renormalize
=
True
,
prefix
=
strategy
,
)
)
for
strategy
in
strategies
:
# Set environment variable
# Set environment variable
env_name
=
"VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
env_name
=
"VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
monkeypatch
.
setenv
(
env_name
,
strategy
)
monkeypatch
.
setenv
(
env_name
,
strategy
)
...
@@ -136,7 +138,9 @@ def test_routing_strategy_integration(monkeypatch, device):
...
@@ -136,7 +138,9 @@ def test_routing_strategy_integration(monkeypatch, device):
assert
topk_weights
.
shape
==
(
num_tokens
,
top_k
),
(
assert
topk_weights
.
shape
==
(
num_tokens
,
top_k
),
(
f
"Wrong weights shape for
{
strategy
}
"
f
"Wrong weights shape for
{
strategy
}
"
)
)
assert
topk_ids
.
shape
==
(
num_tokens
,
top_k
),
f
"Wrong ids shape for
{
strategy
}
"
assert
topk_ids
.
shape
==
(
num_tokens
,
top_k
),
(
f
"Wrong ids shape for
{
strategy
}
"
)
# Verify expert IDs are valid
# Verify expert IDs are valid
assert
topk_ids
.
min
()
>=
0
,
f
"Invalid expert ID (negative) for
{
strategy
}
"
assert
topk_ids
.
min
()
>=
0
,
f
"Invalid expert ID (negative) for
{
strategy
}
"
...
...
tests/model_executor/test_enabled_custom_ops.py
View file @
327a02d8
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.activation import (
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.activation import (
ReLUSquaredActivation
,
ReLUSquaredActivation
,
SiluAndMul
,
SiluAndMul
,
)
)
from
vllm.model_executor.layers.fused_moe.
fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.
router.fused_topk_router
import
(
dispatch_topk_func
,
dispatch_topk_func
,
vllm_topk_softmax
,
vllm_topk_softmax
,
)
)
...
...
vllm/distributed/eplb/eplb_state.py
View file @
327a02d8
...
@@ -1158,6 +1158,15 @@ class EplbState:
...
@@ -1158,6 +1158,15 @@ class EplbState:
return
self
.
_allreduce_list
(
load_pass_list
)
return
self
.
_allreduce_list
(
load_pass_list
)
@
dataclass
class
EplbLayerState
:
"""Runtime EPLB data stored in the MoE layer."""
expert_load_view
:
torch
.
Tensor
|
None
=
None
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
logical_replica_count
:
torch
.
Tensor
|
None
=
None
def
_node_count_with_rank_mapping
(
def
_node_count_with_rank_mapping
(
pg
:
ProcessGroup
|
StatelessProcessGroup
,
pg
:
ProcessGroup
|
StatelessProcessGroup
,
rank_mapping
:
dict
[
int
,
int
],
rank_mapping
:
dict
[
int
,
int
],
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
327a02d8
...
@@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe_router
import
(
FusedMoERouter
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoE
,
FusedMoeWeightScaleSupported
,
FusedMoeWeightScaleSupported
,
...
@@ -23,6 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
...
@@ -23,6 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoEPrepareAndFinalize
,
)
)
from
vllm.model_executor.layers.fused_moe.router.fused_moe_router
import
(
FusedMoERouter
,
)
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
UnquantizedFusedMoEMethod
,
...
@@ -83,13 +83,17 @@ if HAS_TRITON:
...
@@ -83,13 +83,17 @@ if HAS_TRITON:
BatchedTritonExperts
,
BatchedTritonExperts
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
GroupedTopk
,
TritonExperts
,
TritonExperts
,
TritonWNA16Experts
,
TritonWNA16Experts
,
fused_experts
,
fused_experts
,
fused_topk
,
get_config_file_name
,
get_config_file_name
,
)
)
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
(
fused_topk
,
)
from
vllm.model_executor.layers.fused_moe.router.grouped_topk_router
import
(
GroupedTopk
,
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
TritonOrDeepGemmExperts
,
)
)
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
327a02d8
...
@@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum):
...
@@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum):
RenormalizeNaive
=
(
4
,)
RenormalizeNaive
=
(
4
,)
# TopK: TopK (no softmax)
# TopK: TopK (no softmax)
TopK
=
(
5
,)
TopK
=
(
5
,)
# Custom
Custom
=
(
6
,)
# Simulated
Simulated
=
(
7
,)
# Unspecified
# Unspecified
Unspecified
=
6
.0
Unspecified
=
8
.0
@
dataclass
@
dataclass
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
327a02d8
...
@@ -13,9 +13,7 @@ import torch
...
@@ -13,9 +13,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.batch_invariant
import
(
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
vllm_is_batch_invariant
,
)
)
...
@@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
...
@@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
MoEPrepareAndFinalizeNoEP
,
)
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa: E501
rocm_aiter_grouped_topk
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
TopKWeightAndReduceNoOP
,
)
)
...
@@ -49,7 +44,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
...
@@ -49,7 +44,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
dequant_mxfp4
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
dequant_mxfp4
from
vllm.model_executor.layers.quantization.utils.mxfp6_utils
import
dequant_mxfp6
from
vllm.model_executor.layers.quantization.utils.mxfp6_utils
import
dequant_mxfp6
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
OCP_MX_Scheme
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
OCP_MX_Scheme
from
vllm.model_executor.utils
import
maybe_disable_graph_partition
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
...
@@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config(
...
@@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config(
return
config
return
config
def
vllm_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
,
)
return
topk_weights
,
topk_indices
def
dispatch_topk_func
(
use_rocm_aiter
:
bool
=
False
,
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
use_rocm_aiter
:
return
rocm_aiter_ops
.
topk_softmax
return
vllm_topk_softmax
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
indices_type
:
torch
.
dtype
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
size
()
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
if
indices_type
is
None
else
indices_type
,
device
=
hidden_states
.
device
,
)
token_expert_indices
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_func
=
dispatch_topk_func
(
use_rocm_aiter
=
rocm_aiter_ops
.
is_fused_moe_enabled
())
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
)
return
topk_weights
,
topk_ids
,
token_expert_indices
def
fused_topk_bias
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
n_routed_experts
=
gating_output
.
shape
[
-
1
]
scores
=
gating_output
.
softmax
(
dim
=-
1
)
scores_for_choice
=
scores
.
view
(
-
1
,
n_routed_experts
)
+
e_score_correction_bias
.
unsqueeze
(
0
)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted
=
vllm_is_batch_invariant
()
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
topk
,
dim
=-
1
,
sorted
=
use_sorted
)[
1
]
topk_weights
=
scores
.
gather
(
1
,
topk_indices
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_indices
.
to
(
torch
.
int32
)
# This is used by the Deepseek-V2 and Deepseek-V3 model
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
,
options
=
maybe_disable_graph_partition
(
current_platform
.
simple_compile_backend
),
)
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
envs
.
VLLM_USE_FUSED_MOE_GROUPED_TOPK
and
current_platform
.
is_cuda
()
and
num_expert_group
<=
32
and
topk
<=
32
and
e_score_correction_bias
is
not
None
):
return
fused_grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
e_score_correction_bias
=
e_score_correction_bias
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
)
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
"Number of tokens mismatch"
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
num_token
=
scores
.
size
(
0
)
if
e_score_correction_bias
is
not
None
:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores
=
scores
scores
=
scores
+
e_score_correction_bias
.
unsqueeze
(
0
)
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
)
)
else
:
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
)
# [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted
=
vllm_is_batch_invariant
()
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
use_sorted
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_token
,
num_expert_group
,
scores
.
size
(
-
1
)
//
num_expert_group
)
.
reshape
(
num_token
,
-
1
)
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
float
(
"-inf"
))
# [n, e]
if
e_score_correction_bias
is
not
None
:
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
use_sorted
)[
1
]
# Use original unbiased scores for the routing weights
topk_weights
=
original_scores
.
gather
(
1
,
topk_ids
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
use_sorted
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
if
routed_scaling_factor
!=
1.0
:
topk_weights
=
topk_weights
*
routed_scaling_factor
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
# --8<-- [start:grouped_topk]
@
CustomOp
.
register
(
"grouped_topk"
)
class
GroupedTopk
(
CustomOp
):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
# --8<-- [end:grouped_topk]
def
__init__
(
self
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
num_fused_shared_experts
:
int
=
0
,
)
->
None
:
super
().
__init__
()
self
.
native_impl
=
grouped_topk
self
.
topk
=
topk
self
.
renormalize
=
renormalize
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
scoring_func
=
scoring_func
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
num_fused_shared_experts
=
num_fused_shared_experts
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
native_impl
(
hidden_states
,
gating_output
,
self
.
topk
,
self
.
renormalize
,
self
.
num_expert_group
,
self
.
topk_group
,
self
.
scoring_func
,
self
.
routed_scaling_factor
,
e_score_correction_bias
,
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
forward_native
(
hidden_states
,
gating_output
,
e_score_correction_bias
)
def
forward_hip
(
self
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
rocm_aiter_ops
.
is_fused_moe_enabled
():
if
not
rocm_aiter_ops
.
is_fusion_moe_shared_experts_enabled
():
assert
self
.
num_fused_shared_experts
==
0
return
rocm_aiter_grouped_topk
(
hidden_states
,
gating_output
,
self
.
topk
,
self
.
renormalize
,
self
.
num_expert_group
,
self
.
topk_group
,
self
.
scoring_func
,
self
.
routed_scaling_factor
,
e_score_correction_bias
,
self
.
num_fused_shared_experts
,
)
else
:
return
self
.
forward_native
(
hidden_states
,
gating_output
,
e_score_correction_bias
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
eplb_map_to_physical_and_record
(
topk_ids
:
torch
.
Tensor
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
Returns:
The physical expert ids.
"""
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long
=
topk_ids
.
long
()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count
=
logical_replica_count
[
topk_ids_long
]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices
=
torch
.
arange
(
topk_ids
.
numel
(),
device
=
topk_ids
.
device
,
dtype
=
torch
.
long
).
reshape_as
(
topk_ids
)
# Compute pseudo-random indices by modulo
replica_indices
=
(
pos_indices
%
replica_count
).
unsqueeze
(
-
1
)
physical_ids
=
(
logical_to_physical_map
[
topk_ids_long
].
gather
(
-
1
,
replica_indices
).
squeeze
(
-
1
)
)
topk_ids
=
physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten
=
topk_ids
.
flatten
()
expert_load_view
.
scatter_add_
(
dim
=
0
,
index
=
topk_ids_flatten
.
long
(),
src
=
torch
.
ones_like
(
topk_ids_flatten
).
to
(
expert_load_view
),
)
return
topk_ids
def
fused_grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
e_score_correction_bias
:
torch
.
Tensor
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
"Number of tokens mismatch"
if
scoring_func
==
"sigmoid"
:
# Fully fused kernel path for sigmoid
topk_values
,
topk_indices
=
ops
.
grouped_topk
(
gating_output
,
# raw logits
num_expert_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
e_score_correction_bias
,
1
,
# scoring_func=1 for sigmoid
)
elif
scoring_func
==
"softmax"
:
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
topk_values
,
topk_indices
=
ops
.
grouped_topk
(
scores
,
# pre-computed scores
num_expert_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
e_score_correction_bias
,
0
,
# scoring_func=0 (no activation, scores already computed)
)
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
# Fused kernel outputs float32 values and int32 indices directly
return
topk_values
,
topk_indices
def
inplace_fused_experts
(
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
View file @
327a02d8
...
@@ -10,13 +10,13 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -10,13 +10,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe_router
import
(
FusedMoERouter
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoEPrepareAndFinalize
,
)
)
from
vllm.model_executor.layers.fused_moe.router.fused_moe_router
import
(
FusedMoERouter
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
...
...
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
View file @
327a02d8
...
@@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe_router
import
FusedMoERouter
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
,
FusedMoEModularKernel
,
FusedMoEPrepareAndFinalize
,
FusedMoEPrepareAndFinalize
,
)
)
from
vllm.model_executor.layers.fused_moe.router.fused_moe_router
import
(
FusedMoERouter
,
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
327a02d8
...
@@ -21,7 +21,7 @@ from vllm.distributed import (
...
@@ -21,7 +21,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
,
EplbState
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
...
@@ -31,14 +31,24 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -31,14 +31,24 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
RoutingMethodType
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe_router
import
FusedMoERouter
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_modular_method
import
(
FusedMoEModularMethod
,
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
init_aiter_topK_meta_data
,
init_aiter_topK_meta_data
,
)
)
from
vllm.model_executor.layers.fused_moe.routed_experts_capturer
import
(
from
vllm.model_executor.layers.fused_moe.routed_experts_capturer
import
(
RoutedExpertsCapturer
,
RoutedExpertsCapturer
,
)
)
from
vllm.model_executor.layers.fused_moe.routing_simulator
import
RoutingSimulator
from
vllm.model_executor.layers.fused_moe.router.router_factory
import
(
create_fused_moe_router
,
)
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
)
)
...
@@ -52,31 +62,6 @@ from vllm.utils.torch_utils import (
...
@@ -52,31 +62,6 @@ from vllm.utils.torch_utils import (
)
)
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
if
current_platform
.
is_cuda_alike
():
from
.fused_moe
import
eplb_map_to_physical_and_record
else
:
def
_eplb_map_to_physical_and_record
(
topk_ids
:
torch
.
Tensor
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# CPU fallback: no EPLB so just return as is
return
topk_ids
eplb_map_to_physical_and_record
=
_eplb_map_to_physical_and_record
from
vllm.model_executor.layers.fused_moe.fused_moe
import
GroupedTopk
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_modular_method
import
(
FusedMoEModularMethod
,
)
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -288,23 +273,6 @@ def maybe_roundup_hidden_size(
...
@@ -288,23 +273,6 @@ def maybe_roundup_hidden_size(
return
hidden_size
return
hidden_size
class
FusedMoERouterImpl
(
FusedMoERouter
):
def
__init__
(
self
,
layer
:
"FusedMoE"
):
super
().
__init__
()
self
.
layer
=
layer
@
property
def
routing_method_type
(
self
)
->
RoutingMethodType
:
return
self
.
layer
.
routing_method_type
def
select_experts
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
layer
.
_select_experts
(
hidden_states
,
router_logits
)
# --8<-- [start:fused_moe]
# --8<-- [start:fused_moe]
@
CustomOp
.
register
(
"fused_moe"
)
@
CustomOp
.
register
(
"fused_moe"
)
class
FusedMoE
(
CustomOp
):
class
FusedMoE
(
CustomOp
):
...
@@ -440,9 +408,7 @@ class FusedMoE(CustomOp):
...
@@ -440,9 +408,7 @@ class FusedMoE(CustomOp):
self
.
layer_name
=
prefix
self
.
layer_name
=
prefix
self
.
enable_eplb
=
enable_eplb
self
.
enable_eplb
=
enable_eplb
self
.
expert_load_view
:
torch
.
Tensor
|
None
=
None
self
.
eplb_state
=
EplbLayerState
()
self
.
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
self
.
logical_replica_count
:
torch
.
Tensor
|
None
=
None
self
.
expert_placement_strategy
:
ExpertPlacementStrategy
=
(
self
.
expert_placement_strategy
:
ExpertPlacementStrategy
=
(
vllm_config
.
parallel_config
.
expert_placement_strategy
vllm_config
.
parallel_config
.
expert_placement_strategy
)
)
...
@@ -538,6 +504,8 @@ class FusedMoE(CustomOp):
...
@@ -538,6 +504,8 @@ class FusedMoE(CustomOp):
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
renormalize
=
renormalize
# TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
self
.
use_grouped_topk
=
use_grouped_topk
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
...
@@ -547,46 +515,11 @@ class FusedMoE(CustomOp):
...
@@ -547,46 +515,11 @@ class FusedMoE(CustomOp):
self
.
scoring_func
=
scoring_func
self
.
scoring_func
=
scoring_func
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
e_score_correction_bias
=
e_score_correction_bias
# TODO(bnell): end attributes
self
.
apply_router_weight_on_input
=
apply_router_weight_on_input
self
.
apply_router_weight_on_input
=
apply_router_weight_on_input
self
.
activation
=
activation
self
.
activation
=
activation
self
.
_grouped_topk_impl
:
GroupedTopk
|
None
=
None
if
self
.
use_grouped_topk
:
assert
self
.
num_expert_group
is
not
None
assert
self
.
topk_group
is
not
None
self
.
_grouped_topk_impl
=
GroupedTopk
(
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
scoring_func
=
self
.
scoring_func
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
)
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for non-grouped topk."
)
# ToDo: Better logic to determine the routing method type
if
routing_method_type
is
not
None
:
self
.
routing_method_type
:
RoutingMethodType
=
routing_method_type
else
:
if
scoring_func
==
"sigmoid"
:
if
self
.
use_grouped_topk
:
self
.
routing_method_type
=
RoutingMethodType
.
DeepSeekV3
elif
self
.
top_k
==
1
:
self
.
routing_method_type
=
RoutingMethodType
.
Llama4
elif
self
.
scoring_func
==
"softmax"
:
self
.
routing_method_type
=
(
RoutingMethodType
.
Renormalize
if
not
self
.
renormalize
else
RoutingMethodType
.
RenormalizeNaive
)
else
:
self
.
routing_method_type
=
RoutingMethodType
.
TopK
self
.
moe_config
:
FusedMoEConfig
=
FusedMoEConfig
(
self
.
moe_config
:
FusedMoEConfig
=
FusedMoEConfig
(
num_experts
=
self
.
global_num_experts
,
num_experts
=
self
.
global_num_experts
,
experts_per_token
=
top_k
,
experts_per_token
=
top_k
,
...
@@ -637,8 +570,7 @@ class FusedMoE(CustomOp):
...
@@ -637,8 +570,7 @@ class FusedMoE(CustomOp):
# If you plan to add support for more quantization methods,
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
# please refer to the implementation in `Fp8MoEMethod`.
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"EPLB is not supported
{
self
.
quant_method
.
__class__
.
__name__
}
. "
f
"EPLB is not supported
{
self
.
quant_method
.
__class__
.
__name__
}
."
"EPLB is only supported for FP8 quantization for now."
)
)
moe_quant_params
=
{
moe_quant_params
=
{
...
@@ -663,7 +595,38 @@ class FusedMoE(CustomOp):
...
@@ -663,7 +595,38 @@ class FusedMoE(CustomOp):
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
self
.
router
=
FusedMoERouterImpl
(
self
)
# TODO(bnell): in next PR move capture back to layer
capture
:
Callable
[[
torch
.
Tensor
],
None
]
|
None
=
None
if
(
self
.
vllm_config
.
model_config
is
not
None
and
self
.
vllm_config
.
model_config
.
enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer
=
RoutedExpertsCapturer
.
get_instance
()
if
capturer
is
not
None
:
capture
=
lambda
topk_ids
:
capturer
.
capture
(
self
.
layer_id
,
topk_ids
)
self
.
router
=
create_fused_moe_router
(
top_k
=
top_k
,
global_num_experts
=
self
.
global_num_experts
,
eplb_state
=
self
.
eplb_state
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
enable_eplb
=
enable_eplb
,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter
=
lambda
:
self
.
quant_method
.
topk_indices_dtype
,
routing_method_type
=
routing_method_type
,
capture
=
capture
,
)
self
.
routing_method_type
:
RoutingMethodType
=
self
.
router
.
routing_method_type
# Note: maybe_init_modular_kernel should only be called by
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# prepare_communication_buffer_for_model.
...
@@ -1492,9 +1455,9 @@ class FusedMoE(CustomOp):
...
@@ -1492,9 +1455,9 @@ class FusedMoE(CustomOp):
This is used later in forward pass, where we get the expert mapping
This is used later in forward pass, where we get the expert mapping
and record the load metrics in `expert_load_view`.
and record the load metrics in `expert_load_view`.
"""
"""
self
.
expert_load_view
=
expert_load_view
[
moe_layer_idx
]
self
.
eplb_state
.
expert_load_view
=
expert_load_view
[
moe_layer_idx
]
self
.
logical_to_physical_map
=
logical_to_physical_map
[
moe_layer_idx
]
self
.
eplb_state
.
logical_to_physical_map
=
logical_to_physical_map
[
moe_layer_idx
]
self
.
logical_replica_count
=
logical_replica_count
[
moe_layer_idx
]
self
.
eplb_state
.
logical_replica_count
=
logical_replica_count
[
moe_layer_idx
]
def
ensure_moe_quant_config_init
(
self
):
def
ensure_moe_quant_config_init
(
self
):
if
self
.
quant_method
.
moe_quant_config
is
None
:
if
self
.
quant_method
.
moe_quant_config
is
None
:
...
@@ -1535,130 +1498,6 @@ class FusedMoE(CustomOp):
...
@@ -1535,130 +1498,6 @@ class FusedMoE(CustomOp):
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
)
)
def
_select_experts
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
fused_topk_bias
,
)
if
self
.
enable_eplb
:
if
self
.
quant_method
.
supports_eplb
:
if
self
.
expert_load_view
is
None
:
raise
ValueError
(
"enable_eplb=True requiere expert_load_view != None"
)
if
self
.
logical_to_physical_map
is
None
:
raise
ValueError
(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if
self
.
logical_replica_count
is
None
:
raise
ValueError
(
"enable_eplb=True requiere logical_replica_count != None"
)
else
:
raise
NotImplementedError
(
f
"EPLB is not supported for
{
self
.
quant_method
.
method_name
}
."
)
def
valid_grouping
()
->
bool
:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts
=
router_logits
.
shape
[
-
1
]
if
num_experts
<=
self
.
num_expert_group
:
return
False
return
num_experts
%
self
.
num_expert_group
==
0
indices_type
=
self
.
quant_method
.
topk_indices_dtype
# Check if we should use a routing simulation strategy
routing_strategy
=
envs
.
VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if
routing_strategy
!=
""
:
topk_weights
,
topk_ids
=
RoutingSimulator
.
simulate_routing
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
strategy_name
=
routing_strategy
,
top_k
=
self
.
top_k
,
indices_type
=
indices_type
,
)
# DeepSeekv2 uses grouped_top_k
elif
self
.
use_grouped_topk
and
valid_grouping
():
assert
self
.
_grouped_topk_impl
is
not
None
topk_weights
,
topk_ids
=
self
.
_grouped_topk_impl
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
)
elif
self
.
e_score_correction_bias
is
not
None
:
topk_weights
,
topk_ids
=
fused_topk_bias
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
e_score_correction_bias
=
self
.
e_score_correction_bias
.
data
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
)
if
self
.
routed_scaling_factor
!=
1.0
:
topk_weights
*=
self
.
routed_scaling_factor
elif
self
.
custom_routing_function
is
None
:
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
indices_type
=
indices_type
,
)
else
:
topk_weights
,
topk_ids
=
self
.
custom_routing_function
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
)
if
self
.
enable_eplb
:
topk_ids
=
eplb_map_to_physical_and_record
(
topk_ids
=
topk_ids
,
expert_load_view
=
self
.
expert_load_view
,
logical_to_physical_map
=
self
.
logical_to_physical_map
,
logical_replica_count
=
self
.
logical_replica_count
,
)
if
(
indices_type
is
not
None
)
and
topk_ids
.
dtype
!=
indices_type
:
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
assert
topk_ids
.
dtype
==
indices_type
or
indices_type
is
None
if
(
self
.
vllm_config
.
model_config
is
not
None
and
self
.
vllm_config
.
model_config
.
enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer
=
RoutedExpertsCapturer
.
get_instance
()
if
capturer
is
not
None
:
# in dummmy_run may be None
capturer
.
capture
(
# noqa
layer_id
=
self
.
layer_id
,
topk_ids
=
topk_ids
,
)
return
topk_weights
,
topk_ids
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
"""
"""
The shared_experts are typically computed using the RowParallelLinear
The shared_experts are typically computed using the RowParallelLinear
...
@@ -1761,8 +1600,12 @@ class FusedMoE(CustomOp):
...
@@ -1761,8 +1600,12 @@ class FusedMoE(CustomOp):
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
batched_hidden_states
is
not
None
assert
self
.
batched_hidden_states
is
not
None
assert
self
.
batched_router_logits
is
not
None
assert
self
.
batched_router_logits
is
not
None
assert
self
.
batched_hidden_states
.
dtype
==
full_hidden_states
.
dtype
assert
self
.
batched_hidden_states
.
dtype
==
full_hidden_states
.
dtype
,
(
assert
self
.
batched_router_logits
.
dtype
==
full_router_logits
.
dtype
f
"
{
self
.
batched_hidden_states
.
dtype
}
==
{
full_hidden_states
.
dtype
}
"
)
assert
self
.
batched_router_logits
.
dtype
==
full_router_logits
.
dtype
,
(
f
"
{
self
.
batched_router_logits
.
dtype
}
==
{
full_router_logits
.
dtype
}
"
)
# Check size compatibility.
# Check size compatibility.
assert
self
.
batched_hidden_states
.
size
(
-
1
)
==
full_hidden_states
.
size
(
-
1
)
assert
self
.
batched_hidden_states
.
size
(
-
1
)
==
full_hidden_states
.
size
(
-
1
)
assert
self
.
batched_router_logits
.
size
(
-
1
)
==
full_router_logits
.
size
(
-
1
)
assert
self
.
batched_router_logits
.
size
(
-
1
)
==
full_router_logits
.
size
(
-
1
)
...
@@ -2080,15 +1923,8 @@ class FusedMoE(CustomOp):
...
@@ -2080,15 +1923,8 @@ class FusedMoE(CustomOp):
f
"tp_size=
{
self
.
tp_size
}
,
\n
"
f
"tp_size=
{
self
.
tp_size
}
,
\n
"
f
"ep_size=
{
self
.
ep_size
}
, "
f
"ep_size=
{
self
.
ep_size
}
, "
f
"reduce_results=
{
self
.
reduce_results
}
, "
f
"reduce_results=
{
self
.
reduce_results
}
, "
f
"renormalize=
{
self
.
renormalize
}
, "
f
"use_grouped_topk=
{
self
.
use_grouped_topk
}
"
)
)
if
self
.
use_grouped_topk
:
s
+=
f
", num_expert_group=
{
self
.
num_expert_group
}
, topk_group=
{
self
.
topk_group
}
"
# noqa: E501
s
+=
f
", scoring_func='
{
self
.
scoring_func
}
', activation='
{
self
.
activation
}
'"
# noqa: E501
return
s
return
s
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment