Commit 0386844b authored by zhuwenwen's avatar zhuwenwen
Browse files

add moe_fused_gate

parent b31c7251
......@@ -1957,6 +1957,7 @@ def fused_experts(
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
use_int4_w4a8=quant_config.use_int4_w4a8,
ocp_mx_scheme=quant_config.ocp_mx_scheme,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
......
......@@ -539,6 +539,13 @@ class FusedMoE(CustomOp):
self.layer_id, topk_ids
)
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \
and self.e_score_correction_bias is not None \
and num_expert_group is not None \
and self.global_num_experts // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0])
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
......@@ -556,6 +563,7 @@ class FusedMoE(CustomOp):
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
use_fused_gate=self.use_fused_gate,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
......@@ -652,13 +660,6 @@ class FusedMoE(CustomOp):
self.quant_method.create_weights(layer=self, **moe_quant_params)
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \
and self.e_score_correction_bias is not None \
and num_expert_group is not None \
and self.global_num_experts // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0])
# Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
......@@ -1942,34 +1943,12 @@ class FusedMoE(CustomOp):
topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig,
router_logits=router_logits,
# use_fused_gate=use_fused_gate,
)
if self.capture is not None:
self.capture(topk_ids)
# if use_fused_gate:
# # if envs.VLLM_USE_LIGHTOP:
# if False:
# topk_weights, topk_ids = op.moe_fused_gate(
# router_logits,
# self.e_score_correction_bias,
# self.num_expert_group,
# self.topk_group,
# self.top_k,
# 0,
# self.routed_scaling_factor,
# )
# else:
# topk_weights, topk_ids = ops.moe_fused_gate(
# router_logits,
# e_score_correction_bias=self.e_score_correction_bias,
# num_expert_group=self.num_expert_group,
# topk_group=self.topk_group,
# topk=self.top_k,
# routed_scaling_factor=self.routed_scaling_factor,
# n_share_experts_fusion=0,
# )
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
......
......@@ -178,6 +178,7 @@ class BaseRouter(FusedMoERouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
use_fused_gate: bool | None = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the actual routing logic.
......@@ -228,7 +229,7 @@ class BaseRouter(FusedMoERouter):
# Step 3: Compute routing (delegated to subclass)
topk_weights, topk_ids = self._compute_routing(
hidden_states, router_logits, indices_type
hidden_states, router_logits, indices_type,
)
# Step 4: Apply EPLB mapping
......
......@@ -261,6 +261,7 @@ class GroupedTopKRouter(BaseRouter):
num_fused_shared_experts: int = 0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
use_fused_gate: bool | None = False,
):
super().__init__(
top_k=top_k,
......@@ -276,6 +277,7 @@ class GroupedTopKRouter(BaseRouter):
self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts
self.use_fused_gate = use_fused_gate
if scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3
......@@ -336,6 +338,29 @@ class GroupedTopKRouter(BaseRouter):
else:
grouped_topk_impl = grouped_topk
if self.use_fused_gate:
# if envs.VLLM_USE_LIGHTOP:
if False:
topk_weights, topk_ids = op.moe_fused_gate(
router_logits,
self.e_score_correction_bias,
self.num_expert_group,
self.topk_group,
self.top_k,
0,
self.routed_scaling_factor,
)
else:
topk_weights, topk_ids = ops.moe_fused_gate(
router_logits,
self.e_score_correction_bias,
self.num_expert_group,
self.topk_group,
self.top_k,
routed_scaling_factor=self.routed_scaling_factor,
n_share_experts_fusion=0,
)
else:
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
......
......@@ -48,6 +48,7 @@ def create_fused_moe_router(
# eplb parameters
enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
use_fused_gate: bool | None = False,
) -> FusedMoERouter:
"""
Factory function to create the appropriate FusedMoERouter subclass based on
......@@ -119,6 +120,7 @@ def create_fused_moe_router(
num_fused_shared_experts=num_fused_shared_experts,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
use_fused_gate=use_fused_gate,
)
if custom_routing_function is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment