Unverified Commit 5031cd5d authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Optimize `select_experts` (#28069)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 3aaa94ac
...@@ -1246,7 +1246,6 @@ def eplb_map_to_physical_and_record( ...@@ -1246,7 +1246,6 @@ def eplb_map_to_physical_and_record(
expert_load_view: torch.Tensor, expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor, logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor, logical_replica_count: torch.Tensor,
indices_type: torch.dtype | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Map the logical expert ids to physical expert ids Map the logical expert ids to physical expert ids
...@@ -1260,7 +1259,6 @@ def eplb_map_to_physical_and_record( ...@@ -1260,7 +1259,6 @@ def eplb_map_to_physical_and_record(
expert_load_view: The expert load view. expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map. logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count. logical_replica_count: The logical replica count.
indices_type: The indices type.
Returns: Returns:
The physical expert ids. The physical expert ids.
...@@ -1310,9 +1308,6 @@ def eplb_map_to_physical_and_record( ...@@ -1310,9 +1308,6 @@ def eplb_map_to_physical_and_record(
index=topk_ids_flatten.long(), index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view), src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
) )
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
return topk_ids return topk_ids
......
...@@ -68,7 +68,6 @@ else: ...@@ -68,7 +68,6 @@ else:
expert_load_view: torch.Tensor, expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor, logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor, logical_replica_count: torch.Tensor,
indices_type: torch.dtype | None,
) -> torch.Tensor: ) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is # CPU fallback: no EPLB so just return as is
return topk_ids return topk_ids
...@@ -1509,8 +1508,6 @@ class FusedMoE(CustomOp): ...@@ -1509,8 +1508,6 @@ class FusedMoE(CustomOp):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
elif e_score_correction_bias is not None: elif e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias( topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -1519,7 +1516,7 @@ class FusedMoE(CustomOp): ...@@ -1519,7 +1516,7 @@ class FusedMoE(CustomOp):
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
) )
if routed_scaling_factor is not None: if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor topk_weights *= routed_scaling_factor
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
...@@ -1536,8 +1533,6 @@ class FusedMoE(CustomOp): ...@@ -1536,8 +1533,6 @@ class FusedMoE(CustomOp):
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
) )
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
if enable_eplb: if enable_eplb:
assert expert_load_view is not None assert expert_load_view is not None
...@@ -1549,9 +1544,11 @@ class FusedMoE(CustomOp): ...@@ -1549,9 +1544,11 @@ class FusedMoE(CustomOp):
expert_load_view=expert_load_view, expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count, logical_replica_count=logical_replica_count,
indices_type=indices_type,
) )
if (indices_type is not None) and topk_ids.dtype != indices_type:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None assert topk_ids.dtype == indices_type or indices_type is None
# Compute zero expert result if needed # Compute zero expert result if needed
......
...@@ -1706,7 +1706,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1706,7 +1706,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
intermediate_size=layer.intermediate_size_per_partition, intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts, local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
routed_scaling_factor=None, routed_scaling_factor=1.0,
tile_tokens_dim=None, tile_tokens_dim=None,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
do_finalize=True, do_finalize=True,
......
...@@ -118,7 +118,7 @@ class FlashConfig(PretrainedConfig): ...@@ -118,7 +118,7 @@ class FlashConfig(PretrainedConfig):
router_dtype="float32", router_dtype="float32",
router_bias=False, router_bias=False,
topk_method=None, topk_method=None,
routed_scaling_factor=None, routed_scaling_factor=1.0,
zero_expert_num=0, zero_expert_num=0,
zero_expert_type=None, zero_expert_type=None,
nextn_use_scmoe=False, nextn_use_scmoe=False,
......
...@@ -625,7 +625,7 @@ class OpenPanguDecoderLayer(nn.Module): ...@@ -625,7 +625,7 @@ class OpenPanguDecoderLayer(nn.Module):
bias=getattr(config, "mlp_bias", False), bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", None) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.first_k_dense_replace = getattr( self.first_k_dense_replace = getattr(
config, "first_k_dense_replace", self.num_hidden_layers config, "first_k_dense_replace", self.num_hidden_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment