"container/Dockerfile.trtllm" did not exist on "ad8ad66b152722c368cd58dd376d7d9e147230dd"
Unverified Commit 327a02d8 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Separate Router into OO Classes (#30623)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 2f03035a
......@@ -21,12 +21,12 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (
......
......@@ -15,10 +15,10 @@ from tests.kernels.moe.utils import (
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
from vllm.triton_utils import tl
from vllm.utils.torch_utils import set_random_seed
......
......@@ -11,13 +11,15 @@ from tests.kernels.quant_utils import (
)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
)
from vllm.platforms import current_platform
......
......@@ -10,6 +10,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
......@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8,
run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
......
......@@ -12,6 +12,7 @@ from tests.kernels.quantization.nvfp4_utils import (
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
......@@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
......
......@@ -14,7 +14,7 @@ from vllm.config import (
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopk,
fused_grouped_topk,
)
......
......@@ -24,6 +24,9 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import (
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
int4_w4a16_moe_quant_config,
......@@ -34,7 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
......
......@@ -9,7 +9,7 @@ import numpy as np
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute,
......
......@@ -13,11 +13,11 @@ from tests.kernels.quantization.nvfp4_utils import (
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
......
......@@ -8,9 +8,9 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import pytest
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.models.llama4 import Llama4MoE
# Test parameters
MK_S = [(32, 256), (64, 512)]
TOP_KS = [2, 4, 6]
NUM_EXPERTS = [8, 16, 64]
def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerState:
if not enable_eplb:
return EplbLayerState()
# Initialize EPLB state with proper tensors for testing
# For testing purposes, we use a simple 1:1 mapping (no redundant experts)
# expert_load_view: tracks load on each expert (shape: num_experts)
expert_load_view = torch.zeros(global_num_experts, dtype=torch.int32, device="cuda")
# logical_to_physical_map: maps logical experts to physical experts
# Shape: (num_logical_experts, max_slots)
# For testing, use simple 1:1 mapping with single slot per expert
logical_to_physical_map = torch.arange(
global_num_experts, dtype=torch.int64, device="cuda"
).unsqueeze(-1)
# logical_replica_count: number of replicas per logical expert
# Shape: (num_logical_experts,)
# For testing, each logical expert has exactly 1 replica
logical_replica_count = torch.ones(
global_num_experts, dtype=torch.int64, device="cuda"
)
return EplbLayerState(
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def make_test_data(
m: int, k: int, num_experts: int
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = torch.randn((m, k), device="cuda") / 10
logits = torch.randn((m, num_experts), device="cuda")
return hidden_states, logits
def make_e_score_correction_bias(
e_score_correction_bias_val: float,
num_experts: int,
) -> torch.Tensor:
# return torch.randn(num_experts, device="cuda") * e_score_correction_bias_val
return torch.full(
(num_experts,), e_score_correction_bias_val, device="cuda", dtype=torch.float32
)
def assert_routing_results_close(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
baseline_weights: torch.Tensor,
baseline_ids: torch.Tensor,
rtol: float = 1e-3,
atol: float = 1e-3,
):
"""
Compare routing results, sorting by expert ID first to handle non-deterministic
ordering from sorted=False in topk.
"""
# Sort both results by expert IDs for consistent comparison
sorted_indices_actual = torch.argsort(topk_ids, dim=-1)
sorted_indices_baseline = torch.argsort(baseline_ids.to(topk_ids.dtype), dim=-1)
# Gather the sorted values
topk_ids_sorted = torch.gather(topk_ids, 1, sorted_indices_actual)
topk_weights_sorted = torch.gather(topk_weights, 1, sorted_indices_actual)
baseline_ids_sorted = torch.gather(
baseline_ids.to(topk_ids.dtype), 1, sorted_indices_baseline
)
baseline_weights_sorted = torch.gather(baseline_weights, 1, sorted_indices_baseline)
# Compare
torch.testing.assert_close(topk_ids_sorted, baseline_ids_sorted)
torch.testing.assert_close(
topk_weights_sorted, baseline_weights_sorted, rtol=rtol, atol=atol
)
def baseline_fused_topk(
router_logits: torch.Tensor, top_k: int, renormalize: bool
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for standard fused top-k routing.
Algorithm:
1. Apply softmax to router logits
2. Select top-k experts
3. Optionally renormalize the weights
"""
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
# Use sorted=False to match vllm implementation (vllm_is_batch_invariant
# defaults to False)
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_fused_topk_bias(
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for fused top-k with bias correction.
Algorithm:
1. Apply softmax to router logits
2. Add bias to scores for expert selection
3. Select top-k experts using biased scores
4. Get weights from original (unbiased) scores
5. Apply routed scaling factor
6. Optionally renormalize the weights
"""
# Apply softmax to get scores
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
# Add bias for expert selection
scores_for_choice = scores + e_score_correction_bias.unsqueeze(0)
# Select top-k using biased scores (sorted=False to match implementation)
topk_ids = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
# Get weights from original scores (not biased)
topk_weights = scores.gather(1, topk_ids)
# Renormalize if needed (BEFORE applying scaling factor)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# Apply scaling factor (AFTER renormalization, if applicable)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_grouped_topk(
router_logits: torch.Tensor,
top_k: int,
num_expert_group: int,
topk_group: int,
scoring_func: str,
renormalize: bool,
e_score_correction_bias: torch.Tensor | None,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for grouped top-k routing (e.g., DeepSeek).
Algorithm:
1. Apply scoring function (softmax or sigmoid)
2. Optionally add bias
3. Select top-k groups based on max scores within each group
4. Mask scores to only include selected groups
5. Select top-k experts from masked scores
6. Apply scaling factor
7. Optionally renormalize
"""
num_token = router_logits.shape[0]
# Apply scoring function
if scoring_func == "softmax":
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
elif scoring_func == "sigmoid":
scores = torch.sigmoid(router_logits.float())
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Handle bias correction
if e_score_correction_bias is not None:
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
# For bias case, use sum of top-2 scores in each group
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
# Use max score in each group
group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values
# Select top-k groups
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
# Create mask for selected groups
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Expand mask to all experts
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
)
# Mask scores (set non-selected to -inf)
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
# Select top-k experts
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)[1]
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)
# Renormalize if needed
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# Apply scaling factor
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_custom_llama4(
router_logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for Llama4 custom routing.
Algorithm:
1. Select top-k expert indices (without softmax)
2. Apply sigmoid to the selected scores
"""
router_scores, router_indices = torch.topk(router_logits, top_k, dim=-1)
router_scores = torch.sigmoid(router_scores.float())
return router_scores.to(torch.float32), router_indices.to(torch.int32)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
def test_fused_topk(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_fused_topk(
router_logits, top_k, renormalize
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("e_score_correction_bias_val", [0.9])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1])
def test_fused_topk_bias(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
e_score_correction_bias_val: float,
routed_scaling_factor: float,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
e_score_correction_bias = make_e_score_correction_bias(
e_score_correction_bias_val,
global_num_experts,
)
router = create_fused_moe_router(
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_fused_topk_bias(
router_logits,
top_k,
renormalize,
e_score_correction_bias,
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize(
"global_num_experts,num_expert_group,topk_group",
[
(64, 8, 4), # 8 groups of 8 experts, select 4 groups
(32, 4, 2), # 4 groups of 8 experts, select 2 groups
],
)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("e_score_correction_bias_val", [0.9])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1])
@pytest.mark.parametrize("scoring_func", ["sigmoid", "softmax"])
def test_grouped_topk(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
num_expert_group: int,
topk_group: int,
scoring_func: str,
e_score_correction_bias_val: float,
routed_scaling_factor: float,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
e_score_correction_bias = make_e_score_correction_bias(
e_score_correction_bias_val,
global_num_experts,
)
routing_method_type = None
if scoring_func == "llama4":
routing_method_type = RoutingMethodType.Llama4
scoring_func = "sigmoid"
router = create_fused_moe_router(
use_grouped_topk=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routing_method_type=routing_method_type,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_grouped_topk(
router_logits,
top_k,
num_expert_group,
topk_group,
scoring_func,
renormalize,
e_score_correction_bias,
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("custom_routing_function", [Llama4MoE.custom_routing_function])
def test_custom(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
custom_routing_function: Callable,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=global_num_experts,
custom_routing_function=custom_routing_function,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline (Llama4 uses sigmoid)
baseline_weights, baseline_ids = baseline_custom_llama4(router_logits, top_k)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
# TODO: is other test sufficient?
# # See tests/test_routing_simulatator.py
# @pytest.mark.parametrize("m,k", MK_S)
# @pytest.mark.parametrize("top_k", TOP_KS)
# @pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
# @pytest.mark.parametrize("renormalize", [False, True])
# @pytest.mark.parametrize("enable_eplb", [False, True])
# @pytest.mark.parameterize("strategy", ["uniform_random", "normal_routing"])
# def test_simulated(
# m: int,
# k: int,
# top_k: int,
# global_num_experts: int,
# renormalize: bool,
# enable_eplb: bool,
# strategy: str,
# monkeypatch,
# ):
# eplb_state = setup_eplb_state(enable_eplb)
# monkeypatch.setenv("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", strategy)
# router = create_fused_moe_router(
# top_k=top_k,
# global_num_experts=global_num_experts,
# enable_eplb=enable_eplb,
# eplb_state=eplb_state,
# )
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
......@@ -19,7 +19,7 @@ from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import (
from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
DistributionBasedRouting,
RoutingSimulator,
)
......@@ -109,6 +109,8 @@ def test_routing_strategy_integration(monkeypatch, device):
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
for strategy in strategies:
fused_moe = FusedMoE(
num_experts=num_experts,
top_k=top_k,
......@@ -116,9 +118,9 @@ def test_routing_strategy_integration(monkeypatch, device):
intermediate_size=0,
use_grouped_topk=False,
renormalize=True,
prefix=strategy,
)
for strategy in strategies:
# Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
monkeypatch.setenv(env_name, strategy)
......@@ -136,7 +138,9 @@ def test_routing_strategy_integration(monkeypatch, device):
assert topk_weights.shape == (num_tokens, top_k), (
f"Wrong weights shape for {strategy}"
)
assert topk_ids.shape == (num_tokens, top_k), f"Wrong ids shape for {strategy}"
assert topk_ids.shape == (num_tokens, top_k), (
f"Wrong ids shape for {strategy}"
)
# Verify expert IDs are valid
assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}"
......
......@@ -17,7 +17,7 @@ from vllm.model_executor.layers.activation import (
ReLUSquaredActivation,
SiluAndMul,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
dispatch_topk_func,
vllm_topk_softmax,
)
......
......@@ -1158,6 +1158,15 @@ class EplbState:
return self._allreduce_list(load_pass_list)
@dataclass
class EplbLayerState:
"""Runtime EPLB data stored in the MoE layer."""
expert_load_view: torch.Tensor | None = None
logical_to_physical_map: torch.Tensor | None = None
logical_replica_count: torch.Tensor | None = None
def _node_count_with_rank_mapping(
pg: ProcessGroup | StatelessProcessGroup,
rank_mapping: dict[int, int],
......
......@@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoeWeightScaleSupported,
......@@ -23,6 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
......@@ -83,13 +83,17 @@ if HAS_TRITON:
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
TritonExperts,
TritonWNA16Experts,
fused_experts,
fused_topk,
get_config_file_name,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
fused_topk,
)
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopk,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
......
......@@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum):
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# Custom
Custom = (6,)
# Simulated
Simulated = (7,)
# Unspecified
Unspecified = 6.0
Unspecified = 8.0
@dataclass
......
......@@ -13,9 +13,7 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
......@@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
......@@ -49,7 +44,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
......@@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config(
return config
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(
dynamic=True,
backend=current_platform.simple_compile_backend,
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
)
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if (
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
and current_platform.is_cuda()
and num_expert_group <= 32
and topk <= 32
and e_score_correction_bias is not None
):
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# --8<-- [start:grouped_topk]
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
# --8<-- [end:grouped_topk]
def __init__(
self,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
num_fused_shared_experts: int = 0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
self.topk = topk
self.renormalize = renormalize
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_fused_shared_experts = num_fused_shared_experts
def forward_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.native_impl(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
def forward_hip(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
return rocm_aiter_grouped_topk(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
self.num_fused_shared_experts,
)
else:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
Returns:
The physical expert ids.
"""
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = (
logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1)
)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
return topk_ids
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "sigmoid":
# Fully fused kernel path for sigmoid
topk_values, topk_indices = ops.grouped_topk(
gating_output, # raw logits
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores = torch.softmax(gating_output, dim=-1)
topk_values, topk_indices = ops.grouped_topk(
scores, # pre-computed scores
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
0, # scoring_func=0 (no activation, scores already computed)
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Fused kernel outputs float32 values and int32 indices directly
return topk_values, topk_indices
def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......
......@@ -10,13 +10,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
......
......@@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
logger = init_logger(__name__)
......
......@@ -21,7 +21,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
......@@ -31,14 +31,24 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
......@@ -52,31 +62,6 @@ from vllm.utils.torch_utils import (
)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
from .fused_moe import eplb_map_to_physical_and_record
else:
def _eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
logger = init_logger(__name__)
......@@ -288,23 +273,6 @@ def maybe_roundup_hidden_size(
return hidden_size
class FusedMoERouterImpl(FusedMoERouter):
def __init__(self, layer: "FusedMoE"):
super().__init__()
self.layer = layer
@property
def routing_method_type(self) -> RoutingMethodType:
return self.layer.routing_method_type
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits)
# --8<-- [start:fused_moe]
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
......@@ -440,9 +408,7 @@ class FusedMoE(CustomOp):
self.layer_name = prefix
self.enable_eplb = enable_eplb
self.expert_load_view: torch.Tensor | None = None
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
self.eplb_state = EplbLayerState()
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
)
......@@ -538,6 +504,8 @@ class FusedMoE(CustomOp):
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
# TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
......@@ -547,46 +515,11 @@ class FusedMoE(CustomOp):
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
self._grouped_topk_impl: GroupedTopk | None = None
if self.use_grouped_topk:
assert self.num_expert_group is not None
assert self.topk_group is not None
self._grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type: RoutingMethodType = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
elif self.scoring_func == "softmax":
self.routing_method_type = (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
else:
self.routing_method_type = RoutingMethodType.TopK
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
......@@ -637,8 +570,7 @@ class FusedMoE(CustomOp):
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError(
f"EPLB is not supported {self.quant_method.__class__.__name__}. "
"EPLB is only supported for FP8 quantization for now."
f"EPLB is not supported {self.quant_method.__class__.__name__}."
)
moe_quant_params = {
......@@ -663,7 +595,38 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
self.router = FusedMoERouterImpl(self)
# TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
eplb_state=self.eplb_state,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
routing_method_type=routing_method_type,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
......@@ -1492,9 +1455,9 @@ class FusedMoE(CustomOp):
This is used later in forward pass, where we get the expert mapping
and record the load metrics in `expert_load_view`.
"""
self.expert_load_view = expert_load_view[moe_layer_idx]
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]
self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx]
self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx]
def ensure_moe_quant_config_init(self):
if self.quant_method.moe_quant_config is None:
......@@ -1535,130 +1498,6 @@ class FusedMoE(CustomOp):
device=torch.cuda.current_device(),
)
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
fused_topk_bias,
)
if self.enable_eplb:
if self.quant_method.supports_eplb:
if self.expert_load_view is None:
raise ValueError(
"enable_eplb=True requiere expert_load_view != None"
)
if self.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if self.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
else:
raise NotImplementedError(
f"EPLB is not supported for {self.quant_method.method_name}."
)
def valid_grouping() -> bool:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts = router_logits.shape[-1]
if num_experts <= self.num_expert_group:
return False
return num_experts % self.num_expert_group == 0
indices_type = self.quant_method.topk_indices_dtype
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
top_k=self.top_k,
indices_type=indices_type,
)
# DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk and valid_grouping():
assert self._grouped_topk_impl is not None
topk_weights, topk_ids = self._grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias,
)
elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
else:
topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if (indices_type is not None) and topk_ids.dtype != indices_type:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None: # in dummmy_run may be None
capturer.capture( # noqa
layer_id=self.layer_id,
topk_ids=topk_ids,
)
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
......@@ -1761,8 +1600,12 @@ class FusedMoE(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
assert self.batched_router_logits.dtype == full_router_logits.dtype
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
......@@ -2080,15 +1923,8 @@ class FusedMoE(CustomOp):
f"tp_size={self.tp_size},\n"
f"ep_size={self.ep_size}, "
f"reduce_results={self.reduce_results}, "
f"renormalize={self.renormalize}, "
f"use_grouped_topk={self.use_grouped_topk}"
)
if self.use_grouped_topk:
s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
return s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment