Commit 0386844b authored by zhuwenwen's avatar zhuwenwen
Browse files

add moe_fused_gate

parent b31c7251
...@@ -1957,6 +1957,7 @@ def fused_experts( ...@@ -1957,6 +1957,7 @@ def fused_experts(
use_int8_w8a8=quant_config.use_int8_w8a8, use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16, use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16, use_int4_w4a16=quant_config.use_int4_w4a16,
use_int4_w4a8=quant_config.use_int4_w4a8,
ocp_mx_scheme=quant_config.ocp_mx_scheme, ocp_mx_scheme=quant_config.ocp_mx_scheme,
per_channel_quant=quant_config.per_act_token_quant, per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
......
...@@ -539,6 +539,13 @@ class FusedMoE(CustomOp): ...@@ -539,6 +539,13 @@ class FusedMoE(CustomOp):
self.layer_id, topk_ids self.layer_id, topk_ids
) )
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \
and self.e_score_correction_bias is not None \
and num_expert_group is not None \
and self.global_num_experts // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0])
self.router = create_fused_moe_router( self.router = create_fused_moe_router(
top_k=top_k, top_k=top_k,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
...@@ -556,6 +563,7 @@ class FusedMoE(CustomOp): ...@@ -556,6 +563,7 @@ class FusedMoE(CustomOp):
# TODO(bnell): once we can construct the MK at init time, we # TODO(bnell): once we can construct the MK at init time, we
# can make this a value. # can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype, indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
use_fused_gate=self.use_fused_gate,
) )
self.routing_method_type: RoutingMethodType = self.router.routing_method_type self.routing_method_type: RoutingMethodType = self.router.routing_method_type
...@@ -652,13 +660,6 @@ class FusedMoE(CustomOp): ...@@ -652,13 +660,6 @@ class FusedMoE(CustomOp):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \
and self.e_score_correction_bias is not None \
and num_expert_group is not None \
and self.global_num_experts // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0])
# Chunked all2all staging tensor # Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None
...@@ -1942,33 +1943,11 @@ class FusedMoE(CustomOp): ...@@ -1942,33 +1943,11 @@ class FusedMoE(CustomOp):
topk_weights, topk_ids = self.router.select_experts( topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig, hidden_states=x_orig,
router_logits=router_logits, router_logits=router_logits,
# use_fused_gate=use_fused_gate,
) )
if self.capture is not None: if self.capture is not None:
self.capture(topk_ids) self.capture(topk_ids)
# if use_fused_gate:
# # if envs.VLLM_USE_LIGHTOP:
# if False:
# topk_weights, topk_ids = op.moe_fused_gate(
# router_logits,
# self.e_score_correction_bias,
# self.num_expert_group,
# self.topk_group,
# self.top_k,
# 0,
# self.routed_scaling_factor,
# )
# else:
# topk_weights, topk_ids = ops.moe_fused_gate(
# router_logits,
# e_score_correction_bias=self.e_score_correction_bias,
# num_expert_group=self.num_expert_group,
# topk_group=self.topk_group,
# topk=self.top_k,
# routed_scaling_factor=self.routed_scaling_factor,
# n_share_experts_fusion=0,
# )
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
......
...@@ -178,6 +178,7 @@ class BaseRouter(FusedMoERouter): ...@@ -178,6 +178,7 @@ class BaseRouter(FusedMoERouter):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
indices_type: torch.dtype | None, indices_type: torch.dtype | None,
use_fused_gate: bool | None = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Compute the actual routing logic. Compute the actual routing logic.
...@@ -228,7 +229,7 @@ class BaseRouter(FusedMoERouter): ...@@ -228,7 +229,7 @@ class BaseRouter(FusedMoERouter):
# Step 3: Compute routing (delegated to subclass) # Step 3: Compute routing (delegated to subclass)
topk_weights, topk_ids = self._compute_routing( topk_weights, topk_ids = self._compute_routing(
hidden_states, router_logits, indices_type hidden_states, router_logits, indices_type,
) )
# Step 4: Apply EPLB mapping # Step 4: Apply EPLB mapping
......
...@@ -261,6 +261,7 @@ class GroupedTopKRouter(BaseRouter): ...@@ -261,6 +261,7 @@ class GroupedTopKRouter(BaseRouter):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
enable_eplb: bool = False, enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None,
use_fused_gate: bool | None = False,
): ):
super().__init__( super().__init__(
top_k=top_k, top_k=top_k,
...@@ -276,6 +277,7 @@ class GroupedTopKRouter(BaseRouter): ...@@ -276,6 +277,7 @@ class GroupedTopKRouter(BaseRouter):
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
self.use_fused_gate = use_fused_gate
if scoring_func == "sigmoid": if scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3 self._routing_method_type = RoutingMethodType.DeepSeekV3
...@@ -336,16 +338,39 @@ class GroupedTopKRouter(BaseRouter): ...@@ -336,16 +338,39 @@ class GroupedTopKRouter(BaseRouter):
else: else:
grouped_topk_impl = grouped_topk grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl( if self.use_fused_gate:
hidden_states=hidden_states, # if envs.VLLM_USE_LIGHTOP:
gating_output=router_logits, if False:
topk=self.top_k, topk_weights, topk_ids = op.moe_fused_gate(
renormalize=self.renormalize, router_logits,
num_expert_group=self.num_expert_group, self.e_score_correction_bias,
topk_group=self.topk_group, self.num_expert_group,
scoring_func=self.scoring_func, self.topk_group,
routed_scaling_factor=self.routed_scaling_factor, self.top_k,
e_score_correction_bias=self.e_score_correction_bias, 0,
) self.routed_scaling_factor,
)
else:
topk_weights, topk_ids = ops.moe_fused_gate(
router_logits,
self.e_score_correction_bias,
self.num_expert_group,
self.topk_group,
self.top_k,
routed_scaling_factor=self.routed_scaling_factor,
n_share_experts_fusion=0,
)
else:
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -48,6 +48,7 @@ def create_fused_moe_router( ...@@ -48,6 +48,7 @@ def create_fused_moe_router(
# eplb parameters # eplb parameters
enable_eplb: bool = False, enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE, eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
use_fused_gate: bool | None = False,
) -> FusedMoERouter: ) -> FusedMoERouter:
""" """
Factory function to create the appropriate FusedMoERouter subclass based on Factory function to create the appropriate FusedMoERouter subclass based on
...@@ -119,6 +120,7 @@ def create_fused_moe_router( ...@@ -119,6 +120,7 @@ def create_fused_moe_router(
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
use_fused_gate=use_fused_gate,
) )
if custom_routing_function is not None: if custom_routing_function is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment