Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
63227acc
Unverified
Commit
63227acc
authored
Jan 21, 2026
by
Xin Yang
Committed by
GitHub
Jan 21, 2026
Browse files
[Kernel] Add topk_sigmoid kernel (#31246)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
e675dda6
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
725 additions
and
126 deletions
+725
-126
benchmarks/kernels/benchmark_fused_topk.py
benchmarks/kernels/benchmark_fused_topk.py
+99
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+7
-1
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+242
-101
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+9
-1
tests/kernels/moe/test_fused_topk.py
tests/kernels/moe/test_fused_topk.py
+137
-0
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+17
-3
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+39
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+25
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+2
-2
vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
...xecutor/layers/fused_moe/router/fused_topk_bias_router.py
+97
-1
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
...del_executor/layers/fused_moe/router/fused_topk_router.py
+50
-8
vllm/model_executor/layers/fused_moe/router/router_factory.py
.../model_executor/layers/fused_moe/router/router_factory.py
+1
-5
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+0
-3
No files found.
benchmarks/kernels/benchmark_fused_topk.py
0 → 100644
View file @
63227acc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
fused_topk
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
num_tokens_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
num_experts_range
=
[
16
,
32
,
64
,
128
,
256
,
512
]
topk_range
=
[
3
,
4
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
def
torch_topk
(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
=
"softmax"
,
):
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
else
:
scores
=
torch
.
sigmoid
(
gating_output
.
float
())
topk_weights
,
topk_ids
=
torch
.
topk
(
scores
,
k
=
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
get_benchmark
(
scoring_func
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"vllm"
],
line_names
=
[
"Torch"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"fused-topk-perf-
{
scoring_func
}
"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
1024
renormalize
=
True
hidden_states
=
torch
.
randn
(
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch_topk
(
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the MoE topk kernel."
)
parser
.
add_argument
(
"--scoring-func"
,
type
=
str
,
default
=
"softmax"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/fused_topk/"
)
args
=
parser
.
parse_args
()
# Get the benchmark function
benchmark
=
get_benchmark
(
args
.
scoring_func
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
csrc/moe/moe_ops.h
View file @
63227acc
...
...
@@ -4,7 +4,13 @@
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
,
bool
renormalize
);
torch
::
Tensor
&
gating_output
,
bool
renormalize
,
std
::
optional
<
torch
::
Tensor
>
bias
);
void
topk_sigmoid
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
,
bool
renormalize
,
std
::
optional
<
torch
::
Tensor
>
bias
);
void
moe_sum
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
63227acc
...
...
@@ -62,6 +62,12 @@ __device__ __forceinline__ float toFloat(T value) {
}
}
// Scoring function enums
enum
ScoringFunc
{
SCORING_SOFTMAX
=
0
,
// apply softmax
SCORING_SIGMOID
=
1
// apply sigmoid
};
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
...
...
@@ -125,6 +131,27 @@ __launch_bounds__(TPB) __global__
}
}
template
<
int
TPB
,
typename
InputType
>
__launch_bounds__
(
TPB
)
__global__
void
moeSigmoid
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
)
{
const
int
thread_row_offset
=
blockIdx
.
x
*
num_cols
;
// Don't touch finished rows.
if
((
finished
!=
nullptr
)
&&
finished
[
blockIdx
.
x
])
{
return
;
}
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
const
float
val
=
toFloat
(
input
[
idx
]);
const
float
sigmoid_val
=
1.0
f
/
(
1.0
f
+
__expf
(
-
val
));
output
[
idx
]
=
sigmoid_val
;
}
}
template
<
int
TPB
,
typename
IndType
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
...
...
@@ -136,7 +163,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
)
const
bool
renormalize
,
const
float
*
bias
)
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
...
...
@@ -162,7 +190,13 @@ __launch_bounds__(TPB) __global__ void moeTopK(
{
const
int
idx
=
thread_read_offset
+
expert
;
inp_kvp
.
key
=
expert
;
inp_kvp
.
value
=
inputs_after_softmax
[
idx
];
// Apply correction bias if provided
if
(
bias
!=
nullptr
)
{
inp_kvp
.
value
=
inputs_after_softmax
[
idx
]
+
bias
[
expert
];
}
else
{
inp_kvp
.
value
=
inputs_after_softmax
[
idx
];
}
for
(
int
prior_k
=
0
;
prior_k
<
k_idx
;
++
prior_k
)
{
...
...
@@ -186,12 +220,13 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const
bool
should_process_row
=
row_is_active
&&
node_uses_expert
;
const
int
idx
=
k
*
block_row
+
k_idx
;
output
[
idx
]
=
result_kvp
.
value
;
// Return the unbiased scores for output weights
output
[
idx
]
=
inputs_after_softmax
[
thread_read_offset
+
expert
];
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
num_experts
;
assert
(
indices
[
idx
]
>=
0
);
source_rows
[
idx
]
=
k_idx
*
num_rows
+
block_row
;
if
(
renormalize
)
{
selected_sum
+=
result_kvp
.
value
;
selected_sum
+=
inputs_after_softmax
[
thread_read_offset
+
expert
]
;
}
}
__syncthreads
();
...
...
@@ -225,10 +260,12 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k.
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
,
typename
IndType
,
typename
InputType
=
float
>
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
,
typename
IndType
,
typename
InputType
=
float
,
ScoringFunc
SF
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE_PARAM
)
__global__
void
topkGatingSoftmax
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
)
void
topkGating
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
,
const
float
*
bias
)
{
static_assert
(
std
::
is_same_v
<
InputType
,
float
>
||
std
::
is_same_v
<
InputType
,
__nv_bfloat16
>
||
std
::
is_same_v
<
InputType
,
__half
>
,
...
...
@@ -353,61 +390,89 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
//
convert to float afterwards for the exp + sum reduction
.
float
thread_max
=
row_chunk
[
0
];
if
constexpr
(
SF
==
SCORING_SOFTMAX
)
{
//
First, we perform a max reduce within the thread
.
float
thread_max
=
row_chunk
[
0
];
#pragma unroll
for
(
int
ii
=
1
;
ii
<
VPT
;
++
ii
)
{
for
(
int
ii
=
1
;
ii
<
VPT
;
++
ii
)
{
thread_max
=
max
(
thread_max
,
row_chunk
[
ii
]);
}
}
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
thread_max
=
max
(
thread_max
,
VLLM_SHFL_XOR_SYNC_WIDTH
(
thread_max
,
mask
,
THREADS_PER_ROW
));
}
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float
row_sum
=
0
;
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float
row_sum
=
0
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
expf
(
row_chunk
[
ii
]
-
thread_max
);
row_sum
+=
row_chunk
[
ii
];
}
}
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
row_sum
+=
VLLM_SHFL_XOR_SYNC_WIDTH
(
row_sum
,
mask
,
THREADS_PER_ROW
);
}
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const
float
reciprocal_row_sum
=
1.
f
/
row_sum
;
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const
float
reciprocal_row_sum
=
1.
f
/
row_sum
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
row_chunk
[
ii
]
*
reciprocal_row_sum
;
}
}
else
if
constexpr
(
SF
==
SCORING_SIGMOID
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
1.0
f
/
(
1.0
f
+
__expf
(
-
row_chunk
[
ii
]));
}
}
static
constexpr
int
COLS_PER_GROUP_LDG
=
ELTS_PER_LDG
*
THREADS_PER_ROW
;
// If bias is not null, use biased value for selection
float
row_chunk_for_choice
[
VPT
];
// Apply correction bias
if
(
bias
!=
nullptr
)
{
#pragma unroll
for
(
int
ldg
=
0
;
ldg
<
LDG_PER_THREAD
;
++
ldg
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ELTS_PER_LDG
;
++
ii
)
{
const
int
expert
=
first_elt_read_by_thread
+
ldg
*
COLS_PER_GROUP_LDG
+
ii
;
float
bias_val
=
expert
<
NUM_EXPERTS
?
bias
[
expert
]
:
0.0
f
;
row_chunk_for_choice
[
ldg
*
ELTS_PER_LDG
+
ii
]
=
row_chunk
[
ldg
*
ELTS_PER_LDG
+
ii
]
+
bias_val
;
}
}
}
else
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk_for_choice
[
ii
]
=
row_chunk
[
ii
];
}
}
// Now,
softmax_res
contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// Now,
row_chunk
contains the softmax
/ sigmoid
of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
int
start_col
=
first_elt_read_by_thread
;
static
constexpr
int
COLS_PER_GROUP_LDG
=
ELTS_PER_LDG
*
THREADS_PER_ROW
;
float
selected_sum
=
0.
f
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
// First, each thread does the local argmax
float
max_val_for_choice
=
row_chunk_for_choice
[
0
];
float
max_val
=
row_chunk
[
0
];
int
expert
=
start_col
;
#pragma unroll
...
...
@@ -416,12 +481,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ELTS_PER_LDG
;
++
ii
)
{
float
val_for_choice
=
row_chunk_for_choice
[
ldg
*
ELTS_PER_LDG
+
ii
];
float
val
=
row_chunk
[
ldg
*
ELTS_PER_LDG
+
ii
];
// No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=)
if
(
val
>
max_val
)
if
(
val
_for_choice
>
max_val_for_choice
)
{
max_val_for_choice
=
val_for_choice
;
max_val
=
val
;
expert
=
col
+
ii
;
}
...
...
@@ -434,12 +501,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
float
other_max_for_choice
=
VLLM_SHFL_XOR_SYNC_WIDTH
(
max_val_for_choice
,
mask
,
THREADS_PER_ROW
);
float
other_max
=
VLLM_SHFL_XOR_SYNC_WIDTH
(
max_val
,
mask
,
THREADS_PER_ROW
);
int
other_expert
=
VLLM_SHFL_XOR_SYNC_WIDTH
(
expert
,
mask
,
THREADS_PER_ROW
);
// We want lower indices to "win" in every thread so we break ties this way
if
(
other_max
>
max_val
||
(
other_max
==
max_val
&&
other_expert
<
expert
))
if
(
other_max
_for_choice
>
max_val_for_choice
||
(
other_max_for_choice
==
max_val_for_choice
&&
other_expert
<
expert
))
{
max_val_for_choice
=
other_max_for_choice
;
max_val
=
other_max
;
expert
=
other_expert
;
}
...
...
@@ -474,7 +543,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
{
const
int
offset_for_expert
=
expert
%
ELTS_PER_LDG
;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk
[
ldg_group_for_expert
*
ELTS_PER_LDG
+
offset_for_expert
]
=
-
10000.
f
;
row_chunk
_for_choice
[
ldg_group_for_expert
*
ELTS_PER_LDG
+
offset_for_expert
]
=
-
10000.
f
;
}
}
}
...
...
@@ -508,10 +577,10 @@ struct TopkConstants
};
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
int
WARP_SIZE_PARAM
,
int
MAX_BYTES_PER_LDG
,
typename
IndType
,
typename
InputType
>
void
topkGating
Softmax
LauncherHelper
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
int
WARP_SIZE_PARAM
,
int
MAX_BYTES_PER_LDG
,
typename
IndType
,
typename
InputType
,
ScoringFunc
SF
>
void
topkGatingLauncherHelper
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
,
cudaStream_t
stream
)
const
float
*
bias
,
cudaStream_t
stream
)
{
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
InputType
)
*
EXPERTS
);
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
,
InputType
>
;
...
...
@@ -521,43 +590,51 @@ void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finishe
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
dim3
block_dim
(
WARP_SIZE_PARAM
,
WARPS_PER_TB
);
topkGating
Softmax
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
,
IndType
,
InputType
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
,
renormalize
);
topkGating
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
,
IndType
,
InputType
,
SF
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
,
renormalize
,
bias
);
}
#ifndef USE_ROCM
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream);
#define LAUNCH_TOPK(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES, \
IndType, InputType, SF>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
bias, stream);
#else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
#define LAUNCH_TOPK(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES, \
IndType, InputType, SF>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
bias, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES, \
IndType, InputType, SF>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
bias, stream); \
} else { \
assert(false && \
"Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template
<
typename
IndType
,
typename
InputType
>
void
topkGating
Softmax
KernelLauncher
(
template
<
typename
IndType
,
typename
InputType
,
ScoringFunc
SF
>
void
topkGatingKernelLauncher
(
const
InputType
*
gating_output
,
float
*
topk_weights
,
IndType
*
topk_indices
,
int
*
token_expert_indices
,
float
*
softmax_
workspace
,
float
*
workspace
,
const
int
num_tokens
,
const
int
num_experts
,
const
int
topk
,
const
bool
renormalize
,
const
float
*
bias
,
cudaStream_t
stream
)
{
static
constexpr
int
WARPS_PER_TB
=
4
;
static
constexpr
int
BYTES_PER_LDG_POWER_OF_2
=
16
;
...
...
@@ -569,64 +646,71 @@ void topkGatingSoftmaxKernelLauncher(
#endif
switch
(
num_experts
)
{
case
1
:
LAUNCH_
SOFTMAX
(
1
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
1
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
2
:
LAUNCH_
SOFTMAX
(
2
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
2
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
4
:
LAUNCH_
SOFTMAX
(
4
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
4
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
8
:
LAUNCH_
SOFTMAX
(
8
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
8
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
16
:
LAUNCH_
SOFTMAX
(
16
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
16
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
32
:
LAUNCH_
SOFTMAX
(
32
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
32
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
64
:
LAUNCH_
SOFTMAX
(
64
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
64
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
128
:
LAUNCH_
SOFTMAX
(
128
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
128
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
256
:
LAUNCH_
SOFTMAX
(
256
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
256
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
case
512
:
LAUNCH_
SOFTMAX
(
512
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
LAUNCH_
TOPK
(
512
,
WARPS_PER_TB
,
BYTES_PER_LDG_POWER_OF_2
);
break
;
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
// alternatively we can test 4 bytes loading and enable it in future.
#ifndef USE_ROCM
case
192
:
LAUNCH_
SOFTMAX
(
192
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
LAUNCH_
TOPK
(
192
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
break
;
case
320
:
LAUNCH_
SOFTMAX
(
320
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
LAUNCH_
TOPK
(
320
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
break
;
case
384
:
LAUNCH_
SOFTMAX
(
384
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
LAUNCH_
TOPK
(
384
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
break
;
case
448
:
LAUNCH_
SOFTMAX
(
448
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
LAUNCH_
TOPK
(
448
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
break
;
case
576
:
LAUNCH_
SOFTMAX
(
576
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
LAUNCH_
TOPK
(
576
,
WARPS_PER_TB
,
BYTES_PER_LDG_MULTIPLE_64
);
break
;
#endif
default:
{
TORCH_CHECK
(
softmax_
workspace
!=
nullptr
,
"
softmax_
workspace must be provided for num_experts that are not a power of 2 or multiple of 64."
);
TORCH_CHECK
(
workspace
!=
nullptr
,
"workspace must be provided for num_experts that are not a power of 2 or multiple of 64."
);
static
constexpr
int
TPB
=
256
;
moeSoftmax
<
TPB
,
InputType
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
if
constexpr
(
SF
==
SCORING_SOFTMAX
)
{
moeSoftmax
<
TPB
,
InputType
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
workspace
,
num_experts
);
}
else
if
constexpr
(
SF
==
SCORING_SIGMOID
)
{
moeSigmoid
<
TPB
,
InputType
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
workspace
,
num_experts
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported scoring func"
);
}
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
softmax_
workspace
,
nullptr
,
topk_weights
,
topk_indices
,
token_expert_indices
,
num_experts
,
topk
,
0
,
num_experts
,
renormalize
);
workspace
,
nullptr
,
topk_weights
,
topk_indices
,
token_expert_indices
,
num_experts
,
topk
,
0
,
num_experts
,
renormalize
,
bias
);
}
}
}
...
...
@@ -635,40 +719,55 @@ void topkGatingSoftmaxKernelLauncher(
}
// namespace vllm
template
<
typename
ComputeType
>
void
dispatch_topk_
softmax_
launch
(
template
<
typename
ComputeType
,
vllm
::
moe
::
ScoringFunc
SF
>
void
dispatch_topk_launch
(
torch
::
Tensor
&
gating_output
,
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
softmax_workspace
,
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
renormalize
,
cudaStream_t
stream
)
{
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
renormalize
,
std
::
optional
<
torch
::
Tensor
>
bias
,
cudaStream_t
stream
)
{
const
float
*
bias_ptr
=
nullptr
;
if
(
bias
.
has_value
())
{
const
torch
::
Tensor
&
bias_tensor
=
bias
.
value
();
TORCH_CHECK
(
bias_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"bias tensor must be float32"
);
TORCH_CHECK
(
bias_tensor
.
dim
()
==
1
,
"bias tensor must be 1D"
);
TORCH_CHECK
(
bias_tensor
.
size
(
0
)
==
num_experts
,
"bias size mismatch, expected: "
,
num_experts
);
TORCH_CHECK
(
bias_tensor
.
is_contiguous
(),
"bias tensor must be contiguous"
);
bias_ptr
=
bias_tensor
.
data_ptr
<
float
>
();
}
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
)
{
vllm
::
moe
::
topkGating
Softmax
KernelLauncher
<
int
,
ComputeType
>
(
vllm
::
moe
::
topkGatingKernelLauncher
<
int
,
ComputeType
,
SF
>
(
reinterpret_cast
<
const
ComputeType
*>
(
gating_output
.
data_ptr
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias_ptr
,
stream
);
}
else
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
UInt32
)
{
vllm
::
moe
::
topkGating
Softmax
KernelLauncher
<
uint32_t
,
ComputeType
>
(
vllm
::
moe
::
topkGatingKernelLauncher
<
uint32_t
,
ComputeType
,
SF
>
(
reinterpret_cast
<
const
ComputeType
*>
(
gating_output
.
data_ptr
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
uint32_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias_ptr
,
stream
);
}
else
{
TORCH_CHECK
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Long
);
vllm
::
moe
::
topkGating
Softmax
KernelLauncher
<
int64_t
,
ComputeType
>
(
vllm
::
moe
::
topkGatingKernelLauncher
<
int64_t
,
ComputeType
,
SF
>
(
reinterpret_cast
<
const
ComputeType
*>
(
gating_output
.
data_ptr
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int64_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias_ptr
,
stream
);
}
}
...
...
@@ -677,7 +776,8 @@ void topk_softmax(
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
token_expert_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
gating_output
,
// [num_tokens, num_experts]
bool
renormalize
)
bool
renormalize
,
std
::
optional
<
torch
::
Tensor
>
bias
)
{
const
int
num_experts
=
gating_output
.
size
(
-
1
);
const
auto
num_tokens
=
gating_output
.
numel
()
/
num_experts
;
...
...
@@ -693,14 +793,55 @@ void topk_softmax(
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
workspace_options
);
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Float
)
{
dispatch_topk_softmax_launch
<
float
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
dispatch_topk_launch
<
float
,
vllm
::
moe
::
SCORING_SOFTMAX
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias
,
stream
);
}
else
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
dispatch_topk_launch
<
__half
,
vllm
::
moe
::
SCORING_SOFTMAX
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias
,
stream
);
}
else
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
dispatch_topk_launch
<
__nv_bfloat16
,
vllm
::
moe
::
SCORING_SOFTMAX
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported gating_output data type: "
,
gating_output
.
scalar_type
());
}
}
void
topk_sigmoid
(
torch
::
Tensor
&
topk_weights
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
token_expert_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
gating_output
,
// [num_tokens, num_experts]
bool
renormalize
,
std
::
optional
<
torch
::
Tensor
>
bias
)
{
const
int
num_experts
=
gating_output
.
size
(
-
1
);
const
auto
num_tokens
=
gating_output
.
numel
()
/
num_experts
;
const
int
topk
=
topk_weights
.
size
(
-
1
);
const
bool
is_pow_2
=
(
num_experts
!=
0
)
&&
((
num_experts
&
(
num_experts
-
1
))
==
0
);
const
bool
needs_workspace
=
!
is_pow_2
||
num_experts
>
256
;
const
int64_t
workspace_size
=
needs_workspace
?
num_tokens
*
num_experts
:
0
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
auto
workspace_options
=
gating_output
.
options
().
dtype
(
at
::
ScalarType
::
Float
);
torch
::
Tensor
workspace
=
torch
::
empty
({
workspace_size
},
workspace_options
);
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Float
)
{
dispatch_topk_launch
<
float
,
vllm
::
moe
::
SCORING_SIGMOID
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias
,
stream
);
}
else
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
dispatch_topk_softmax_launch
<
__half
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
dispatch_topk_launch
<
__half
,
vllm
::
moe
::
SCORING_SIGMOID
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias
,
stream
);
}
else
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
dispatch_topk_softmax_launch
<
__nv_bfloat16
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
dispatch_topk_launch
<
__nv_bfloat16
,
vllm
::
moe
::
SCORING_SIGMOID
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
bias
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported gating_output data type: "
,
gating_output
.
scalar_type
());
}
...
...
csrc/moe/torch_bindings.cpp
View file @
63227acc
...
...
@@ -5,9 +5,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()"
);
"token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
// Apply topk sigmoid to the gating outputs.
m
.
def
(
"topk_sigmoid(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()"
);
m
.
impl
(
"topk_sigmoid"
,
torch
::
kCUDA
,
&
topk_sigmoid
);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m
.
def
(
"moe_sum(Tensor input, Tensor! output) -> ()"
);
...
...
tests/kernels/moe/test_fused_topk.py
0 → 100644
View file @
63227acc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE fused topk kernel
Run `pytest tests/kernels/moe/test_fused_topk.py`.
"""
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router
import
(
fused_topk_bias
,
)
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
fused_topk
from
vllm.platforms
import
current_platform
def
torch_topk
(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
e_score_correction_bias
:
torch
.
Tensor
=
None
,
scoring_func
:
str
=
"softmax"
,
):
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
else
:
assert
scoring_func
==
"sigmoid"
scores
=
torch
.
sigmoid
(
gating_output
.
float
())
if
e_score_correction_bias
is
not
None
:
num_experts
=
gating_output
.
shape
[
-
1
]
scores_for_choice
=
scores
.
view
(
-
1
,
num_experts
)
+
e_score_correction_bias
.
unsqueeze
(
0
)
_
,
topk_ids
=
torch
.
topk
(
scores_for_choice
,
k
=
topk
,
dim
=-
1
)
topk_weights
=
scores
.
gather
(
1
,
topk_ids
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
scores
,
k
=
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
33
,
56
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
6
,
16
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"softmax"
,
"sigmoid"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
])
def
test_fused_topk
(
num_tokens
:
int
,
hidden_size
:
int
,
num_experts
:
int
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
,
dtype
:
torch
.
dtype
,
):
torch
.
manual_seed
(
0
)
hidden_states
=
torch
.
randn
((
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights_ref
,
topk_ids_ref
=
torch_topk
(
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
)
torch
.
testing
.
assert_close
(
topk_weights_ref
.
to
(
torch
.
float32
),
topk_weights
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
topk_ids_ref
.
to
(
torch
.
int32
),
topk_ids
,
atol
=
0
,
rtol
=
0
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
33
,
56
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
6
,
16
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"softmax"
,
"sigmoid"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
])
def
test_fused_topk_bias
(
num_tokens
:
int
,
hidden_size
:
int
,
num_experts
:
int
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
,
dtype
:
torch
.
dtype
,
):
torch
.
manual_seed
(
0
)
hidden_states
=
torch
.
randn
((
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
e_score_correction_bias
=
torch
.
randn
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_weights_ref
,
topk_ids_ref
=
torch_topk
(
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
e_score_correction_bias
=
e_score_correction_bias
,
scoring_func
=
scoring_func
,
)
topk_weights
,
topk_ids
=
fused_topk_bias
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
e_score_correction_bias
=
e_score_correction_bias
,
topk
=
topk
,
renormalize
=
renormalize
,
scoring_func
=
scoring_func
,
)
torch
.
testing
.
assert_close
(
topk_weights_ref
.
to
(
torch
.
float32
),
topk_weights
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
topk_ids_ref
.
to
(
torch
.
int32
),
topk_ids
,
atol
=
0
,
rtol
=
0
)
tests/model_executor/test_enabled_custom_ops.py
View file @
63227acc
...
...
@@ -18,7 +18,9 @@ from vllm.model_executor.layers.activation import (
SiluAndMul
,
)
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
(
dispatch_topk_func
,
dispatch_topk_sigmoid_func
,
dispatch_topk_softmax_func
,
vllm_topk_sigmoid
,
vllm_topk_softmax
,
)
from
vllm.model_executor.layers.layernorm
import
(
...
...
@@ -133,8 +135,8 @@ def test_enabled_ops_invalid(env: str):
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
]
)
def
test_topk_dispatch
(
use_rocm_aiter
:
bool
):
topk_func
=
dispatch_topk_func
(
use_rocm_aiter
)
def
test_topk_
softmax_
dispatch
(
use_rocm_aiter
:
bool
):
topk_func
=
dispatch_topk_
softmax_
func
(
use_rocm_aiter
)
if
current_platform
.
is_rocm
()
and
use_rocm_aiter
:
assert
topk_func
==
rocm_aiter_ops
.
topk_softmax
...
...
@@ -142,6 +144,18 @@ def test_topk_dispatch(use_rocm_aiter: bool):
assert
topk_func
==
vllm_topk_softmax
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
]
)
def
test_topk_sigmoid_dispatch
(
use_rocm_aiter
:
bool
):
topk_func
=
dispatch_topk_sigmoid_func
(
use_rocm_aiter
)
if
current_platform
.
is_rocm
()
and
use_rocm_aiter
:
assert
topk_func
==
rocm_aiter_ops
.
topk_sigmoid
else
:
assert
topk_func
==
vllm_topk_sigmoid
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
])
...
...
vllm/_aiter_ops.py
View file @
63227acc
...
...
@@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake(
pass
def
_rocm_aiter_topk_sigmoid_impl
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
)
->
None
:
from
aiter
import
topk_sigmoid
topk_sigmoid
(
topk_weights
,
topk_indices
,
gating_output
)
def
_rocm_aiter_topk_sigmoid_fake
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
)
->
None
:
pass
def
_rocm_aiter_biased_grouped_topk_impl
(
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
...
...
@@ -985,6 +1003,14 @@ class rocm_aiter_ops:
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_topk_sigmoid"
,
op_func
=
_rocm_aiter_topk_sigmoid_impl
,
mutates_args
=
[
"topk_weights"
,
"topk_indices"
],
fake_impl
=
_rocm_aiter_topk_sigmoid_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_biased_grouped_topk"
,
op_func
=
_rocm_aiter_biased_grouped_topk_impl
,
...
...
@@ -1272,6 +1298,19 @@ class rocm_aiter_ops:
)
return
topk_weights
,
topk_indices
@
staticmethod
def
topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
...]:
torch
.
ops
.
vllm
.
rocm_aiter_topk_sigmoid
(
topk_weights
,
topk_indices
,
gating_output
)
return
topk_weights
,
topk_indices
@
staticmethod
def
biased_grouped_topk
(
gating_output
:
torch
.
Tensor
,
...
...
vllm/_custom_ops.py
View file @
63227acc
...
...
@@ -2177,9 +2177,33 @@ def topk_softmax(
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
def
topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_sigmoid
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
63227acc
...
...
@@ -106,14 +106,14 @@ def _quant_flags_to_group_shape(
class
RoutingMethodType
(
IntEnum
):
# Default: Softmax -> TopK
Default
=
(
0
,)
# Renormalize: TopK -> Softmax
# Renormalize: TopK -> Softmax
/Sigmoid
Renormalize
=
(
1
,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3
=
(
2
,)
# Llama4: Top1 -> Sigmoid
Llama4
=
(
3
,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
# RenormalizeNaive: Softmax
/Sigmoid
-> TopK -> Renormalize
RenormalizeNaive
=
(
4
,)
# TopK: TopK (no softmax)
TopK
=
(
5
,)
...
...
vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
View file @
63227acc
...
...
@@ -4,6 +4,8 @@ from collections.abc import Callable
import
torch
import
vllm._custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
...
...
@@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from
vllm.model_executor.layers.fused_moe.router.base_router
import
BaseRouter
def
vllm_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_indices
def
vllm_topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_sigmoid
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_indices
def
fused_topk_bias
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
scoring_func
:
str
=
"softmax"
,
indices_type
:
torch
.
dtype
|
None
=
None
,
):
if
not
rocm_aiter_ops
.
is_fused_moe_enabled
():
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
(
"Number of tokens mismatch"
)
M
,
_
=
hidden_states
.
size
()
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
if
indices_type
is
None
else
indices_type
,
device
=
hidden_states
.
device
,
)
token_expert_indices
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
if
scoring_func
==
"softmax"
:
topk_weights
,
topk_ids
=
vllm_topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_ids
elif
scoring_func
==
"sigmoid"
:
topk_weights
,
topk_ids
=
vllm_topk_sigmoid
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
,
e_score_correction_bias
,
)
return
topk_weights
,
topk_ids
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
n_routed_experts
=
gating_output
.
shape
[
-
1
]
scores
=
gating_output
.
softmax
(
dim
=-
1
)
if
scoring_func
==
"softmax"
:
scores
=
gating_output
.
softmax
(
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
scores_for_choice
=
scores
.
view
(
-
1
,
n_routed_experts
)
+
e_score_correction_bias
.
unsqueeze
(
0
)
...
...
@@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter):
global_num_experts
:
int
,
eplb_state
:
EplbLayerState
,
e_score_correction_bias
:
torch
.
Tensor
,
scoring_func
:
str
,
renormalize
:
bool
=
True
,
routed_scaling_factor
:
float
=
1.0
,
enable_eplb
:
bool
=
False
,
...
...
@@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter):
)
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
renormalize
=
renormalize
self
.
scoring_func
=
scoring_func
self
.
routed_scaling_factor
=
routed_scaling_factor
@
property
...
...
@@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter):
e_score_correction_bias
=
self
.
e_score_correction_bias
.
data
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
scoring_func
=
self
.
scoring_func
,
)
if
self
.
routed_scaling_factor
!=
1.0
:
...
...
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
View file @
63227acc
...
...
@@ -16,7 +16,7 @@ def vllm_topk_softmax(
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
,
renormalize
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_softmax
(
topk_weights
,
...
...
@@ -29,7 +29,25 @@ def vllm_topk_softmax(
return
topk_weights
,
topk_indices
def
dispatch_topk_func
(
def
vllm_topk_sigmoid
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_sigmoid
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
,
)
return
topk_weights
,
topk_indices
def
dispatch_topk_softmax_func
(
use_rocm_aiter
:
bool
=
False
,
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
use_rocm_aiter
:
...
...
@@ -37,12 +55,21 @@ def dispatch_topk_func(
return
vllm_topk_softmax
def
dispatch_topk_sigmoid_func
(
use_rocm_aiter
:
bool
=
False
,
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
use_rocm_aiter
:
return
rocm_aiter_ops
.
topk_sigmoid
return
vllm_topk_sigmoid
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
indices_type
:
torch
.
dtype
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
"Number of tokens mismatch"
...
...
@@ -61,12 +88,26 @@ def fused_topk(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_func
=
dispatch_topk_func
(
use_rocm_aiter
=
rocm_aiter_ops
.
is_fused_moe_enabled
())
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
)
if
scoring_func
==
"softmax"
:
topk_func
=
dispatch_topk_softmax_func
(
use_rocm_aiter
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
)
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
)
return
topk_weights
,
topk_ids
,
token_expert_indices
elif
scoring_func
==
"sigmoid"
:
topk_func
=
dispatch_topk_sigmoid_func
(
use_rocm_aiter
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
)
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
,
renormalize
)
return
topk_weights
,
topk_ids
,
token_expert_indices
return
topk_weights
,
topk_ids
,
token_expert_indices
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
class
FusedTopKRouter
(
BaseRouter
):
...
...
@@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter):
enable_eplb
:
bool
=
False
,
indices_type_getter
:
Callable
[[],
torch
.
dtype
|
None
]
|
None
=
None
,
):
assert
scoring_func
==
"softmax"
,
"FusedTopKRouter only supports softmax."
super
().
__init__
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
...
...
@@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter):
indices_type_getter
=
indices_type_getter
,
)
self
.
renormalize
=
renormalize
self
.
scoring_func
=
scoring_func
@
property
def
routing_method_type
(
self
)
->
RoutingMethodType
:
...
...
@@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter):
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
indices_type
=
indices_type
,
scoring_func
=
self
.
scoring_func
,
)
return
topk_weights
,
topk_ids
vllm/model_executor/layers/fused_moe/router/router_factory.py
View file @
63227acc
...
...
@@ -143,17 +143,13 @@ def create_fused_moe_router(
router
.
capture
=
capture
return
router
if
scoring_func
!=
"softmax"
:
raise
ValueError
(
"Only softmax scoring function is supported for non-grouped topk."
)
if
e_score_correction_bias
is
not
None
:
router
=
FusedTopKBiasRouter
(
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
eplb_state
=
eplb_state
,
e_score_correction_bias
=
e_score_correction_bias
,
scoring_func
=
scoring_func
,
renormalize
=
renormalize
,
routed_scaling_factor
=
routed_scaling_factor
,
enable_eplb
=
enable_eplb
,
...
...
vllm/model_executor/models/minimax_m2.py
View file @
63227acc
...
...
@@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module):
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
scoring_func
=
config
.
scoring_func
,
use_grouped_topk
=
True
,
num_expert_group
=
1
,
topk_group
=
1
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment