Commit 68026bd1 authored by zhangyue's avatar zhangyue
Browse files

issue/1008: use warpBroadcast api

parent 3d54ce8c
......@@ -194,13 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
l = l * alpha + beta;
m = m_new;
}
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
......@@ -238,11 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
......@@ -420,13 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
l = l * alpha + beta;
m = m_new;
}
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
......@@ -803,13 +789,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
l = l * alpha + beta;
m = m_new;
}
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
......@@ -849,11 +830,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
......@@ -1297,11 +1274,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
......@@ -1992,13 +1965,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
l = l * alpha + beta;
m = m_new;
}
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
......@@ -2038,11 +2006,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
......@@ -2171,11 +2135,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment