Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a91e90d9
Unverified
Commit
a91e90d9
authored
Aug 20, 2025
by
Trevor Morris
Committed by
GitHub
Aug 20, 2025
Browse files
[2/2] Fuse routed scaling factor into select_experts (#8690)
parent
f96413c4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
25 deletions
+55
-25
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+7
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+20
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+8
-9
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+2
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+12
-11
sgl-kernel/tests/test_moe_fused_gate.py
sgl-kernel/tests/test_moe_fused_gate.py
+6
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
a91e90d9
...
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8MoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -923,6 +924,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -923,6 +924,12 @@ class FusedMoE(torch.nn.Module):
for
shard_id
in
[
"w1"
,
"w2"
,
"w3"
]
for
shard_id
in
[
"w1"
,
"w2"
,
"w3"
]
]
]
def
should_fuse_routed_scaling_factor_in_topk
(
self
):
return
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
or
(
isinstance
(
self
.
quant_method
,
Fp8MoEMethod
)
and
self
.
quant_method
.
use_cutlass_fused_experts_fp8
)
class
FlashInferFusedMoE
(
FusedMoE
):
class
FlashInferFusedMoE
(
FusedMoE
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
python/sglang/srt/layers/moe/topk.py
View file @
a91e90d9
...
@@ -197,6 +197,7 @@ class TopK(CustomOp):
...
@@ -197,6 +197,7 @@ class TopK(CustomOp):
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
# see https://github.com/sgl-project/sglang/pull/4505 for more details
...
@@ -215,6 +216,7 @@ class TopK(CustomOp):
...
@@ -215,6 +216,7 @@ class TopK(CustomOp):
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
...
@@ -433,6 +435,7 @@ def grouped_topk_gpu(
...
@@ -433,6 +435,7 @@ def grouped_topk_gpu(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
@@ -480,6 +483,8 @@ def grouped_topk_gpu(
...
@@ -480,6 +483,8 @@ def grouped_topk_gpu(
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
topk_weights
=
topk_weights
/
topk_weights_sum
if
apply_routed_scaling_factor_on_output
:
topk_weights
*=
routed_scaling_factor
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
...
@@ -528,6 +533,7 @@ def biased_grouped_topk_impl(
...
@@ -528,6 +533,7 @@ def biased_grouped_topk_impl(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
@@ -579,6 +585,8 @@ def biased_grouped_topk_impl(
...
@@ -579,6 +585,8 @@ def biased_grouped_topk_impl(
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
topk_weights
=
topk_weights
/
topk_weights_sum
if
apply_routed_scaling_factor_on_output
:
topk_weights
*=
routed_scaling_factor
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
...
@@ -621,6 +629,7 @@ def biased_grouped_topk_gpu(
...
@@ -621,6 +629,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
assert
(
assert
(
routed_scaling_factor
is
not
None
routed_scaling_factor
is
not
None
...
@@ -640,6 +649,7 @@ def biased_grouped_topk_gpu(
...
@@ -640,6 +649,7 @@ def biased_grouped_topk_gpu(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
)
)
# TODO merge into kernel
# TODO merge into kernel
if
(
expert_location_dispatch_info
is
not
None
)
or
(
if
(
expert_location_dispatch_info
is
not
None
)
or
(
...
@@ -650,6 +660,7 @@ def biased_grouped_topk_gpu(
...
@@ -650,6 +660,7 @@ def biased_grouped_topk_gpu(
)
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
elif
_use_aiter
:
elif
_use_aiter
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
token
=
gating_output
.
shape
[
0
]
token
=
gating_output
.
shape
[
0
]
device
=
gating_output
.
device
device
=
gating_output
.
device
assert
(
assert
(
...
@@ -681,6 +692,7 @@ def biased_grouped_topk_gpu(
...
@@ -681,6 +692,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
...
@@ -743,6 +755,9 @@ def select_experts(
...
@@ -743,6 +755,9 @@ def select_experts(
correction_bias
=
topk_config
.
correction_bias
correction_bias
=
topk_config
.
correction_bias
torch_native
=
topk_config
.
torch_native
torch_native
=
topk_config
.
torch_native
routed_scaling_factor
=
topk_config
.
routed_scaling_factor
routed_scaling_factor
=
topk_config
.
routed_scaling_factor
apply_routed_scaling_factor_on_output
=
(
topk_config
.
apply_routed_scaling_factor_on_output
)
router_logits
,
correction_bias
=
(
router_logits
,
correction_bias
=
(
expert_location_dispatch
.
transform_select_experts_inputs
(
expert_location_dispatch
.
transform_select_experts_inputs
(
...
@@ -768,6 +783,7 @@ def select_experts(
...
@@ -768,6 +783,7 @@ def select_experts(
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
else
:
else
:
topk_weights
,
topk_ids
=
biased_grouped_topk
(
topk_weights
,
topk_ids
=
biased_grouped_topk
(
...
@@ -782,12 +798,14 @@ def select_experts(
...
@@ -782,12 +798,14 @@ def select_experts(
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
elif
torch_native
and
custom_routing_function
is
None
:
elif
torch_native
and
custom_routing_function
is
None
:
assert
(
assert
(
num_token_non_padded
is
None
num_token_non_padded
is
None
),
"num_token_non_padded is not yet supported in fused_topk_native"
),
"num_token_non_padded is not yet supported in fused_topk_native"
assert
expert_location_dispatch_info
is
None
assert
expert_location_dispatch_info
is
None
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
fused_topk_native
(
topk_weights
,
topk_ids
=
fused_topk_native
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
@@ -795,6 +813,7 @@ def select_experts(
...
@@ -795,6 +813,7 @@ def select_experts(
renormalize
=
renormalize
,
renormalize
=
renormalize
,
)
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
# Qwen3MOE uses fused_topk
# Qwen3MOE uses fused_topk
topk_weights
,
topk_ids
=
fused_topk
(
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -809,6 +828,7 @@ def select_experts(
...
@@ -809,6 +828,7 @@ def select_experts(
num_token_non_padded
is
None
num_token_non_padded
is
None
),
"num_token_non_padded is not yet supported in custom_routing_function"
),
"num_token_non_padded is not yet supported in custom_routing_function"
assert
expert_location_dispatch_info
is
None
assert
expert_location_dispatch_info
is
None
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
custom_routing_function
(
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
a91e90d9
...
@@ -514,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -514,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
use_cutlass_fused_experts_fp8
=
(
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
(
is_sm100_supported
()
or
is_sm90_supported
())
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -1021,12 +1027,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1021,12 +1027,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
ret
is
not
None
:
if
ret
is
not
None
:
return
ret
return
ret
if
(
if
self
.
use_cutlass_fused_experts_fp8
:
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
(
is_sm100_supported
()
or
is_sm90_supported
())
):
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
...
@@ -1053,9 +1054,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1053,9 +1054,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
problem_sizes2
,
self
.
problem_sizes2
,
use_fp8_blockscale
=
True
,
use_fp8_blockscale
=
True
,
)
)
# TODO: Fuse into select_experts
# Scale by routed_scaling_factor is fused into select_experts.
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
return
output
return
output
# Expert fusion with FP8 quantization
# Expert fusion with FP8 quantization
return
fused_experts
(
return
fused_experts
(
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
a91e90d9
...
@@ -1305,8 +1305,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1305,8 +1305,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank
=
layer
.
moe_tp_rank
,
tp_rank
=
layer
.
moe_tp_rank
,
tune_max_num_tokens
=
next_power_of_2
(
x
.
shape
[
0
]),
tune_max_num_tokens
=
next_power_of_2
(
x
.
shape
[
0
]),
)[
0
]
)[
0
]
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
# Scale by routed_scaling_factor is fused into select_experts.
output
*=
moe_runner_config
.
routed_scaling_factor
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
output
,
global_output
=
get_local_dp_buffer
(),
output
output
,
global_output
=
get_local_dp_buffer
(),
output
get_tp_group
().
reduce_scatterv
(
get_tp_group
().
reduce_scatterv
(
...
@@ -1332,6 +1331,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1332,6 +1331,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
params
=
layer
.
cutlass_moe_params
,
params
=
layer
.
cutlass_moe_params
,
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
).
to
(
x
.
dtype
)
).
to
(
x
.
dtype
)
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
# Scale by routed_scaling_factor is fused into select_experts.
output
*=
moe_runner_config
.
routed_scaling_factor
return
output
return
output
python/sglang/srt/models/deepseek_v2.py
View file @
a91e90d9
...
@@ -319,17 +319,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -319,17 +319,6 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
()(
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
self
.
num_fused_shared_experts
...
@@ -344,6 +333,18 @@ class DeepseekV2MoE(nn.Module):
...
@@ -344,6 +333,18 @@ class DeepseekV2MoE(nn.Module):
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
(),
)
self
.
shared_experts_is_int8
=
False
self
.
shared_experts_is_int8
=
False
self
.
shared_experts_is_fp8
=
False
self
.
shared_experts_is_fp8
=
False
self
.
shared_experts_weight_block_size
=
None
self
.
shared_experts_weight_block_size
=
None
...
...
sgl-kernel/tests/test_moe_fused_gate.py
View file @
a91e90d9
...
@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
...
@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"num_fused_shared_experts"
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"num_fused_shared_experts"
,
[
0
,
1
,
2
])
def
test_moe_fused_gate_combined
(
seq_length
,
params
,
num_fused_shared_experts
):
@
pytest
.
mark
.
parametrize
(
"apply_routed_scaling_factor_on_output"
,
[
False
,
True
])
def
test_moe_fused_gate_combined
(
seq_length
,
params
,
num_fused_shared_experts
,
apply_routed_scaling_factor_on_output
):
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
...
@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk
=
topk
,
topk
=
topk
,
num_fused_shared_experts
=
num_fused_shared_experts
,
num_fused_shared_experts
=
num_fused_shared_experts
,
routed_scaling_factor
=
2.5
,
routed_scaling_factor
=
2.5
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
ref_output
,
ref_indices
=
biased_grouped_topk
(
ref_output
,
ref_indices
=
biased_grouped_topk
(
scores
,
scores
,
...
@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
...
@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
num_fused_shared_experts
=
num_fused_shared_experts
,
routed_scaling_factor
=
2.5
,
routed_scaling_factor
=
2.5
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment