Commit 6074f7b8 authored by zhushuang's avatar zhushuang
Browse files

issue/1001 - feat: add paged attention prefill for moore gpu referencing nvidia

parent 3d3a277f
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace op::paged_attention_prefill::cuda {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) {
size_t low = 0, high = num_seqs - 1;
while (low <= high) {
size_t mid = (low + high) >> 1;
if (token_idx >= (size_t)cum_seq_lens_q[mid] && token_idx < (size_t)cum_seq_lens_q[mid + 1]) {
return mid;
} else if (token_idx < (size_t)cum_seq_lens_q[mid]) {
high = mid - 1;
} else {
low = mid + 1;
}
}
return 0;
}
// Warp-level sum reduction with an explicit active mask (safe for partial warps).
__device__ __forceinline__ float warpReduceSum(float val, unsigned mask) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(mask, val, offset);
}
return val;
}
// Block-level sum reduction. Returns the sum to all threads in the block.
// Supports blockDim.x up to 1024.
__device__ __forceinline__ float blockReduceSum(float val) {
__shared__ float shared[32]; // max 32 warps per block
const int lane = threadIdx.x & 31;
const int wid = threadIdx.x >> 5;
const unsigned mask = __activemask();
val = warpReduceSum(val, mask);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
const int num_warps = (blockDim.x + 31) >> 5;
float sum = 0.0f;
if (wid == 0) {
sum = (lane < num_warps) ? shared[lane] : 0.0f;
const unsigned mask0 = (num_warps >= 32) ? 0xffffffffu : ((1u << num_warps) - 1u);
sum = warpReduceSum(sum, mask0);
if (lane == 0) {
shared[0] = sum;
}
}
__syncthreads();
return shared[0];
}
template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cum_seq_lens_q_,
const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride, const ptrdiff_t q_head_stride,
const size_t head_size,
const size_t num_seqs) {
// Grid : x -> token, y -> head
const size_t global_token_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;
const size_t dim_idx = threadIdx.x;
if (dim_idx >= head_size) {
return;
}
__shared__ size_t sh_seq_idx;
__shared__ size_t sh_causal_limit;
__shared__ size_t sh_kv_head_idx;
__shared__ float sh_scale_acc;
__shared__ float sh_w;
__shared__ float sh_inv_l;
if (dim_idx == 0) {
sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx]);
const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[sh_seq_idx]);
const size_t q_len = static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]);
const size_t history_len = total_kv_len - q_len;
sh_causal_limit = history_len + q_token_idx;
const size_t num_queries_per_kv = num_heads / num_kv_heads;
sh_kv_head_idx = head_idx / num_queries_per_kv;
}
__syncthreads();
const size_t seq_idx = sh_seq_idx;
const size_t causal_limit = sh_causal_limit;
const size_t kv_head_idx = sh_kv_head_idx;
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const float qv = static_cast<float>(q_vec[dim_idx]);
Tcompute acc = 0.0f;
float m = -FLT_MAX;
float l = 0.0f;
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t b_idx = t / block_size;
const size_t t_off = t % block_size;
const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
const float dot = blockReduceSum(qv * static_cast<float>(k_vec[dim_idx]));
if (dim_idx == 0) {
float score = dot * static_cast<float>(scale);
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(t - causal_limit);
}
const float m_new = fmaxf(m, score);
const float scale_acc = expf(m - m_new);
const float w = expf(score - m_new);
l = l * scale_acc + w;
m = m_new;
sh_scale_acc = scale_acc;
sh_w = w;
}
__syncthreads();
const float scale_acc = sh_scale_acc;
const float w = sh_w;
const Tdata *v_vec = v_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
acc = acc * static_cast<Tcompute>(scale_acc) + static_cast<Tcompute>(w) * static_cast<Tcompute>(v_vec[dim_idx]);
__syncthreads();
}
if (dim_idx == 0) {
sh_inv_l = 1.0f / (l + 1e-6f);
}
__syncthreads();
out_ptr[dim_idx] = static_cast<Tdata>(acc * static_cast<Tcompute>(sh_inv_l));
}
} // namespace op::paged_attention_prefill::cuda
#endif
#ifndef __PAGED_ATTENTION_PREFILL_MOORE_H__
#define __PAGED_ATTENTION_PREFILL_MOORE_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR(moore)
#endif // __PAGED_ATTENTION_PREFILL_MOORE_H__
#include <musa_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "paged_attention_prefill_kernel.h"
#include "paged_attention_prefill_moore.h"
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t page_block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t k_batch_stride,
const ptrdiff_t k_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
musaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, page_block_size,
k_batch_stride, k_head_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
return INFINI_STATUS_SUCCESS;
}
namespace op::paged_attention_prefill::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.page_block_size, _info.total_q_tokens, \
_info.head_size, \
_info.k_batch_stride, _info.k_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__mt_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::moore
......@@ -8,6 +8,9 @@
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_attention_prefill_moore.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
......@@ -44,6 +47,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -71,6 +77,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -105,6 +114,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -131,6 +143,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment