Unverified Commit 61249b17 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Remove useless syncwarp (#30510)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent c817b141
...@@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, ...@@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
largest = value; largest = value;
} }
} }
__syncwarp(); // Ensure all threads have valid data before reduction
// Get the top2 warpwise // Get the top2 warpwise
T max1 = cg::reduce(tile, largest, cg::greater<T>()); T max1 = cg::reduce(tile, largest, cg::greater<T>());
...@@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
int pre_count_equal_to_top_value = 0; int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group // Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) { while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>()); topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) { if (value == topk_group_value) {
value = neg_inf<T>(); value = neg_inf<T>();
...@@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
} }
} }
queue.done(); queue.done();
__syncwarp();
// Get the topk_idx // Get the topk_idx
queue.dumpIdx(s_topk_idx); queue.dumpIdx(s_topk_idx);
__syncwarp();
} }
// Load the valid score value // Load the valid score value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment