Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
81b16a2b
Unverified
Commit
81b16a2b
authored
Sep 18, 2025
by
Lumina
Committed by
GitHub
Sep 18, 2025
Browse files
[Kernel] Better inf handling for grouped topk cu (#24886)
Signed-off-by:
lumina37
<
starry.qvq@gmail.com
>
parent
e111d5b0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
20 deletions
+24
-20
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+24
-20
No files found.
csrc/moe/grouped_topk_kernels.cu
View file @
81b16a2b
...
...
@@ -21,6 +21,7 @@
#include <torch/all.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda/std/limits>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace
cg
=
cooperative_groups
;
...
...
@@ -28,7 +29,6 @@ namespace cg = cooperative_groups;
namespace
vllm
{
namespace
moe
{
constexpr
float
kNegInfinity
=
INFINITY
*
-
1
;
constexpr
unsigned
FULL_WARP_MASK
=
0xffffffff
;
constexpr
int32_t
WARP_SIZE
=
32
;
constexpr
int32_t
BLOCK_SIZE
=
512
;
...
...
@@ -411,14 +411,21 @@ __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
return
__bfloat162float
(
val
);
}
template
<
typename
T
>
__device__
inline
T
neg_inf
()
{
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
// so we need to cast from fp32
return
cuda_cast
<
T
,
float
>
(
-
cuda
::
std
::
numeric_limits
<
float
>::
infinity
());
}
template
<
typename
T
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
cg
::
thread_block_tile
<
32
>
const
&
tile
,
int32_t
const
lane_id
,
int
const
num_experts_per_group
)
{
// Get the top2 per thread
T
largest
=
-
INFINITY
;
T
second_largest
=
-
INFINITY
;
T
largest
=
neg_inf
<
T
>
()
;
T
second_largest
=
neg_inf
<
T
>
()
;
if
(
num_experts_per_group
>
WARP_SIZE
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
...
...
@@ -513,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel(
warp_id
*
topk
;
s_topk_idx
+=
warp_id
*
topk
;
T
value
=
kNegInfinity
;
T
topk_group_value
=
kNegInfinity
;
T
value
=
neg_inf
<
T
>
()
;
T
topk_group_value
=
neg_inf
<
T
>
()
;
int32_t
num_equalto_topkth_group
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
...
...
@@ -525,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel(
if
(
case_id
<
num_tokens
)
{
// calculate group_idx
int32_t
target_num_min
=
WARP_SIZE
-
n_group
+
topk_group
;
if
(
lane_id
<
n_group
&&
(
isfinite
(
cuda_cast
<
float
,
T
>
(
group_scores
[
lane_id
]))))
// The check is necessary to avoid
// abnormal input
{
// The check is necessary to avoid abnormal input
if
(
lane_id
<
n_group
&&
cuda
::
std
::
isfinite
(
group_scores
[
lane_id
]))
{
value
=
group_scores
[
lane_id
];
}
...
...
@@ -540,11 +544,11 @@ __global__ void group_idx_and_topk_idx_kernel(
__syncwarp
();
// Ensure all threads have valid data before reduction
topk_group_value
=
cg
::
reduce
(
tile
,
value
,
cg
::
greater
<
T
>
());
if
(
value
==
topk_group_value
)
{
value
=
kNegInfinity
;
value
=
neg_inf
<
T
>
()
;
}
pre_count_equal_to_top_value
=
count_equal_to_top_value
;
count_equal_to_top_value
=
__popc
(
__ballot_sync
(
FULL_WARP_MASK
,
(
value
==
cuda_cast
<
T
,
float
>
(
kNegInfinity
))));
count_equal_to_top_value
=
__popc
(
__ballot_sync
(
FULL_WARP_MASK
,
(
value
==
neg_inf
<
T
>
(
))));
}
num_equalto_topkth_group
=
target_num_min
-
pre_count_equal_to_top_value
;
}
...
...
@@ -552,11 +556,10 @@ __global__ void group_idx_and_topk_idx_kernel(
warp_topk
::
WarpSelect
<
/*capability*/
WARP_SIZE
,
/*greater*/
true
,
T
,
int32_t
,
/* is_stable */
true
>
queue
((
int32_t
)
topk
,
-
INFINITY
);
queue
((
int32_t
)
topk
,
neg_inf
<
T
>
()
);
int
count_equalto_topkth_group
=
0
;
bool
if_proceed_next_topk
=
(
topk_group_value
!=
cuda_cast
<
T
,
float
>
(
kNegInfinity
));
bool
if_proceed_next_topk
=
topk_group_value
!=
neg_inf
<
T
>
();
if
(
case_id
<
num_tokens
&&
if_proceed_next_topk
)
{
for
(
int
i_group
=
0
;
i_group
<
n_group
;
i_group
++
)
{
if
((
group_scores
[
i_group
]
>
topk_group_value
)
||
...
...
@@ -566,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel(
for
(
int32_t
i
=
lane_id
;
i
<
align_num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
candidates
=
(
i
<
num_experts_per_group
)
&&
isfinite
(
cuda_cast
<
float
,
T
>
(
scores_with_bias
[
offset
+
i
])
)
(
i
<
num_experts_per_group
)
&&
cuda
::
std
::
isfinite
(
scores_with_bias
[
offset
+
i
])
?
scores_with_bias
[
offset
+
i
]
:
cuda_cast
<
T
,
float
>
(
kNegInfinity
);
:
neg_inf
<
T
>
(
);
queue
.
add
(
candidates
,
offset
+
i
);
}
if
(
group_scores
[
i_group
]
==
topk_group_value
)
{
...
...
@@ -598,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel(
if
(
i
<
topk
)
{
s_topk_value
[
i
]
=
value
;
}
topk_sum
+=
reduce
(
tile
,
cuda_cast
<
float
,
T
>
(
value
),
cg
::
plus
<
float
>
());
topk_sum
+=
cg
::
reduce
(
tile
,
cuda_cast
<
float
,
T
>
(
value
),
cg
::
plus
<
float
>
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment