Commit 7377e711 authored by zhangyue's avatar zhangyue
Browse files

issue/1008: adapt paged_attention_prefill

parent f46e9f65
...@@ -16,17 +16,66 @@ struct OnlineSoftmaxState { ...@@ -16,17 +16,66 @@ struct OnlineSoftmaxState {
} }
}; };
__device__ __forceinline__ float warpReduceSum(float x) { __device__ __forceinline__ float warpReduceSum(float x) {
#if defined(ENABLE_ILUVATAR_API)
// Iluvatar may use warp size 64; __shfl_sync(0xffffffff) only covers 32 threads.
// Use shared-memory tree reduce for portability across warp sizes.
constexpr int kMaxWarps = 16;
__shared__ float _reduce_buf[kMaxWarps * 32];
const int lane = threadIdx.x & 31;
const int warp_id = threadIdx.x / 32;
_reduce_buf[threadIdx.x] = x;
__syncthreads();
for (int offset = 16; offset > 0; offset >>= 1) {
if (lane < offset) {
_reduce_buf[warp_id * 32 + lane] += _reduce_buf[warp_id * 32 + lane + offset];
}
__syncthreads();
}
return _reduce_buf[warp_id * 32];
#else
for (int offset = 16; offset > 0; offset >>= 1) { for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_down_sync(0xffffffff, x, offset); x += __shfl_down_sync(0xffffffff, x, offset);
} }
return x; return x;
#endif
}
__device__ __forceinline__ float warpBroadcast(float x, int src_lane) {
#if defined(ENABLE_ILUVATAR_API)
__shared__ float _bcast_buf[16];
const int warp_id = threadIdx.x / 32;
if ((threadIdx.x & 31) == src_lane) {
_bcast_buf[warp_id] = x;
}
__syncthreads();
return _bcast_buf[warp_id];
#else
return __shfl_sync(0xffffffff, x, src_lane);
#endif
} }
__device__ __forceinline__ float warpReduceMax(float x) { __device__ __forceinline__ float warpReduceMax(float x) {
#if defined(ENABLE_ILUVATAR_API)
__shared__ float _reduce_buf[16 * 32];
const int lane = threadIdx.x & 31;
const int warp_id = threadIdx.x / 32;
_reduce_buf[threadIdx.x] = x;
__syncthreads();
for (int offset = 16; offset > 0; offset >>= 1) {
if (lane < offset) {
float other = _reduce_buf[warp_id * 32 + lane + offset];
float cur = _reduce_buf[warp_id * 32 + lane];
_reduce_buf[warp_id * 32 + lane] = fmaxf(cur, other);
}
__syncthreads();
}
return _reduce_buf[warp_id * 32];
#else
for (int offset = 16; offset > 0; offset >>= 1) { for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_down_sync(0xffffffff, x, offset)); x = fmaxf(x, __shfl_down_sync(0xffffffff, x, offset));
} }
return x; return x;
#endif
} }
__device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) { __device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
......
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ #ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ #define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -194,8 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel( ...@@ -194,8 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
alpha = __shfl_sync(0xffffffff, alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -233,7 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel( ...@@ -233,7 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
inv_l = __shfl_sync(0xffffffff, inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -411,8 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel( ...@@ -411,8 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
alpha = __shfl_sync(0xffffffff, alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -450,7 +450,7 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel( ...@@ -450,7 +450,7 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
inv_l = __shfl_sync(0xffffffff, inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -785,8 +785,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( ...@@ -785,8 +785,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
alpha = __shfl_sync(0xffffffff, alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -826,7 +826,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( ...@@ -826,7 +826,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
inv_l = __shfl_sync(0xffffffff, inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -1270,7 +1270,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined( ...@@ -1270,7 +1270,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
inv_l = __shfl_sync(0xffffffff, inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -1961,8 +1961,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( ...@@ -1961,8 +1961,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
l = l * alpha + beta; l = l * alpha + beta;
m = m_new; m = m_new;
} }
alpha = __shfl_sync(0xffffffff, alpha, 0); alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0); beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
...@@ -2002,7 +2002,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( ...@@ -2002,7 +2002,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
inv_l = __shfl_sync(0xffffffff, inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#pragma unroll #pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) { for (int i = 0; i < DIMS_PER_THREAD; ++i) {
...@@ -2131,7 +2131,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow( ...@@ -2131,7 +2131,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
if (lane == 0) { if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f); inv_l = 1.0f / (l + 1e-6f);
} }
inv_l = __shfl_sync(0xffffffff, inv_l, 0); inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
const int64_t q_token = q_start + static_cast<int64_t>(q_token_local); const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride; half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
......
...@@ -21,6 +21,11 @@ constexpr size_t ceilDiv(size_t a, size_t b) { ...@@ -21,6 +21,11 @@ constexpr size_t ceilDiv(size_t a, size_t b) {
} }
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) { inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Iluvatar: use warp (stable). Users can override via INFINIOP_FLASH_PREFILL_KERNEL.
#ifdef ENABLE_ILUVATAR_API
(void)info;
return "warp";
#endif
// Heuristic auto-dispatch (v0.4): // Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256. // - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80). // - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment