Unverified Commit 8d99a8f5 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #1051 from InfiniTensor/issue/1050

issue/1050 - fix paged caching and paged prefill on metax
parents 3e5cad10 9a0f2505
...@@ -46,8 +46,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp( ...@@ -46,8 +46,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -79,8 +79,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp( ...@@ -79,8 +79,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -112,8 +112,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta( ...@@ -112,8 +112,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -148,8 +148,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta( ...@@ -148,8 +148,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -184,8 +184,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8( ...@@ -184,8 +184,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -220,8 +220,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128( ...@@ -220,8 +220,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -257,8 +257,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8( ...@@ -257,8 +257,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -293,8 +293,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe( ...@@ -293,8 +293,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -329,8 +329,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma( ...@@ -329,8 +329,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
const half *k_cache, const half *k_cache,
const half *v_cache, const half *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -364,8 +364,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe( ...@@ -364,8 +364,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -404,8 +404,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv( ...@@ -404,8 +404,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -447,8 +447,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv( ...@@ -447,8 +447,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -512,8 +512,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16( ...@@ -512,8 +512,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -548,8 +548,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16( ...@@ -548,8 +548,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_kv_heads, size_t num_kv_heads,
float scale, float scale,
...@@ -584,8 +584,8 @@ infiniStatus_t launch_prefill_ref( ...@@ -584,8 +584,8 @@ infiniStatus_t launch_prefill_ref(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -645,8 +645,8 @@ infiniStatus_t launch_prefill_warp( ...@@ -645,8 +645,8 @@ infiniStatus_t launch_prefill_warp(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -712,8 +712,8 @@ infiniStatus_t launch_prefill( ...@@ -712,8 +712,8 @@ infiniStatus_t launch_prefill(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -778,8 +778,8 @@ infiniStatus_t launch_prefill_warpcta8( ...@@ -778,8 +778,8 @@ infiniStatus_t launch_prefill_warpcta8(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -844,8 +844,8 @@ infiniStatus_t launch_prefill_warpcta8pipe( ...@@ -844,8 +844,8 @@ infiniStatus_t launch_prefill_warpcta8pipe(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -910,8 +910,8 @@ infiniStatus_t launch_prefill_warpcta8mma( ...@@ -910,8 +910,8 @@ infiniStatus_t launch_prefill_warpcta8mma(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -1027,8 +1027,8 @@ infiniStatus_t launch_prefill_warpcta8pipe_splitkv( ...@@ -1027,8 +1027,8 @@ infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -1122,8 +1122,8 @@ infiniStatus_t launch_prefill_warpcta8n128( ...@@ -1122,8 +1122,8 @@ infiniStatus_t launch_prefill_warpcta8n128(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -1177,8 +1177,8 @@ infiniStatus_t launch_prefill_warpcta16( ...@@ -1177,8 +1177,8 @@ infiniStatus_t launch_prefill_warpcta16(
const Tdata *k_cache, const Tdata *k_cache,
const Tdata *v_cache, const Tdata *v_cache,
const Tindex *block_tables, const Tindex *block_tables,
const int64_t *total_kv_lens, const Tindex *total_kv_lens,
const int64_t *cu_seqlens_q, const Tindex *cu_seqlens_q,
const float *alibi_slopes, const float *alibi_slopes,
size_t num_heads, size_t num_heads,
size_t num_seqs, size_t num_seqs,
...@@ -1310,8 +1310,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -1310,8 +1310,8 @@ infiniStatus_t Descriptor::calculate(
auto stream = static_cast<hcStream_t>(stream_); auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes); const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens); const void *total_kv_lens_ptr = total_kv_lens;
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q); const void *cu_seqlens_q_ptr = cum_seqlens_q;
bool use_splitkv = false; bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
...@@ -1345,7 +1345,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1345,7 +1345,7 @@ infiniStatus_t Descriptor::calculate(
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size; float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n; float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64. // Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently either int 32 or int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \ #define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \ return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \ partial_acc, partial_m, partial_l, num_splits, \
...@@ -1354,7 +1354,9 @@ infiniStatus_t Descriptor::calculate( ...@@ -1354,7 +1354,9 @@ infiniStatus_t Descriptor::calculate(
static_cast<const Tdata *>(k_cache), \ static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \ static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(total_kv_lens_ptr), \
static_cast<const Tindex *>(cu_seqlens_q_ptr), \
alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1424,7 +1426,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1424,7 +1426,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warp<Tindex, Tdata>( \ return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1437,7 +1439,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1437,7 +1439,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill<Tindex, Tdata>( \ return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1450,7 +1452,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1450,7 +1452,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8<Tindex, Tdata>( \ return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1463,7 +1465,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1463,7 +1465,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \ return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1477,7 +1479,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1477,7 +1479,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8mma<Tindex, Tdata>( \ return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1491,7 +1493,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1491,7 +1493,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8n128<Tindex, Tdata>( \ return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1504,7 +1506,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1504,7 +1506,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta16<Tindex, Tdata>( \ return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1517,7 +1519,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1517,7 +1519,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \ return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
......
...@@ -1312,10 +1312,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -1312,10 +1312,8 @@ infiniStatus_t Descriptor::calculate(
auto stream = static_cast<cudaStream_t>(stream_); auto stream = static_cast<cudaStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes); const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
// const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens); const void *total_kv_lens_ptr = total_kv_lens;
// const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q); const void *cu_seqlens_q_ptr = cum_seqlens_q;
const void *total_kv_lens_i64 = total_kv_lens;
const void *cu_seqlens_q_i64 = cum_seqlens_q;
bool use_splitkv = false; bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
...@@ -1358,8 +1356,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -1358,8 +1356,8 @@ infiniStatus_t Descriptor::calculate(
static_cast<const Tdata *>(k_cache), \ static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \ static_cast<const Tindex *>(BT_PTR), \
static_cast<const Tindex *>(total_kv_lens_i64), \ static_cast<const Tindex *>(total_kv_lens_ptr), \
static_cast<const Tindex *>(cu_seqlens_q_i64), \ static_cast<const Tindex *>(cu_seqlens_q_ptr), \
alibi_ptr, \ alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
...@@ -1429,7 +1427,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1429,7 +1427,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warp<Tindex, Tdata>( \ return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1442,7 +1440,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1442,7 +1440,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill<Tindex, Tdata>( \ return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1455,7 +1453,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1455,7 +1453,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8<Tindex, Tdata>( \ return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1468,7 +1466,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1468,7 +1466,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \ return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1482,7 +1480,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1482,7 +1480,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8mma<Tindex, Tdata>( \ return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1496,7 +1494,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1496,7 +1494,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8n128<Tindex, Tdata>( \ return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1509,7 +1507,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1509,7 +1507,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta16<Tindex, Tdata>( \ return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
...@@ -1522,7 +1520,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -1522,7 +1520,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \ return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \ static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \ static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \ static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \ _info.block_table_batch_stride, \
......
...@@ -12,7 +12,7 @@ INFINIOP_METAX_KERNEL pagedCaching( ...@@ -12,7 +12,7 @@ INFINIOP_METAX_KERNEL pagedCaching(
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride, const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride, const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) { const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>( op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size, k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, block_size, k_src_stride, v_src_stride,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment