"docs/vscode:/vscode.git/clone" did not exist on "3fd02bda51ee7cf07e0375994ac1f34b6d1b981b"
Unverified Commit 527821d1 authored by Ming Yang's avatar Ming Yang Committed by GitHub
Browse files

Use macro guard CUDA functions for back compatibility in grouped_topk_kernel.cu (#25346)


Signed-off-by: default avatarMing Yang <minos.future@gmail.com>
Signed-off-by: default avatarRahul Tuli <rtuli@redhat.com>
Co-authored-by: default avatarRahul Tuli <rtuli@redhat.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarLu Fang <30275821+houseroad@users.noreply.github.com>
Co-authored-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
parent 846197f5
......@@ -418,6 +418,15 @@ __device__ inline T neg_inf() {
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
}
template <typename T>
__device__ inline bool is_finite(const T val) {
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
return cuda::std::isfinite(val);
#else
return isfinite(cuda_cast<float, T>(val));
#endif
}
template <typename T>
__device__ void topk_with_k2(T* output, T const* input,
cg::thread_block_tile<32> const& tile,
......@@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel(
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
// The check is necessary to avoid abnormal input
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
value = group_scores[lane_id];
}
......@@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates =
(i < num_experts_per_group) &&
cuda::std::isfinite(scores_with_bias[offset + i])
? scores_with_bias[offset + i]
: neg_inf<T>();
T candidates = (i < num_experts_per_group) &&
is_finite(scores_with_bias[offset + i])
? scores_with_bias[offset + i]
: neg_inf<T>();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment