"components/src/vscode:/vscode.git/clone" did not exist on "7d78fdad8dc7249b4940098ec77d0e4fbfeab1c2"
Unverified Commit 327a02d8 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Separate Router into OO Classes (#30623)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 2f03035a
...@@ -21,12 +21,12 @@ from vllm.distributed import ( ...@@ -21,12 +21,12 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import ( from .mk_objects import (
......
...@@ -15,10 +15,10 @@ from tests.kernels.moe.utils import ( ...@@ -15,10 +15,10 @@ from tests.kernels.moe.utils import (
from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel, invoke_moe_batched_triton_kernel,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl from vllm.triton_utils import tl
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
......
...@@ -11,13 +11,15 @@ from tests.kernels.quant_utils import ( ...@@ -11,13 +11,15 @@ from tests.kernels.quant_utils import (
) )
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, _valid_deep_gemm_shape,
deep_gemm_moe_fp8, deep_gemm_moe_fp8,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe, modular_triton_fused_moe,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import ( ...@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8, CutlassExpertsFp8,
run_cutlass_moe_fp8, run_cutlass_moe_fp8,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
......
...@@ -12,6 +12,7 @@ from tests.kernels.quantization.nvfp4_utils import ( ...@@ -12,6 +12,7 @@ from tests.kernels.quantization.nvfp4_utils import (
from tests.kernels.utils import torch_moe from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe, is_valid_flashinfer_cutlass_fused_moe,
...@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( ...@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
create_flashinfer_prepare_finalize, create_flashinfer_prepare_finalize,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
......
...@@ -14,7 +14,7 @@ from vllm.config import ( ...@@ -14,7 +14,7 @@ from vllm.config import (
get_cached_compilation_config, get_cached_compilation_config,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopk, GroupedTopk,
fused_grouped_topk, fused_grouped_topk,
) )
......
...@@ -24,6 +24,9 @@ from vllm._aiter_ops import rocm_aiter_ops ...@@ -24,6 +24,9 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import (
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
...@@ -34,7 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( ...@@ -34,7 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, fused_marlin_moe,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe, modular_triton_fused_moe,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_permute,
......
...@@ -13,11 +13,11 @@ from tests.kernels.quantization.nvfp4_utils import ( ...@@ -13,11 +13,11 @@ from tests.kernels.quantization.nvfp4_utils import (
from tests.kernels.utils import torch_moe from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4, CutlassExpertsFp4,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
......
...@@ -8,9 +8,9 @@ import torch ...@@ -8,9 +8,9 @@ import torch
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import pytest
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.models.llama4 import Llama4MoE
# Test parameters
MK_S = [(32, 256), (64, 512)]
TOP_KS = [2, 4, 6]
NUM_EXPERTS = [8, 16, 64]
def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerState:
if not enable_eplb:
return EplbLayerState()
# Initialize EPLB state with proper tensors for testing
# For testing purposes, we use a simple 1:1 mapping (no redundant experts)
# expert_load_view: tracks load on each expert (shape: num_experts)
expert_load_view = torch.zeros(global_num_experts, dtype=torch.int32, device="cuda")
# logical_to_physical_map: maps logical experts to physical experts
# Shape: (num_logical_experts, max_slots)
# For testing, use simple 1:1 mapping with single slot per expert
logical_to_physical_map = torch.arange(
global_num_experts, dtype=torch.int64, device="cuda"
).unsqueeze(-1)
# logical_replica_count: number of replicas per logical expert
# Shape: (num_logical_experts,)
# For testing, each logical expert has exactly 1 replica
logical_replica_count = torch.ones(
global_num_experts, dtype=torch.int64, device="cuda"
)
return EplbLayerState(
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def make_test_data(
m: int, k: int, num_experts: int
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = torch.randn((m, k), device="cuda") / 10
logits = torch.randn((m, num_experts), device="cuda")
return hidden_states, logits
def make_e_score_correction_bias(
e_score_correction_bias_val: float,
num_experts: int,
) -> torch.Tensor:
# return torch.randn(num_experts, device="cuda") * e_score_correction_bias_val
return torch.full(
(num_experts,), e_score_correction_bias_val, device="cuda", dtype=torch.float32
)
def assert_routing_results_close(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
baseline_weights: torch.Tensor,
baseline_ids: torch.Tensor,
rtol: float = 1e-3,
atol: float = 1e-3,
):
"""
Compare routing results, sorting by expert ID first to handle non-deterministic
ordering from sorted=False in topk.
"""
# Sort both results by expert IDs for consistent comparison
sorted_indices_actual = torch.argsort(topk_ids, dim=-1)
sorted_indices_baseline = torch.argsort(baseline_ids.to(topk_ids.dtype), dim=-1)
# Gather the sorted values
topk_ids_sorted = torch.gather(topk_ids, 1, sorted_indices_actual)
topk_weights_sorted = torch.gather(topk_weights, 1, sorted_indices_actual)
baseline_ids_sorted = torch.gather(
baseline_ids.to(topk_ids.dtype), 1, sorted_indices_baseline
)
baseline_weights_sorted = torch.gather(baseline_weights, 1, sorted_indices_baseline)
# Compare
torch.testing.assert_close(topk_ids_sorted, baseline_ids_sorted)
torch.testing.assert_close(
topk_weights_sorted, baseline_weights_sorted, rtol=rtol, atol=atol
)
def baseline_fused_topk(
router_logits: torch.Tensor, top_k: int, renormalize: bool
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for standard fused top-k routing.
Algorithm:
1. Apply softmax to router logits
2. Select top-k experts
3. Optionally renormalize the weights
"""
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
# Use sorted=False to match vllm implementation (vllm_is_batch_invariant
# defaults to False)
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_fused_topk_bias(
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for fused top-k with bias correction.
Algorithm:
1. Apply softmax to router logits
2. Add bias to scores for expert selection
3. Select top-k experts using biased scores
4. Get weights from original (unbiased) scores
5. Apply routed scaling factor
6. Optionally renormalize the weights
"""
# Apply softmax to get scores
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
# Add bias for expert selection
scores_for_choice = scores + e_score_correction_bias.unsqueeze(0)
# Select top-k using biased scores (sorted=False to match implementation)
topk_ids = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
# Get weights from original scores (not biased)
topk_weights = scores.gather(1, topk_ids)
# Renormalize if needed (BEFORE applying scaling factor)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# Apply scaling factor (AFTER renormalization, if applicable)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_grouped_topk(
router_logits: torch.Tensor,
top_k: int,
num_expert_group: int,
topk_group: int,
scoring_func: str,
renormalize: bool,
e_score_correction_bias: torch.Tensor | None,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for grouped top-k routing (e.g., DeepSeek).
Algorithm:
1. Apply scoring function (softmax or sigmoid)
2. Optionally add bias
3. Select top-k groups based on max scores within each group
4. Mask scores to only include selected groups
5. Select top-k experts from masked scores
6. Apply scaling factor
7. Optionally renormalize
"""
num_token = router_logits.shape[0]
# Apply scoring function
if scoring_func == "softmax":
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
elif scoring_func == "sigmoid":
scores = torch.sigmoid(router_logits.float())
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Handle bias correction
if e_score_correction_bias is not None:
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
# For bias case, use sum of top-2 scores in each group
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
# Use max score in each group
group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values
# Select top-k groups
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
# Create mask for selected groups
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Expand mask to all experts
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
)
# Mask scores (set non-selected to -inf)
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
# Select top-k experts
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)[1]
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)
# Renormalize if needed
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# Apply scaling factor
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_custom_llama4(
router_logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for Llama4 custom routing.
Algorithm:
1. Select top-k expert indices (without softmax)
2. Apply sigmoid to the selected scores
"""
router_scores, router_indices = torch.topk(router_logits, top_k, dim=-1)
router_scores = torch.sigmoid(router_scores.float())
return router_scores.to(torch.float32), router_indices.to(torch.int32)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
def test_fused_topk(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_fused_topk(
router_logits, top_k, renormalize
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("e_score_correction_bias_val", [0.9])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1])
def test_fused_topk_bias(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
e_score_correction_bias_val: float,
routed_scaling_factor: float,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
e_score_correction_bias = make_e_score_correction_bias(
e_score_correction_bias_val,
global_num_experts,
)
router = create_fused_moe_router(
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_fused_topk_bias(
router_logits,
top_k,
renormalize,
e_score_correction_bias,
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize(
"global_num_experts,num_expert_group,topk_group",
[
(64, 8, 4), # 8 groups of 8 experts, select 4 groups
(32, 4, 2), # 4 groups of 8 experts, select 2 groups
],
)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("e_score_correction_bias_val", [0.9])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1])
@pytest.mark.parametrize("scoring_func", ["sigmoid", "softmax"])
def test_grouped_topk(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
num_expert_group: int,
topk_group: int,
scoring_func: str,
e_score_correction_bias_val: float,
routed_scaling_factor: float,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
e_score_correction_bias = make_e_score_correction_bias(
e_score_correction_bias_val,
global_num_experts,
)
routing_method_type = None
if scoring_func == "llama4":
routing_method_type = RoutingMethodType.Llama4
scoring_func = "sigmoid"
router = create_fused_moe_router(
use_grouped_topk=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routing_method_type=routing_method_type,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_grouped_topk(
router_logits,
top_k,
num_expert_group,
topk_group,
scoring_func,
renormalize,
e_score_correction_bias,
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("custom_routing_function", [Llama4MoE.custom_routing_function])
def test_custom(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
custom_routing_function: Callable,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=global_num_experts,
custom_routing_function=custom_routing_function,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline (Llama4 uses sigmoid)
baseline_weights, baseline_ids = baseline_custom_llama4(router_logits, top_k)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
# TODO: is other test sufficient?
# # See tests/test_routing_simulatator.py
# @pytest.mark.parametrize("m,k", MK_S)
# @pytest.mark.parametrize("top_k", TOP_KS)
# @pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
# @pytest.mark.parametrize("renormalize", [False, True])
# @pytest.mark.parametrize("enable_eplb", [False, True])
# @pytest.mark.parameterize("strategy", ["uniform_random", "normal_routing"])
# def test_simulated(
# m: int,
# k: int,
# top_k: int,
# global_num_experts: int,
# renormalize: bool,
# enable_eplb: bool,
# strategy: str,
# monkeypatch,
# ):
# eplb_state = setup_eplb_state(enable_eplb)
# monkeypatch.setenv("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", strategy)
# router = create_fused_moe_router(
# top_k=top_k,
# global_num_experts=global_num_experts,
# enable_eplb=enable_eplb,
# eplb_state=eplb_state,
# )
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
...@@ -19,7 +19,7 @@ from vllm.distributed import ( ...@@ -19,7 +19,7 @@ from vllm.distributed import (
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.fused_moe.routing_simulator import ( from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
DistributionBasedRouting, DistributionBasedRouting,
RoutingSimulator, RoutingSimulator,
) )
...@@ -109,40 +109,44 @@ def test_routing_strategy_integration(monkeypatch, device): ...@@ -109,40 +109,44 @@ def test_routing_strategy_integration(monkeypatch, device):
tensor_model_parallel_size=1, tensor_model_parallel_size=1,
pipeline_model_parallel_size=1, pipeline_model_parallel_size=1,
) )
fused_moe = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=0,
use_grouped_topk=False,
renormalize=True,
)
for strategy in strategies:
# Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
monkeypatch.setenv(env_name, strategy)
# Force reload of environment variable for strategy in strategies:
envs.environment_variables[env_name] = lambda s=strategy: s fused_moe = FusedMoE(
num_experts=num_experts,
# Test the select_experts method top_k=top_k,
topk_weights, topk_ids = fused_moe.router.select_experts( hidden_size=hidden_size,
hidden_states=hidden_states, intermediate_size=0,
router_logits=router_logits, use_grouped_topk=False,
) renormalize=True,
prefix=strategy,
# Verify output shapes )
assert topk_weights.shape == (num_tokens, top_k), (
f"Wrong weights shape for {strategy}" # Set environment variable
) env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
assert topk_ids.shape == (num_tokens, top_k), f"Wrong ids shape for {strategy}" monkeypatch.setenv(env_name, strategy)
# Verify expert IDs are valid # Force reload of environment variable
assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}" envs.environment_variables[env_name] = lambda s=strategy: s
assert topk_ids.max() < num_experts, (
f"Invalid expert ID (too large) for {strategy}" # Test the select_experts method
) topk_weights, topk_ids = fused_moe.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
# Verify output shapes
assert topk_weights.shape == (num_tokens, top_k), (
f"Wrong weights shape for {strategy}"
)
assert topk_ids.shape == (num_tokens, top_k), (
f"Wrong ids shape for {strategy}"
)
# Verify expert IDs are valid
assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}"
assert topk_ids.max() < num_experts, (
f"Invalid expert ID (too large) for {strategy}"
)
def test_distribution_based_routing_with_custom_strategy(): def test_distribution_based_routing_with_custom_strategy():
......
...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.activation import ( ...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.activation import (
ReLUSquaredActivation, ReLUSquaredActivation,
SiluAndMul, SiluAndMul,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
dispatch_topk_func, dispatch_topk_func,
vllm_topk_softmax, vllm_topk_softmax,
) )
......
...@@ -1158,6 +1158,15 @@ class EplbState: ...@@ -1158,6 +1158,15 @@ class EplbState:
return self._allreduce_list(load_pass_list) return self._allreduce_list(load_pass_list)
@dataclass
class EplbLayerState:
"""Runtime EPLB data stored in the MoE layer."""
expert_load_view: torch.Tensor | None = None
logical_to_physical_map: torch.Tensor | None = None
logical_replica_count: torch.Tensor | None = None
def _node_count_with_rank_mapping( def _node_count_with_rank_mapping(
pg: ProcessGroup | StatelessProcessGroup, pg: ProcessGroup | StatelessProcessGroup,
rank_mapping: dict[int, int], rank_mapping: dict[int, int],
......
...@@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
...@@ -23,6 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -23,6 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
...@@ -83,13 +83,17 @@ if HAS_TRITON: ...@@ -83,13 +83,17 @@ if HAS_TRITON:
BatchedTritonExperts, BatchedTritonExperts,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
TritonExperts, TritonExperts,
TritonWNA16Experts, TritonWNA16Experts,
fused_experts, fused_experts,
fused_topk,
get_config_file_name, get_config_file_name,
) )
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
fused_topk,
)
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopk,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
......
...@@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum): ...@@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum):
RenormalizeNaive = (4,) RenormalizeNaive = (4,)
# TopK: TopK (no softmax) # TopK: TopK (no softmax)
TopK = (5,) TopK = (5,)
# Custom
Custom = (6,)
# Simulated
Simulated = (7,)
# Unspecified # Unspecified
Unspecified = 6.0 Unspecified = 8.0
@dataclass @dataclass
......
...@@ -13,9 +13,7 @@ import torch ...@@ -13,9 +13,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
...@@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
...@@ -49,7 +44,6 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -49,7 +44,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
...@@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config( ...@@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config(
return config return config
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(
dynamic=True,
backend=current_platform.simple_compile_backend,
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
)
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if (
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
and current_platform.is_cuda()
and num_expert_group <= 32
and topk <= 32
and e_score_correction_bias is not None
):
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# --8<-- [start:grouped_topk]
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
# --8<-- [end:grouped_topk]
def __init__(
self,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
num_fused_shared_experts: int = 0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
self.topk = topk
self.renormalize = renormalize
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_fused_shared_experts = num_fused_shared_experts
def forward_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.native_impl(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
def forward_hip(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
return rocm_aiter_grouped_topk(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
self.num_fused_shared_experts,
)
else:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
Returns:
The physical expert ids.
"""
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = (
logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1)
)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
return topk_ids
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "sigmoid":
# Fully fused kernel path for sigmoid
topk_values, topk_indices = ops.grouped_topk(
gating_output, # raw logits
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores = torch.softmax(gating_output, dim=-1)
topk_values, topk_indices = ops.grouped_topk(
scores, # pre-computed scores
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
0, # scoring_func=0 (no activation, scores already computed)
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Fused kernel outputs float32 values and int32 indices directly
return topk_values, topk_indices
def inplace_fused_experts( def inplace_fused_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
......
...@@ -10,13 +10,13 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -10,13 +10,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
......
...@@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel, FusedMoEModularKernel,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -21,7 +21,7 @@ from vllm.distributed import ( ...@@ -21,7 +21,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -31,14 +31,24 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -31,14 +31,24 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data, init_aiter_topK_meta_data,
) )
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer, RoutedExpertsCapturer,
) )
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
) )
...@@ -52,31 +62,6 @@ from vllm.utils.torch_utils import ( ...@@ -52,31 +62,6 @@ from vllm.utils.torch_utils import (
) )
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
from .fused_moe import eplb_map_to_physical_and_record
else:
def _eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -288,23 +273,6 @@ def maybe_roundup_hidden_size( ...@@ -288,23 +273,6 @@ def maybe_roundup_hidden_size(
return hidden_size return hidden_size
class FusedMoERouterImpl(FusedMoERouter):
def __init__(self, layer: "FusedMoE"):
super().__init__()
self.layer = layer
@property
def routing_method_type(self) -> RoutingMethodType:
return self.layer.routing_method_type
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits)
# --8<-- [start:fused_moe] # --8<-- [start:fused_moe]
@CustomOp.register("fused_moe") @CustomOp.register("fused_moe")
class FusedMoE(CustomOp): class FusedMoE(CustomOp):
...@@ -440,9 +408,7 @@ class FusedMoE(CustomOp): ...@@ -440,9 +408,7 @@ class FusedMoE(CustomOp):
self.layer_name = prefix self.layer_name = prefix
self.enable_eplb = enable_eplb self.enable_eplb = enable_eplb
self.expert_load_view: torch.Tensor | None = None self.eplb_state = EplbLayerState()
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
self.expert_placement_strategy: ExpertPlacementStrategy = ( self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy vllm_config.parallel_config.expert_placement_strategy
) )
...@@ -538,6 +504,8 @@ class FusedMoE(CustomOp): ...@@ -538,6 +504,8 @@ class FusedMoE(CustomOp):
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.renormalize = renormalize self.renormalize = renormalize
# TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
self.use_grouped_topk = use_grouped_topk self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk: if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
...@@ -547,46 +515,11 @@ class FusedMoE(CustomOp): ...@@ -547,46 +515,11 @@ class FusedMoE(CustomOp):
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation self.activation = activation
self._grouped_topk_impl: GroupedTopk | None = None
if self.use_grouped_topk:
assert self.num_expert_group is not None
assert self.topk_group is not None
self._grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type: RoutingMethodType = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
elif self.scoring_func == "softmax":
self.routing_method_type = (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
else:
self.routing_method_type = RoutingMethodType.TopK
self.moe_config: FusedMoEConfig = FusedMoEConfig( self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,
...@@ -637,8 +570,7 @@ class FusedMoE(CustomOp): ...@@ -637,8 +570,7 @@ class FusedMoE(CustomOp):
# If you plan to add support for more quantization methods, # If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`. # please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError( raise NotImplementedError(
f"EPLB is not supported {self.quant_method.__class__.__name__}. " f"EPLB is not supported {self.quant_method.__class__.__name__}."
"EPLB is only supported for FP8 quantization for now."
) )
moe_quant_params = { moe_quant_params = {
...@@ -663,7 +595,38 @@ class FusedMoE(CustomOp): ...@@ -663,7 +595,38 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None
self.router = FusedMoERouterImpl(self) # TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
eplb_state=self.eplb_state,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
routing_method_type=routing_method_type,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Note: maybe_init_modular_kernel should only be called by # Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model. # prepare_communication_buffer_for_model.
...@@ -1492,9 +1455,9 @@ class FusedMoE(CustomOp): ...@@ -1492,9 +1455,9 @@ class FusedMoE(CustomOp):
This is used later in forward pass, where we get the expert mapping This is used later in forward pass, where we get the expert mapping
and record the load metrics in `expert_load_view`. and record the load metrics in `expert_load_view`.
""" """
self.expert_load_view = expert_load_view[moe_layer_idx] self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx]
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx] self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx]
def ensure_moe_quant_config_init(self): def ensure_moe_quant_config_init(self):
if self.quant_method.moe_quant_config is None: if self.quant_method.moe_quant_config is None:
...@@ -1535,130 +1498,6 @@ class FusedMoE(CustomOp): ...@@ -1535,130 +1498,6 @@ class FusedMoE(CustomOp):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
fused_topk_bias,
)
if self.enable_eplb:
if self.quant_method.supports_eplb:
if self.expert_load_view is None:
raise ValueError(
"enable_eplb=True requiere expert_load_view != None"
)
if self.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if self.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
else:
raise NotImplementedError(
f"EPLB is not supported for {self.quant_method.method_name}."
)
def valid_grouping() -> bool:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts = router_logits.shape[-1]
if num_experts <= self.num_expert_group:
return False
return num_experts % self.num_expert_group == 0
indices_type = self.quant_method.topk_indices_dtype
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
top_k=self.top_k,
indices_type=indices_type,
)
# DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk and valid_grouping():
assert self._grouped_topk_impl is not None
topk_weights, topk_ids = self._grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias,
)
elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
else:
topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if (indices_type is not None) and topk_ids.dtype != indices_type:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None: # in dummmy_run may be None
capturer.capture( # noqa
layer_id=self.layer_id,
topk_ids=topk_ids,
)
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool: def must_reduce_shared_expert_outputs(self) -> bool:
""" """
The shared_experts are typically computed using the RowParallelLinear The shared_experts are typically computed using the RowParallelLinear
...@@ -1761,8 +1600,12 @@ class FusedMoE(CustomOp): ...@@ -1761,8 +1600,12 @@ class FusedMoE(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
assert self.batched_router_logits.dtype == full_router_logits.dtype f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
)
# Check size compatibility. # Check size compatibility.
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1) assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
...@@ -2080,15 +1923,8 @@ class FusedMoE(CustomOp): ...@@ -2080,15 +1923,8 @@ class FusedMoE(CustomOp):
f"tp_size={self.tp_size},\n" f"tp_size={self.tp_size},\n"
f"ep_size={self.ep_size}, " f"ep_size={self.ep_size}, "
f"reduce_results={self.reduce_results}, " f"reduce_results={self.reduce_results}, "
f"renormalize={self.renormalize}, "
f"use_grouped_topk={self.use_grouped_topk}"
) )
if self.use_grouped_topk:
s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
return s return s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment