Unverified Commit dc917cce authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Move `select_experts` from `FusedMoEQuantMethod` -> `FusedMoE` (#31996)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent fc56f4a0
...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoERouter,
) )
from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def allow_inplace(self) -> bool: def allow_inplace(self) -> bool:
return True return True
@property
def is_monolithic(self) -> bool:
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
)
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_bias, layer.w2_bias,
layer.w13_weight_scale, layer.w13_weight_scale,
layer.w2_weight_scale, layer.w2_weight_scale,
router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_scale1=None, global_scale1=None,
...@@ -942,6 +944,98 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -942,6 +944,98 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.eplb_state.logical_replica_count, layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration." ), "MXFP4 are not supported with this configuration."
assert (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
)
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16)
flashinfer_cutlass_fused_moe(
input=fi_input,
token_selected_experts=topk_ids.to(torch.int).contiguous(),
token_final_scales=topk_weights,
output_dtype=torch.bfloat16,
output=output,
quant_scales=quant_scales,
fc1_expert_biases=layer.w13_bias,
fc2_expert_biases=layer.w2_bias,
swiglu_alpha=layer.gemm1_alpha,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
**extra_kwargs,
)
return output
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
assert _can_support_mxfp4(
layer.use_grouped_topk,
layer.topk_group,
layer.num_expert_group,
layer.expert_map,
layer.custom_routing_function,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.eplb_state.expert_load_view,
layer.eplb_state.logical_to_physical_map,
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
if ( if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
...@@ -988,75 +1082,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -988,75 +1082,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
return trtllm_gen_output return trtllm_gen_output
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16)
_ = flashinfer_cutlass_fused_moe(
input=fi_input,
token_selected_experts=topk_ids.to(torch.int).contiguous(),
token_final_scales=topk_weights,
output_dtype=torch.bfloat16,
output=output,
quant_scales=quant_scales,
fc1_expert_biases=layer.w13_bias,
fc2_expert_biases=layer.w2_bias,
swiglu_alpha=layer.gemm1_alpha,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
**extra_kwargs,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON: elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward, triton_kernel_moe_forward,
...@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
experts_start_id=ep_rank_start, experts_start_id=ep_rank_start,
) )
def apply( @property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoERouter,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
...@@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
None, None,
layer.w13_weight_scale, layer.w13_weight_scale,
layer.w2_weight_scale, layer.w2_weight_scale,
router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
...@@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
) )
...@@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
if not self.emulate: if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment