Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e919d6f5
Unverified
Commit
e919d6f5
authored
Sep 03, 2025
by
Qiming Zhang
Committed by
GitHub
Sep 04, 2025
Browse files
[Kernel][Bugfix] Fix grouped topk cu (#24146)
Signed-off-by:
mayuyuace
<
qiming1.zhang@intel.com
>
parent
a38f8bd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+7
-6
No files found.
csrc/moe/grouped_topk_kernels.cu
View file @
e919d6f5
...
@@ -28,6 +28,7 @@ namespace cg = cooperative_groups;
...
@@ -28,6 +28,7 @@ namespace cg = cooperative_groups;
namespace
vllm
{
namespace
vllm
{
namespace
moe
{
namespace
moe
{
constexpr
float
kNegInfinity
=
INFINITY
*
-
1
;
constexpr
unsigned
FULL_WARP_MASK
=
0xffffffff
;
constexpr
unsigned
FULL_WARP_MASK
=
0xffffffff
;
constexpr
int32_t
WARP_SIZE
=
32
;
constexpr
int32_t
WARP_SIZE
=
32
;
constexpr
int32_t
BLOCK_SIZE
=
512
;
constexpr
int32_t
BLOCK_SIZE
=
512
;
...
@@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel(
...
@@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel(
warp_id
*
topk
;
warp_id
*
topk
;
s_topk_idx
+=
warp_id
*
topk
;
s_topk_idx
+=
warp_id
*
topk
;
T
value
=
cuda
::
std
::
numeric_limits
<
T
>::
min
()
;
T
value
=
kNegInfinity
;
T
topk_group_value
=
cuda
::
std
::
numeric_limits
<
T
>::
min
()
;
T
topk_group_value
=
kNegInfinity
;
int32_t
num_equalto_topkth_group
;
int32_t
num_equalto_topkth_group
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
...
@@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel(
...
@@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel(
__syncwarp
();
// Ensure all threads have valid data before reduction
__syncwarp
();
// Ensure all threads have valid data before reduction
topk_group_value
=
cg
::
reduce
(
tile
,
value
,
cg
::
greater
<
T
>
());
topk_group_value
=
cg
::
reduce
(
tile
,
value
,
cg
::
greater
<
T
>
());
if
(
value
==
topk_group_value
)
{
if
(
value
==
topk_group_value
)
{
value
=
cuda
::
std
::
numeric_limits
<
T
>::
min
()
;
value
=
kNegInfinity
;
}
}
pre_count_equal_to_top_value
=
count_equal_to_top_value
;
pre_count_equal_to_top_value
=
count_equal_to_top_value
;
count_equal_to_top_value
=
__popc
(
__ballot_sync
(
count_equal_to_top_value
=
__popc
(
__ballot_sync
(
FULL_WARP_MASK
,
(
value
==
cuda
::
std
::
numeric_limits
<
T
>::
min
(
))));
FULL_WARP_MASK
,
(
value
==
cuda
_cast
<
T
,
float
>
(
kNegInfinity
))));
}
}
num_equalto_topkth_group
=
target_num_min
-
pre_count_equal_to_top_value
;
num_equalto_topkth_group
=
target_num_min
-
pre_count_equal_to_top_value
;
}
}
...
@@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel(
...
@@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel(
int
count_equalto_topkth_group
=
0
;
int
count_equalto_topkth_group
=
0
;
bool
if_proceed_next_topk
=
bool
if_proceed_next_topk
=
(
topk_group_value
!=
cuda
::
std
::
numeric_limits
<
T
>::
min
(
));
(
topk_group_value
!=
cuda
_cast
<
T
,
float
>
(
kNegInfinity
));
if
(
case_id
<
num_tokens
&&
if_proceed_next_topk
)
{
if
(
case_id
<
num_tokens
&&
if_proceed_next_topk
)
{
for
(
int
i_group
=
0
;
i_group
<
n_group
;
i_group
++
)
{
for
(
int
i_group
=
0
;
i_group
<
n_group
;
i_group
++
)
{
if
((
group_scores
[
i_group
]
>
topk_group_value
)
||
if
((
group_scores
[
i_group
]
>
topk_group_value
)
||
...
@@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel(
...
@@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel(
(
i
<
num_experts_per_group
)
&&
isfinite
(
cuda_cast
<
float
,
T
>
(
(
i
<
num_experts_per_group
)
&&
isfinite
(
cuda_cast
<
float
,
T
>
(
scores_with_bias
[
offset
+
i
]))
scores_with_bias
[
offset
+
i
]))
?
scores_with_bias
[
offset
+
i
]
?
scores_with_bias
[
offset
+
i
]
:
cuda
::
std
::
numeric_limits
<
T
>::
min
(
);
:
cuda
_cast
<
T
,
float
>
(
kNegInfinity
);
queue
.
add
(
candidates
,
offset
+
i
);
queue
.
add
(
candidates
,
offset
+
i
);
}
}
if
(
group_scores
[
i_group
]
==
topk_group_value
)
{
if
(
group_scores
[
i_group
]
==
topk_group_value
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment