Commit 68026bd1 authored by zhangyue's avatar zhangyue
Browse files

issue/1008: use warpBroadcast api

parent 3d54ce8c
...@@ -194,13 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel( ...@@ -194,13 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -238,11 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel( ...@@ -238,11 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -420,13 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel( ...@@ -420,13 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -803,13 +789,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( ...@@ -803,13 +789,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -849,11 +830,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( ...@@ -849,11 +830,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -1297,11 +1274,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined( ...@@ -1297,11 +1274,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -1992,13 +1965,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( ...@@ -1992,13 +1965,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
#ifdef ENABLE_ILUVATAR_API
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#else
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#endif
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -2038,11 +2006,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( ...@@ -2038,11 +2006,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -2171,11 +2135,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow( ...@@ -2171,11 +2135,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif
const int64_t q_token = q_start + static_cast<int64_t>(q_token_local); const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride; half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment