Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
63227acc
Unverified
Commit
63227acc
authored
Jan 21, 2026
by
Xin Yang
Committed by
GitHub
Jan 21, 2026
Browse files
[Kernel] Add topk_sigmoid kernel (#31246)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
e675dda6
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
725 additions
and
126 deletions
+725
-126
benchmarks/kernels/benchmark_fused_topk.py
benchmarks/kernels/benchmark_fused_topk.py
+99
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+7
-1
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+242
-101
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+9
-1
tests/kernels/moe/test_fused_topk.py
tests/kernels/moe/test_fused_topk.py
+137
-0
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+17
-3
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+39
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+25
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+2
-2
vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
...xecutor/layers/fused_moe/router/fused_topk_bias_router.py
+97
-1
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
...del_executor/layers/fused_moe/router/fused_topk_router.py
+50
-8
vllm/model_executor/layers/fused_moe/router/router_factory.py
.../model_executor/layers/fused_moe/router/router_factory.py
+1
-5
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+0
-3
No files found.
benchmarks/kernels/benchmark_fused_topk.py
0 → 100644
View file @
63227acc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
fused_topk
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
num_tokens_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
num_experts_range
=
[
16
,
32
,
64
,
128
,
256
,
512
]
topk_range
=
[
3
,
4
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
def
torch_topk
(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
=
"softmax"
,
):
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
else
:
scores
=
torch
.
sigmoid
(
gating_output
.
float
())
topk_weights
,
topk_ids
=
torch
.
topk
(
scores
,
k
=
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
get_benchmark
(
scoring_func
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"vllm"
],
line_names
=
[
"Torch"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"fused-topk-perf-
{
scoring_func
}
"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
1024
renormalize
=
True
hidden_states
=
torch
.
randn
(
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch_topk
(
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the MoE topk kernel."
)
parser
.
add_argument
(
"--scoring-func"
,
type
=
str
,
default
=
"softmax"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/fused_topk/"
)
args
=
parser
.
parse_args
()
# Get the benchmark function
benchmark
=
get_benchmark
(
args
.
scoring_func
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
csrc/moe/moe_ops.h
View file @
63227acc
...
...
@@ -4,7 +4,13 @@
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
,
bool
renormalize
);
torch
::
Tensor
&
gating_output
,
bool
renormalize
,
std
::
optional
<
torch
::
Tensor
>
bias
);
void
topk_sigmoid
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
,
bool
renormalize
,
std
::
optional
<
torch
::
Tensor
>
bias
);
void
moe_sum
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
63227acc
This diff is collapsed.
Click to expand it.
csrc/moe/torch_bindings.cpp
View file @
63227acc
...
...
@@ -5,9 +5,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()"
);
"token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
// Apply topk sigmoid to the gating outputs.
m
.
def
(
"topk_sigmoid(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()"
);
m
.
impl
(
"topk_sigmoid"
,
torch
::
kCUDA
,
&
topk_sigmoid
);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m
.
def
(
"moe_sum(Tensor input, Tensor! output) -> ()"
);
...
...
tests/kernels/moe/test_fused_topk.py
0 → 100644
View file @
63227acc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE fused topk kernel
Run `pytest tests/kernels/moe/test_fused_topk.py`.
"""
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router
import
(
fused_topk_bias
,
)
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
fused_topk
from
vllm.platforms
import
current_platform
def
torch_topk
(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
e_score_correction_bias
:
torch
.
Tensor
=
None
,
scoring_func
:
str
=
"softmax"
,
):
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
else
:
assert
scoring_func
==
"sigmoid"
scores
=
torch
.
sigmoid
(
gating_output
.
float
())
if
e_score_correction_bias
is
not
None
:
num_experts
=
gating_output
.
shape
[
-
1
]
scores_for_choice
=
scores
.
view
(
-
1
,
num_experts
)
+
e_score_correction_bias
.
unsqueeze
(
0
)
_
,
topk_ids
=
torch
.
topk
(
scores_for_choice
,
k
=
topk
,
dim
=-
1
)
topk_weights
=
scores
.
gather
(
1
,
topk_ids
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
scores
,
k
=
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
33
,
56
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
6
,
16
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"softmax"
,
"sigmoid"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
])
def
test_fused_topk
(
num_tokens
:
int
,
hidden_size
:
int
,
num_experts
:
int
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
,
dtype
:
torch
.
dtype
,
):
torch
.
manual_seed
(
0
)
hidden_states
=
torch
.
randn
((
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights_ref
,
topk_ids_ref
=
torch_topk
(
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
)
torch
.
testing
.
assert_close
(
topk_weights_ref
.
to
(
torch
.
float32
),
topk_weights
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
topk_ids_ref
.
to
(
torch
.
int32
),
topk_ids
,
atol
=
0
,
rtol
=
0
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
33
,
56
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
6
,
16
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"softmax"
,
"sigmoid"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
])
def
test_fused_topk_bias
(
num_tokens
:
int
,
hidden_size
:
int
,
num_experts
:
int
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
,
dtype
:
torch
.
dtype
,
):
torch
.
manual_seed
(
0
)
hidden_states
=
torch
.
randn
((
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
e_score_correction_bias
=
torch
.
randn
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_weights_ref
,
topk_ids_ref
=
torch_topk
(
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
e_score_correction_bias
=
e_score_correction_bias
,
scoring_func
=
scoring_func
,
)
topk_weights
,
topk_ids
=
fused_topk_bias
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
e_score_correction_bias
=
e_score_correction_bias
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
)
torch
.
testing
.
assert_close
(
topk_weights_ref
.
to
(
torch
.
float32
),
topk_weights
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
topk_ids_ref
.
to
(
torch
.
int32
),
topk_ids
,
atol
=
0
,
rtol
=
0
)
tests/model_executor/test_enabled_custom_ops.py
View file @
63227acc
...
...
@@ -18,7 +18,9 @@ from vllm.model_executor.layers.activation import (
SiluAndMul
,
)
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
(
dispatch_topk_func
,
dispatch_topk_sigmoid_func
,
dispatch_topk_softmax_func
,
vllm_topk_sigmoid
,
vllm_topk_softmax
,
)
from
vllm.model_executor.layers.layernorm
import
(
...
...
@@ -133,8 +135,8 @@ def test_enabled_ops_invalid(env: str):
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
]
)
def
test_topk_dispatch
(
use_rocm_aiter
:
bool
):
topk_func
=
dispatch_topk_func
(
use_rocm_aiter
)
def
test_topk_
softmax_
dispatch
(
use_rocm_aiter
:
bool
):
topk_func
=
dispatch_topk_
softmax_
func
(
use_rocm_aiter
)
if
current_platform
.
is_rocm
()
and
use_rocm_aiter
:
assert
topk_func
==
rocm_aiter_ops
.
topk_softmax
...
...
@@ -142,6 +144,18 @@ def test_topk_dispatch(use_rocm_aiter: bool):
assert
topk_func
==
vllm_topk_softmax
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
]
)
def
test_topk_sigmoid_dispatch
(
use_rocm_aiter
:
bool
):
topk_func
=
dispatch_topk_sigmoid_func
(
use_rocm_aiter
)
if
current_platform
.
is_rocm
()
and
use_rocm_aiter
:
assert
topk_func
==
rocm_aiter_ops
.
topk_sigmoid
else
:
assert
topk_func
==
vllm_topk_sigmoid
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
])
...
...
vllm/_aiter_ops.py
View file @
63227acc
...
...
@@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake(
pass
def
_rocm_aiter_topk_sigmoid_impl
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
)
->
None
:
from
aiter
import
topk_sigmoid
topk_sigmoid
(
topk_weights
,
topk_indices
,
gating_output
)
def
_rocm_aiter_topk_sigmoid_fake
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
)
->
None
:
pass
def
_rocm_aiter_biased_grouped_topk_impl
(
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
...
...
@@ -985,6 +1003,14 @@ class rocm_aiter_ops:
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_topk_sigmoid"
,
op_func
=
_rocm_aiter_topk_sigmoid_impl
,
mutates_args
=
[
"topk_weights"
,
"topk_indices"
],
fake_impl
=
_rocm_aiter_topk_sigmoid_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_biased_grouped_topk"
,
op_func
=
_rocm_aiter_biased_grouped_topk_impl
,
...
...
@@ -1272,6 +1298,19 @@ class rocm_aiter_ops:
)
return
topk_weights
,
topk_indices
@
staticmethod
def
topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
...]:
torch
.
ops
.
vllm
.
rocm_aiter_topk_sigmoid
(
topk_weights
,
topk_indices
,
gating_output
)
return
topk_weights
,
topk_indices
@
staticmethod
def
biased_grouped_topk
(
gating_output
:
torch
.
Tensor
,
...
...
vllm/_custom_ops.py
View file @
63227acc
...
...
@@ -2177,9 +2177,33 @@ def topk_softmax(
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
def
topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_sigmoid
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
63227acc
...
...
@@ -106,14 +106,14 @@ def _quant_flags_to_group_shape(
class
RoutingMethodType
(
IntEnum
):
# Default: Softmax -> TopK
Default
=
(
0
,)
# Renormalize: TopK -> Softmax
# Renormalize: TopK -> Softmax
/Sigmoid
Renormalize
=
(
1
,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3
=
(
2
,)
# Llama4: Top1 -> Sigmoid
Llama4
=
(
3
,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
# RenormalizeNaive: Softmax
/Sigmoid
-> TopK -> Renormalize
RenormalizeNaive
=
(
4
,)
# TopK: TopK (no softmax)
TopK
=
(
5
,)
...
...
vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
View file @
63227acc
...
...
@@ -4,6 +4,8 @@ from collections.abc import Callable
import
torch
import
vllm._custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
...
...
@@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from
vllm.model_executor.layers.fused_moe.router.base_router
import
BaseRouter
def
vllm_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_indices
def
vllm_topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_sigmoid
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_indices
def
fused_topk_bias
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
=
"softmax"
,
indices_type
:
torch
.
dtype
|
None
=
None
,
):
if
not
rocm_aiter_ops
.
is_fused_moe_enabled
():
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
(
"Number of tokens mismatch"
)
M
,
_
=
hidden_states
.
size
()
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
if
indices_type
is
None
else
indices_type
,
device
=
hidden_states
.
device
,
)
token_expert_indices
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
if
scoring_func
==
"softmax"
:
topk_weights
,
topk_ids
=
vllm_topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_ids
elif
scoring_func
==
"sigmoid"
:
topk_weights
,
topk_ids
=
vllm_topk_sigmoid
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_ids
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
n_routed_experts
=
gating_output
.
shape
[
-
1
]
scores
=
gating_output
.
softmax
(
dim
=-
1
)
if
scoring_func
==
"softmax"
:
scores
=
gating_output
.
softmax
(
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
scores_for_choice
=
scores
.
view
(
-
1
,
n_routed_experts
)
+
e_score_correction_bias
.
unsqueeze
(
0
)
...
...
@@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter):
global_num_experts
:
int
,
eplb_state
:
EplbLayerState
,
e_score_correction_bias
:
torch
.
Tensor
,
scoring_func
:
str
,
renormalize
:
bool
=
True
,
routed_scaling_factor
:
float
=
1.0
,
enable_eplb
:
bool
=
False
,
...
...
@@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter):
)
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
renormalize
=
renormalize
self
.
scoring_func
=
scoring_func
self
.
routed_scaling_factor
=
routed_scaling_factor
@
property
...
...
@@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter):
e_score_correction_bias
=
self
.
e_score_correction_bias
.
data
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
scoring_func
=
self
.
scoring_func
,
)
if
self
.
routed_scaling_factor
!=
1.0
:
...
...
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
View file @
63227acc
...
...
@@ -16,7 +16,7 @@ def vllm_topk_softmax(
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
,
renormalize
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_softmax
(
topk_weights
,
...
...
@@ -29,7 +29,25 @@ def vllm_topk_softmax(
return
topk_weights
,
topk_indices
def
dispatch_topk_func
(
def
vllm_topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_sigmoid
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
,
)
return
topk_weights
,
topk_indices
def
dispatch_topk_softmax_func
(
use_rocm_aiter
:
bool
=
False
,
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
use_rocm_aiter
:
...
...
@@ -37,12 +55,21 @@ def dispatch_topk_func(
return
vllm_topk_softmax
def
dispatch_topk_sigmoid_func
(
use_rocm_aiter
:
bool
=
False
,
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
use_rocm_aiter
:
return
rocm_aiter_ops
.
topk_sigmoid
return
vllm_topk_sigmoid
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
indices_type
:
torch
.
dtype
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
"Number of tokens mismatch"
...
...
@@ -61,12 +88,26 @@ def fused_topk(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_func
=
dispatch_topk_func
(
use_rocm_aiter
=
rocm_aiter_ops
.
is_fused_moe_enabled
())
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
)
if
scoring_func
==
"softmax"
:
topk_func
=
dispatch_topk_softmax_func
(
use_rocm_aiter
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
)
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
)
return
topk_weights
,
topk_ids
,
token_expert_indices
elif
scoring_func
==
"sigmoid"
:
topk_func
=
dispatch_topk_sigmoid_func
(
use_rocm_aiter
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
)
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
)
return
topk_weights
,
topk_ids
,
token_expert_indices
return
topk_weights
,
topk_ids
,
token_expert_indices
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
class
FusedTopKRouter
(
BaseRouter
):
...
...
@@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter):
enable_eplb
:
bool
=
False
,
indices_type_getter
:
Callable
[[],
torch
.
dtype
|
None
]
|
None
=
None
,
):
assert
scoring_func
==
"softmax"
,
"FusedTopKRouter only supports softmax."
super
().
__init__
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
...
...
@@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter):
indices_type_getter
=
indices_type_getter
,
)
self
.
renormalize
=
renormalize
self
.
scoring_func
=
scoring_func
@
property
def
routing_method_type
(
self
)
->
RoutingMethodType
:
...
...
@@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter):
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
indices_type
=
indices_type
,
scoring_func
=
self
.
scoring_func
,
)
return
topk_weights
,
topk_ids
vllm/model_executor/layers/fused_moe/router/router_factory.py
View file @
63227acc
...
...
@@ -143,17 +143,13 @@ def create_fused_moe_router(
router
.
capture
=
capture
return
router
if
scoring_func
!=
"softmax"
:
raise
ValueError
(
"Only softmax scoring function is supported for non-grouped topk."
)
if
e_score_correction_bias
is
not
None
:
router
=
FusedTopKBiasRouter
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
eplb_state
=
eplb_state
,
e_score_correction_bias
=
e_score_correction_bias
,
scoring_func
=
scoring_func
,
renormalize
=
renormalize
,
routed_scaling_factor
=
routed_scaling_factor
,
enable_eplb
=
enable_eplb
,
...
...
vllm/model_executor/models/minimax_m2.py
View file @
63227acc
...
...
@@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module):
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
scoring_func
=
config
.
scoring_func
,
use_grouped_topk
=
True
,
num_expert_group
=
1
,
topk_group
=
1
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment