Commit 4c20b890 authored by zhuwenwen's avatar zhuwenwen
Browse files

Set the default value of routed_scaling_factor to 1

parent 2c16c7a4
...@@ -1393,7 +1393,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1393,7 +1393,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None: routed_scaling_factor: Optional[float] = 1.0) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
...@@ -1427,7 +1427,7 @@ def inplace_fused_experts_fake( ...@@ -1427,7 +1427,7 @@ def inplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None: routed_scaling_factor: Optional[float] = 1.0) -> None:
pass pass
...@@ -1465,7 +1465,7 @@ def outplace_fused_experts( ...@@ -1465,7 +1465,7 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
...@@ -1499,7 +1499,7 @@ def outplace_fused_experts_fake( ...@@ -1499,7 +1499,7 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1558,7 +1558,7 @@ def fused_experts( ...@@ -1558,7 +1558,7 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.size(1) N = w1.size(1)
...@@ -1647,7 +1647,7 @@ def fused_experts_impl( ...@@ -1647,7 +1647,7 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe: if use_nn_moe:
...@@ -1960,7 +1960,7 @@ def fused_moe( ...@@ -1960,7 +1960,7 @@ def fused_moe(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
......
...@@ -376,7 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -376,7 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
...@@ -423,7 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -423,7 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation: str = "silu", activation: str = "silu",
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -487,7 +487,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -487,7 +487,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
**kwargs, **kwargs,
): ):
...@@ -527,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -527,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
...@@ -683,7 +683,7 @@ class FusedMoE(torch.nn.Module): ...@@ -683,7 +683,7 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
...@@ -1269,7 +1269,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1269,7 +1269,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False use_fused_gate: Optional[bool] = False
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment