Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f28125d8
Unverified
Commit
f28125d8
authored
Jan 13, 2026
by
Wentao Ye
Committed by
GitHub
Jan 13, 2026
Browse files
[Perf] Optimize grouped topk kernel, 1.2%~2% E2E Throughput improvement (#32058)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
46f8c6b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
186 additions
and
283 deletions
+186
-283
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+181
-283
tests/models/utils.py
tests/models/utils.py
+5
-0
No files found.
csrc/moe/grouped_topk_kernels.cu
View file @
f28125d8
...
...
@@ -31,8 +31,6 @@ namespace moe {
constexpr
unsigned
FULL_WARP_MASK
=
0xffffffff
;
constexpr
int32_t
WARP_SIZE
=
32
;
constexpr
int32_t
BLOCK_SIZE
=
512
;
constexpr
int32_t
NUM_WARPS_PER_BLOCK
=
BLOCK_SIZE
/
WARP_SIZE
;
namespace
warp_topk
{
...
...
@@ -65,14 +63,6 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
return
res
;
}
template
<
typename
T
,
typename
idxT
>
int
calc_smem_size_for_block_wide
(
int
num_of_warp
,
int64_t
k
)
{
int64_t
cache_topk
=
(
sizeof
(
T
)
+
sizeof
(
idxT
))
*
num_of_warp
*
k
;
int64_t
n
=
std
::
max
<
int
>
(
num_of_warp
/
2
*
k
,
num_of_warp
*
WARP_SIZE
);
return
max
(
cache_topk
,
round_up_to_multiple_of
<
256
>
(
n
*
sizeof
(
T
))
+
n
*
sizeof
(
idxT
));
}
template
<
int
size
,
bool
ascending
,
bool
reverse
,
typename
T
,
typename
idxT
,
bool
is_stable
>
struct
BitonicMerge
{
...
...
@@ -267,6 +257,15 @@ class WarpSort {
}
}
// Accessors for per-lane selected value/index.
// NOTE: For the common case `capacity == WARP_SIZE`, `max_arr_len_ == 1`
// and callers should use `i == 0`.
__device__
__forceinline__
idxT
get_idx
(
int
i
=
0
)
const
{
return
idx_arr_
[
i
];
}
__device__
__forceinline__
T
get_val
(
int
i
=
0
)
const
{
return
val_arr_
[
i
];
}
protected:
static
constexpr
int
max_arr_len_
=
capacity
/
WARP_SIZE
;
...
...
@@ -285,6 +284,7 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
__device__
WarpSelect
(
idxT
k
,
T
dummy
)
:
WarpSort
<
capacity
,
greater
,
T
,
idxT
,
is_stable
>
(
k
,
dummy
),
k_th_
(
dummy
),
k_th_idx_
(
0
),
k_th_lane_
((
k
-
1
)
%
WARP_SIZE
)
{
extern
__shared__
char
smem_buf
[];
// extern __shared__ T smem_buf[];
...
...
@@ -346,9 +346,6 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
idxT
idx
=
(
lane_
<
smem_buf_len_
)
?
idx_smem_
[
lane_
]
:
0
;
merge_buf_
(
val
,
idx
);
}
// after done(), smem is used for merging results among warps
__syncthreads
();
}
private:
...
...
@@ -503,255 +500,186 @@ __device__ void topk_with_k2(T* output, T const* input, BiasT const* bias,
}
}
template
<
typename
T
,
typename
BiasT
,
ScoringFunc
SF
>
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
BiasT
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_cases
,
int64_t
const
n_group
,
int64_t
const
num_experts_per_group
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
case_id
=
blockIdx
.
x
*
NUM_WARPS_PER_BLOCK
+
warp_id
;
if
(
case_id
<
num_cases
)
{
input
+=
case_id
*
num_experts_per_group
;
// bias is per expert group, offset to current group
int32_t
group_id
=
case_id
%
n_group
;
BiasT
const
*
group_bias
=
bias
+
group_id
*
num_experts_per_group
;
output
+=
case_id
;
cg
::
thread_block
block
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
tile
=
cg
::
tiled_partition
<
32
>
(
block
);
template
<
typename
T
,
typename
BiasT
,
typename
IdxT
,
ScoringFunc
SF
>
__global__
void
grouped_topk_fused_kernel
(
T
*
scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
BiasT
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_experts
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
bool
renormalize
,
double
routed_scaling_factor
)
{
int32_t
const
token_id
=
static_cast
<
int32_t
>
(
blockIdx
.
x
);
if
(
token_id
>=
num_tokens
)
{
return
;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
topk_with_k2
<
T
,
BiasT
,
SF
>
(
output
,
input
,
group_bias
,
tile
,
lane_id
,
num_experts_per_group
);
int32_t
const
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
const
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
const
n_group_i32
=
static_cast
<
int32_t
>
(
n_group
);
int32_t
const
topk_group_i32
=
static_cast
<
int32_t
>
(
topk_group
);
int32_t
const
topk_i32
=
static_cast
<
int32_t
>
(
topk
);
int32_t
const
num_experts_i32
=
static_cast
<
int32_t
>
(
num_experts
);
int32_t
const
num_warps
=
blockDim
.
x
/
WARP_SIZE
;
if
(
warp_id
>=
n_group_i32
||
num_warps
<
n_group_i32
)
{
return
;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
template
<
typename
T
,
typename
BiasT
,
typename
IdxT
,
ScoringFunc
SF
,
int
NGroup
=
-
1
>
__global__
void
group_idx_and_topk_idx_kernel
(
T
*
scores
,
T
const
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
BiasT
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
renormalize
,
double
routed_scaling_factor
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
case_id
=
blockIdx
.
x
*
NUM_WARPS_PER_BLOCK
+
warp_id
;
// one per token
scores
+=
case_id
*
num_experts
;
group_scores
+=
case_id
*
n_group
;
topk_values
+=
case_id
*
topk
;
topk_indices
+=
case_id
*
topk
;
constexpr
bool
kUseStaticNGroup
=
(
NGroup
>
0
);
// use int32 to avoid implicit conversion
int32_t
const
n_group_i32
=
kUseStaticNGroup
?
NGroup
:
static_cast
<
int32_t
>
(
n_group
);
int32_t
align_num_experts_per_group
=
warp_topk
::
round_up_to_multiple_of
<
WARP_SIZE
>
(
num_experts_per_group
);
int32_t
const
num_experts_per_group
=
num_experts_i32
/
n_group_i32
;
T
*
scores_token
=
scores
+
static_cast
<
int64_t
>
(
token_id
)
*
num_experts
;
cg
::
thread_block
block
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
tile
=
cg
::
tiled_partition
<
32
>
(
block
);
extern
__shared__
char
smem_buf
[];
// NOTE: reuse the shared memory here to
// store the target topk idx
int32_t
*
s_topk_idx
=
reinterpret_cast
<
int32_t
*>
(
smem_buf
);
T
*
s_topk_value
=
reinterpret_cast
<
T
*>
(
s_topk_idx
+
NUM_WARPS_PER_BLOCK
*
topk
)
+
warp_id
*
topk
;
s_topk_idx
+=
warp_id
*
topk
;
T
value
=
neg_inf
<
T
>
();
T
topk_group_value
=
neg_inf
<
T
>
();
int32_t
num_equalto_topkth_group
;
extern
__shared__
char
smem_buf
[];
// warpSelect internal staging buffer layout
size_t
const
val_bytes
=
static_cast
<
size_t
>
(
num_warps
)
*
WARP_SIZE
*
sizeof
(
T
);
size_t
const
val_bytes_aligned
=
warp_topk
::
round_up_to_multiple_of
<
256
>
(
val_bytes
);
size_t
const
idx_bytes
=
static_cast
<
size_t
>
(
num_warps
)
*
WARP_SIZE
*
sizeof
(
int32_t
);
size_t
const
internal_bytes
=
val_bytes_aligned
+
idx_bytes
;
// user-managed shared memory starts after warpSelect internal staging.
uintptr_t
ptr_u
=
reinterpret_cast
<
uintptr_t
>
(
smem_buf
+
internal_bytes
);
ptr_u
=
(
ptr_u
+
15
)
&
~
static_cast
<
uintptr_t
>
(
15
);
// align to 16B
T
*
s_group_scores
=
reinterpret_cast
<
T
*>
(
ptr_u
);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
// I think all prolog can be put before
// acqbulk because it's ptr arithmetic
#endif
if
(
case_id
<
num_tokens
)
{
// calculate group_idx
int32_t
target_num_min
=
WARP_SIZE
-
n_group_i32
+
static_cast
<
int32_t
>
(
topk_group
);
// The check is necessary to avoid abnormal input
if
(
lane_id
<
n_group_i32
&&
is_finite
(
group_scores
[
lane_id
]))
{
value
=
group_scores
[
lane_id
];
}
// phase 1: per-group scan
int32_t
const
group_offset
=
warp_id
*
num_experts_per_group
;
topk_with_k2
<
T
,
BiasT
,
SF
>
(
s_group_scores
+
warp_id
,
scores_token
+
group_offset
,
bias
+
group_offset
,
tile
,
lane_id
,
num_experts_per_group
);
int
count_equal_to_top_value
=
WARP_SIZE
-
n_group_i32
;
int
pre_count_equal_to_top_value
=
0
;
// Use loop to find the largset top_group
while
(
count_equal_to_top_value
<
target_num_min
)
{
topk_group_value
=
cg
::
reduce
(
tile
,
value
,
cg
::
greater
<
T
>
());
if
(
value
==
topk_group_value
)
{
value
=
neg_inf
<
T
>
();
}
pre_count_equal_to_top_value
=
count_equal_to_top_value
;
count_equal_to_top_value
=
__popc
(
__ballot_sync
(
FULL_WARP_MASK
,
(
value
==
neg_inf
<
T
>
())));
}
num_equalto_topkth_group
=
target_num_min
-
pre_count_equal_to_top_value
;
}
__syncthreads
();
// phase 2: warp0 selects groups + merges candidates to final topk
if
(
warp_id
!=
0
)
{
return
;
}
topk_values
+=
static_cast
<
int64_t
>
(
token_id
)
*
topk
;
topk_indices
+=
static_cast
<
int64_t
>
(
token_id
)
*
topk
;
// select topk_group groups by group score
warp_topk
::
WarpSelect
<
/*capability*/
WARP_SIZE
,
/*greater*/
true
,
T
,
int32_t
,
/* is_stable */
true
>
queue
((
int32_t
)
topk
,
neg_inf
<
T
>
());
int
count_equalto_topkth_group
=
0
;
bool
if_proceed_next_topk
=
topk_group_value
!=
neg_inf
<
T
>
();
if
(
case_id
<
num_tokens
&&
if_proceed_next_topk
)
{
auto
process_group
=
[
&
](
int
i_group
)
{
if
((
group_scores
[
i_group
]
>
topk_group_value
)
||
((
group_scores
[
i_group
]
==
topk_group_value
)
&&
(
count_equalto_topkth_group
<
num_equalto_topkth_group
)))
{
int32_t
offset
=
i_group
*
num_experts_per_group
;
for
(
int32_t
i
=
lane_id
;
i
<
align_num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
candidates
=
neg_inf
<
T
>
();
if
(
i
<
num_experts_per_group
)
{
// apply scoring function (if any) and add bias
T
input
=
scores
[
offset
+
i
];
if
(
is_finite
(
input
))
{
T
score
=
apply_scoring
<
SF
>
(
input
);
candidates
=
score
+
static_cast
<
T
>
(
bias
[
offset
+
i
]);
}
}
queue
.
add
(
candidates
,
offset
+
i
);
}
if
(
group_scores
[
i_group
]
==
topk_group_value
)
{
count_equalto_topkth_group
++
;
}
}
};
group_sel
(
static_cast
<
int32_t
>
(
topk_group_i32
),
neg_inf
<
T
>
());
// all lanes must participate in WarpSelect::add().
T
gscore
=
(
lane_id
<
n_group_i32
)
?
s_group_scores
[
lane_id
]
:
neg_inf
<
T
>
();
group_sel
.
add
(
gscore
,
lane_id
);
group_sel
.
done
();
// proceed only if the k-th selected group score is not -inf
bool
proceed
=
false
;
if
(
topk_group_i32
>
0
)
{
int
const
kth_lane
=
topk_group_i32
-
1
;
// broadcast the k-th selected group score to all lanes
T
kth_val
=
__shfl_sync
(
FULL_WARP_MASK
,
group_sel
.
get_val
(
0
),
kth_lane
);
proceed
=
(
kth_val
!=
neg_inf
<
T
>
());
}
if
constexpr
(
kUseStaticNGroup
)
{
#pragma unroll
for
(
int
i_group
=
0
;
i_group
<
NGroup
;
++
i_group
)
{
process_group
(
i_group
);
}
}
else
{
for
(
int
i_group
=
0
;
i_group
<
n_group_i32
;
++
i_group
)
{
process_group
(
i_group
);
}
if
(
!
proceed
)
{
for
(
int
i
=
lane_id
;
i
<
topk_i32
;
i
+=
WARP_SIZE
)
{
topk_indices
[
i
]
=
static_cast
<
IdxT
>
(
i
);
topk_values
[
i
]
=
1.0
f
/
static_cast
<
float
>
(
topk_i32
);
}
queue
.
done
();
// Get the topk_idx
queue
.
dumpIdx
(
s_topk_idx
);
}
// Load the valid score value
// Calculate the summation
float
topk_sum
=
1e-20
;
if
(
case_id
<
num_tokens
&&
if_proceed_next_topk
)
{
for
(
int
i
=
lane_id
;
i
<
warp_topk
::
round_up_to_multiple_of
<
WARP_SIZE
>
(
topk
);
i
+=
WARP_SIZE
)
{
T
value
=
cuda_cast
<
T
,
float
>
(
0.0
f
);
if
(
i
<
topk
)
{
// Load the score value (without bias) for normalization
T
input
=
scores
[
s_topk_idx
[
i
]];
value
=
apply_scoring
<
SF
>
(
input
);
s_topk_value
[
i
]
=
value
;
}
if
(
renormalize
)
{
topk_sum
+=
cg
::
reduce
(
tile
,
cuda_cast
<
float
,
T
>
(
value
),
cg
::
plus
<
float
>
());
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
return
;
}
// merge per-group topk candidates for selected groups, then select topk
warp_topk
::
WarpSelect
<
/*capability*/
WARP_SIZE
,
/*greater*/
true
,
T
,
int32_t
,
/* is_stable */
true
>
expert_sel
(
static_cast
<
int32_t
>
(
topk_i32
),
neg_inf
<
T
>
());
// selected group ids reside in lanes [0, topk_group)
int32_t
sel_gid_lane
=
(
lane_id
<
topk_group_i32
)
?
group_sel
.
get_idx
(
0
)
:
0
;
// add candidates from selected groups to expert_sel
for
(
int32_t
g
=
0
;
g
<
topk_group_i32
;
++
g
)
{
int32_t
gid
=
__shfl_sync
(
FULL_WARP_MASK
,
sel_gid_lane
,
g
);
int32_t
const
offset
=
gid
*
num_experts_per_group
;
int32_t
const
align_num_experts_per_group
=
warp_topk
::
round_up_to_multiple_of
<
WARP_SIZE
>
(
num_experts_per_group
);
for
(
int32_t
i
=
lane_id
;
i
<
align_num_experts_per_group
;
i
+=
WARP_SIZE
)
{
// all lanes must call `add()` the same number of times.
T
cand
=
neg_inf
<
T
>
();
int32_t
idx
=
0
;
if
(
i
<
num_experts_per_group
)
{
idx
=
offset
+
i
;
T
input
=
scores_token
[
idx
];
if
(
is_finite
(
input
))
{
T
score
=
apply_scoring
<
SF
>
(
input
);
cand
=
score
+
static_cast
<
T
>
(
bias
[
idx
]);
}
}
expert_sel
.
add
(
cand
,
idx
);
}
}
expert_sel
.
done
();
// compute unbiased routing weights + optional renorm.
float
lane_unbiased
=
0.0
f
;
IdxT
lane_idx
=
0
;
if
(
lane_id
<
topk_i32
)
{
lane_idx
=
static_cast
<
IdxT
>
(
expert_sel
.
get_idx
(
0
));
T
in
=
scores_token
[
static_cast
<
int32_t
>
(
lane_idx
)];
lane_unbiased
=
cuda_cast
<
float
,
T
>
(
apply_scoring
<
SF
>
(
in
));
}
__syncthreads
();
float
topk_sum
=
1e-20
f
;
if
(
renormalize
)
{
topk_sum
+=
cg
::
reduce
(
tile
,
lane_unbiased
,
cg
::
plus
<
float
>
());
}
if
(
case_id
<
num_tokens
)
{
if
(
if_proceed_next_topk
)
{
float
scale
=
routed_scaling_factor
;
if
(
renormalize
)
{
scale
/=
topk_sum
;
}
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
WARP_SIZE
)
{
float
base
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
]);
float
value
=
base
*
scale
;
topk_indices
[
i
]
=
s_topk_idx
[
i
];
topk_values
[
i
]
=
value
;
}
}
else
{
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
WARP_SIZE
)
{
topk_indices
[
i
]
=
i
;
topk_values
[
i
]
=
1.0
f
/
topk
;
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
// default result.
float
scale
=
static_cast
<
float
>
(
routed_scaling_factor
);
if
(
renormalize
)
{
scale
/=
topk_sum
;
}
if
(
lane_id
<
topk_i32
)
{
topk_indices
[
lane_id
]
=
lane_idx
;
topk_values
[
lane_id
]
=
lane_unbiased
*
scale
;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
template
<
typename
T
,
typename
BiasT
,
typename
IdxT
,
ScoringFunc
SF
>
inline
void
launch_group_idx_and_topk_kernel
(
cudaLaunchConfig_t
const
&
config
,
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
BiasT
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
const
renormalize
,
double
const
routed_scaling_factor
)
{
auto
launch
=
[
&
](
auto
*
kernel_instance2
)
{
cudaLaunchKernelEx
(
&
config
,
kernel_instance2
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
};
switch
(
n_group
)
{
case
4
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
4
>
);
break
;
}
case
8
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
8
>
);
break
;
}
case
16
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
16
>
);
break
;
}
case
32
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
32
>
);
break
;
}
default:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
>
);
break
;
}
}
}
template
<
typename
T
,
typename
BiasT
,
typename
IdxT
>
void
invokeNoAuxTc
(
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
BiasT
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_experts
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
int
const
scoring_func
,
bool
enable_pdl
=
false
,
cudaStream_t
const
stream
=
0
)
{
int64_t
num_cases
=
num_tokens
*
n_group
;
int64_t
topk_with_k2_num_blocks
=
(
num_cases
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
void
invokeNoAuxTc
(
T
*
scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
BiasT
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_experts
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
int
const
scoring_func
,
bool
enable_pdl
=
false
,
cudaStream_t
const
stream
=
0
)
{
cudaLaunchConfig_t
config
;
config
.
gridDim
=
topk_with_k2_num_blocks
;
config
.
blockDim
=
BLOCK_SIZE
;
config
.
dynamicSmemBytes
=
0
;
// One block per token; one warp per group.
config
.
gridDim
=
static_cast
<
uint32_t
>
(
num_tokens
);
config
.
blockDim
=
static_cast
<
uint32_t
>
(
n_group
)
*
WARP_SIZE
;
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
int32_t
const
num_warps
=
static_cast
<
int32_t
>
(
n_group
);
size_t
const
val_bytes
=
static_cast
<
size_t
>
(
num_warps
)
*
WARP_SIZE
*
sizeof
(
T
);
size_t
const
val_bytes_aligned
=
warp_topk
::
round_up_to_multiple_of
<
256
>
(
val_bytes
);
size_t
const
idx_bytes
=
static_cast
<
size_t
>
(
num_warps
)
*
WARP_SIZE
*
sizeof
(
int32_t
);
size_t
const
internal_bytes
=
val_bytes_aligned
+
idx_bytes
;
size_t
const
extra_bytes
=
16
+
static_cast
<
size_t
>
(
n_group
)
*
sizeof
(
T
);
config
.
dynamicSmemBytes
=
internal_bytes
+
extra_bytes
;
config
.
stream
=
stream
;
cudaLaunchAttribute
attrs
[
1
];
attrs
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
...
...
@@ -759,66 +687,35 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
auto
const
sf
=
static_cast
<
ScoringFunc
>
(
scoring_func
);
int64_t
const
num_experts_per_group
=
num_experts
/
n_group
;
auto
launch_topk_with_k2
=
[
&
](
auto
*
kernel_instance1
)
{
cudaLaunchKernelEx
(
&
config
,
kernel_instance1
,
group_scores
,
scores
,
bias
,
num_tokens
,
num_cases
,
n_group
,
num_experts_per_group
);
};
switch
(
sf
)
{
case
SCORING_NONE
:
{
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
BiasT
,
SCORING_NONE
>
;
launch_topk_with_k2
(
kernel_instance1
);
break
;
auto
*
kernel_instance
=
&
grouped_topk_fused_kernel
<
T
,
BiasT
,
IdxT
,
SCORING_NONE
>
;
cudaLaunchKernelEx
(
&
config
,
kernel_instance
,
scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
);
return
;
}
case
SCORING_SIGMOID
:
{
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
BiasT
,
SCORING_SIGMOID
>
;
launch_topk_with_k2
(
kernel_instance1
);
break
;
auto
*
kernel_instance
=
&
grouped_topk_fused_kernel
<
T
,
BiasT
,
IdxT
,
SCORING_SIGMOID
>
;
cudaLaunchKernelEx
(
&
config
,
kernel_instance
,
scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
);
return
;
}
default:
// should be guarded by higher level checks.
TORCH_CHECK
(
false
,
"Unsupported scoring_func in invokeNoAuxTc"
);
}
int64_t
topk_with_k_group_num_blocks
=
(
num_tokens
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
size_t
dynamic_smem_in_bytes
=
warp_topk
::
calc_smem_size_for_block_wide
<
T
,
int32_t
>
(
NUM_WARPS_PER_BLOCK
,
topk
);
config
.
gridDim
=
topk_with_k_group_num_blocks
;
config
.
blockDim
=
BLOCK_SIZE
;
config
.
dynamicSmemBytes
=
dynamic_smem_in_bytes
;
config
.
stream
=
stream
;
attrs
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
enable_pdl
;
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
switch
(
sf
)
{
case
SCORING_NONE
:
{
launch_group_idx_and_topk_kernel
<
T
,
BiasT
,
IdxT
,
SCORING_NONE
>
(
config
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
break
;
}
case
SCORING_SIGMOID
:
{
launch_group_idx_and_topk_kernel
<
T
,
BiasT
,
IdxT
,
SCORING_SIGMOID
>
(
config
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
break
;
}
default:
TORCH_CHECK
(
false
,
"Unsupported scoring_func in invokeNoAuxTc"
);
}
}
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT)
\
template void invokeNoAuxTc<T, BiasT, IdxT>(
\
T * scores,
T * group_scores,
float* topk_values, IdxT* topk_indices,
\
BiasT const* bias,
int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk,
\
bool const renormalize, double const routed_scaling_factor,
\
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
template void invokeNoAuxTc<T, BiasT, IdxT>( \
T * scores, float* topk_values, IdxT* topk_indices,
BiasT const* bias,
\
int64_t const num_tokens, int64_t const num_experts,
\
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC
(
float
,
float
,
int32_t
);
...
...
@@ -843,17 +740,21 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
int64_t
num_tokens
=
input_size
[
0
];
int64_t
num_experts
=
input_size
[
1
];
TORCH_CHECK
(
input_size
.
size
()
==
2
,
"scores must be a 2D Tensor"
);
TORCH_CHECK
(
n_group
>
0
,
"n_group must be positive"
);
TORCH_CHECK
(
topk
>
0
,
"topk must be positive"
);
TORCH_CHECK
(
topk_group
>
0
,
"topk_group must be positive"
);
TORCH_CHECK
(
topk_group
<=
n_group
,
"topk_group must be <= n_group"
);
TORCH_CHECK
(
num_experts
%
n_group
==
0
,
"num_experts should be divisible by n_group"
);
TORCH_CHECK
(
n_group
<=
32
,
"n_group should be smaller than or equal to 32 for now"
);
TORCH_CHECK
(
topk
<=
32
,
"topk should be smaller than or equal to 32 for now"
);
TORCH_CHECK
(
topk
<=
topk_group
*
(
num_experts
/
n_group
),
"topk must be <= topk_group * (num_experts / n_group)"
);
TORCH_CHECK
(
scoring_func
==
vllm
::
moe
::
SCORING_NONE
||
scoring_func
==
vllm
::
moe
::
SCORING_SIGMOID
,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)"
);
torch
::
Tensor
group_scores
=
torch
::
empty
(
{
num_tokens
,
n_group
},
torch
::
dtype
(
data_type
).
device
(
torch
::
kCUDA
));
// Always output float32 for topk_values (eliminates Python-side conversion)
torch
::
Tensor
topk_values
=
torch
::
empty
(
{
num_tokens
,
topk
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
...
...
@@ -868,7 +769,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
case torch::kFloat16: \
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
...
...
@@ -879,7 +779,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
case torch::kFloat32: \
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
...
...
@@ -890,7 +789,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
case torch::kBFloat16: \
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
...
...
tests/models/utils.py
View file @
f28125d8
...
...
@@ -454,6 +454,9 @@ def dummy_hf_overrides(
# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
n_group
=
getattr
(
text_config
,
"n_group"
,
None
)
# Kimi uses `num_expert_group` instead of `n_group`.
if
n_group
is
None
:
n_group
=
getattr
(
text_config
,
"num_expert_group"
,
None
)
num_experts
=
n_group
*
2
if
n_group
is
not
None
else
2
# we use three layers for Gemma-3n to check
...
...
@@ -487,6 +490,8 @@ def dummy_hf_overrides(
{
"num_experts"
:
num_experts
,
"num_experts_per_tok"
:
2
,
# Kimi uses `num_experts_per_token`.
"num_experts_per_token"
:
2
,
"num_local_experts"
:
num_experts
,
# Otherwise there will not be any expert layers
"first_k_dense_replace"
:
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment