Unverified Commit 99a802dd authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1063 from InfiniTensor/issue/1061

issue/1061 - feat: use template to replace int64_t in paged_attentio_prefill kernel for moore gpu
parents b2660e66 7f426b26
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace op::paged_attention_prefill::cuda {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) {
template <typename Tindex>
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const Tindex *cum_seq_lens_q, size_t num_seqs) {
size_t low = 0, high = num_seqs - 1;
while (low <= high) {
size_t mid = (low + high) >> 1;
......@@ -48,12 +50,12 @@ __device__ __forceinline__ float blockReduceSum(float val) {
return shared[0];
}
template <typename Tdata, typename Tcompute>
template <typename Tindex, typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cum_seq_lens_q_,
const Tindex *block_tables_,
const Tindex *total_kv_lens_,
const Tindex *cum_seq_lens_q_,
const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size,
......@@ -75,7 +77,7 @@ __global__ void pagedAttentionPrefillKernel(
__shared__ float sh_w;
__shared__ float sh_inv_l;
if (dim_idx == 0) {
sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs);
sh_seq_idx = find_seq_id<Tindex>(global_token_idx, cum_seq_lens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx]);
const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[sh_seq_idx]);
const size_t q_len = static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]);
......@@ -90,7 +92,7 @@ __global__ void pagedAttentionPrefillKernel(
const size_t kv_head_idx = sh_kv_head_idx;
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const Tindex *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const float qv = static_cast<float>(q_vec[dim_idx]);
Tcompute acc = 0.0f;
......
......@@ -8,12 +8,12 @@
#include "paged_attention_prefill_kernel.h"
#include "paged_attention_prefill_moore.h"
template <typename Tdata, typename Tcompute>
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const Tindex *block_tables,
const Tindex *seq_lens,
const Tindex *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
......@@ -36,7 +36,7 @@ infiniStatus_t launchPagedAttentionPrefill(
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tindex, Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
......@@ -99,10 +99,10 @@ infiniStatus_t Descriptor::calculate(
musaStream_t stream = (musaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
return launchPagedAttentionPrefill<Tindex, Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(seq_lens), static_cast<const Tindex *>(cum_seq_lens_q), \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
......@@ -112,12 +112,23 @@ infiniStatus_t Descriptor::calculate(
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__mt_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64){
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32){
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32){
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment