Commit 7f426b26 authored by zhushuang's avatar zhushuang
Browse files

issue/1061 - feat: use template to replace int64_t in paged_attention_prefill kernel for moore gpu

parent e60985dc
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ #ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ #define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace op::paged_attention_prefill::cuda { namespace op::paged_attention_prefill::cuda {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) {
template <typename Tindex>
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const Tindex *cum_seq_lens_q, size_t num_seqs) {
size_t low = 0, high = num_seqs - 1; size_t low = 0, high = num_seqs - 1;
while (low <= high) { while (low <= high) {
size_t mid = (low + high) >> 1; size_t mid = (low + high) >> 1;
...@@ -48,12 +50,12 @@ __device__ __forceinline__ float blockReduceSum(float val) { ...@@ -48,12 +50,12 @@ __device__ __forceinline__ float blockReduceSum(float val) {
return shared[0]; return shared[0];
} }
template <typename Tdata, typename Tcompute> template <typename Tindex, typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel( __global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_, Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_, const Tindex *block_tables_,
const int64_t *total_kv_lens_, const Tindex *total_kv_lens_,
const int64_t *cum_seq_lens_q_, const Tindex *cum_seq_lens_q_,
const float *alibi_slopes_, const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale, const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size, const size_t max_num_blocks_per_seq, const size_t block_size,
...@@ -75,7 +77,7 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -75,7 +77,7 @@ __global__ void pagedAttentionPrefillKernel(
__shared__ float sh_w; __shared__ float sh_w;
__shared__ float sh_inv_l; __shared__ float sh_inv_l;
if (dim_idx == 0) { if (dim_idx == 0) {
sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs); sh_seq_idx = find_seq_id<Tindex>(global_token_idx, cum_seq_lens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx]); const size_t q_token_idx = global_token_idx - static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx]);
const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[sh_seq_idx]); const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[sh_seq_idx]);
const size_t q_len = static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]); const size_t q_len = static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]);
...@@ -90,7 +92,7 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -90,7 +92,7 @@ __global__ void pagedAttentionPrefillKernel(
const size_t kv_head_idx = sh_kv_head_idx; const size_t kv_head_idx = sh_kv_head_idx;
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride; const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; const Tindex *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const float qv = static_cast<float>(q_vec[dim_idx]); const float qv = static_cast<float>(q_vec[dim_idx]);
Tcompute acc = 0.0f; Tcompute acc = 0.0f;
......
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
#include "paged_attention_prefill_kernel.h" #include "paged_attention_prefill_kernel.h"
#include "paged_attention_prefill_moore.h" #include "paged_attention_prefill_moore.h"
template <typename Tdata, typename Tcompute> template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill( infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const Tindex *block_tables,
const int64_t *seq_lens, const Tindex *seq_lens,
const int64_t *cum_seq_lens_q, const Tindex *cum_seq_lens_q,
const float *alibi_slopes, const float *alibi_slopes,
const size_t num_heads, const size_t num_heads,
const size_t num_seqs, const size_t num_seqs,
...@@ -36,7 +36,7 @@ infiniStatus_t launchPagedAttentionPrefill( ...@@ -36,7 +36,7 @@ infiniStatus_t launchPagedAttentionPrefill(
dim3 grid(total_q_tokens, num_heads); dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size); dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute> op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tindex, Tdata, Tcompute>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
...@@ -99,10 +99,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -99,10 +99,10 @@ infiniStatus_t Descriptor::calculate(
musaStream_t stream = (musaStream_t)stream_; musaStream_t stream = (musaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \ #define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \ return launchPagedAttentionPrefill<Tindex, Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(seq_lens), static_cast<const Tindex *>(cum_seq_lens_q), \
(const float *)alibi_slopes, \ (const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \ _info.scale, _info.max_num_blocks_per_seq, \
...@@ -112,12 +112,23 @@ infiniStatus_t Descriptor::calculate( ...@@ -112,12 +112,23 @@ infiniStatus_t Descriptor::calculate(
_info.q_stride, _info.q_head_stride, \ _info.q_stride, _info.q_head_stride, \
stream) stream)
if (_info.dtype == INFINI_DTYPE_F16) { #define DISPATCH_INDEX(Tindex) \
return LAUNCH_KERNEL(half, float); do { \
} else if (_info.dtype == INFINI_DTYPE_BF16) { if (_info.dtype == INFINI_DTYPE_F16) { \
return LAUNCH_KERNEL(__mt_bfloat16, float); DISPATCH_KERNEL(Tindex, half, float); \
} else if (_info.dtype == INFINI_DTYPE_F32) { } \
return LAUNCH_KERNEL(float, float); if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64){
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32){
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32){
DISPATCH_INDEX(uint32_t);
} }
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment