Commit 9a0f2505 authored by wooway777's avatar wooway777
Browse files

issue/1050 - fix paged caching and paged prefill on metax

parent a9503148
......@@ -46,8 +46,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -79,8 +79,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -112,8 +112,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -148,8 +148,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -184,8 +184,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -220,8 +220,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -257,8 +257,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -293,8 +293,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -329,8 +329,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -364,8 +364,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -404,8 +404,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -447,8 +447,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -512,8 +512,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -548,8 +548,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
......@@ -584,8 +584,8 @@ infiniStatus_t launch_prefill_ref(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -645,8 +645,8 @@ infiniStatus_t launch_prefill_warp(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -712,8 +712,8 @@ infiniStatus_t launch_prefill(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -778,8 +778,8 @@ infiniStatus_t launch_prefill_warpcta8(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -844,8 +844,8 @@ infiniStatus_t launch_prefill_warpcta8pipe(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -910,8 +910,8 @@ infiniStatus_t launch_prefill_warpcta8mma(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -1027,8 +1027,8 @@ infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -1122,8 +1122,8 @@ infiniStatus_t launch_prefill_warpcta8n128(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -1177,8 +1177,8 @@ infiniStatus_t launch_prefill_warpcta16(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const Tindex *total_kv_lens,
const Tindex *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
......@@ -1310,8 +1310,8 @@ infiniStatus_t Descriptor::calculate(
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
const void *total_kv_lens_ptr = total_kv_lens;
const void *cu_seqlens_q_ptr = cum_seqlens_q;
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
......@@ -1345,7 +1345,7 @@ infiniStatus_t Descriptor::calculate(
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently either int 32 or int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
......@@ -1354,7 +1354,9 @@ infiniStatus_t Descriptor::calculate(
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(total_kv_lens_ptr), \
static_cast<const Tindex *>(cu_seqlens_q_ptr), \
alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1424,7 +1426,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1437,7 +1439,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1450,7 +1452,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1463,7 +1465,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1477,7 +1479,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1491,7 +1493,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1504,7 +1506,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1517,7 +1519,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......
......@@ -1311,10 +1311,8 @@ infiniStatus_t Descriptor::calculate(
auto stream = static_cast<cudaStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
// const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
// const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
const void *total_kv_lens_i64 = total_kv_lens;
const void *cu_seqlens_q_i64 = cum_seqlens_q;
const void *total_kv_lens_ptr = total_kv_lens;
const void *cu_seqlens_q_ptr = cum_seqlens_q;
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
......@@ -1357,8 +1355,8 @@ infiniStatus_t Descriptor::calculate(
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
static_cast<const Tindex *>(total_kv_lens_i64), \
static_cast<const Tindex *>(cu_seqlens_q_i64), \
static_cast<const Tindex *>(total_kv_lens_ptr), \
static_cast<const Tindex *>(cu_seqlens_q_ptr), \
alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
......@@ -1428,7 +1426,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1441,7 +1439,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1454,7 +1452,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1467,7 +1465,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1481,7 +1479,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1495,7 +1493,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1508,7 +1506,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......@@ -1521,7 +1519,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_i64), static_cast<const Tindex *>(cu_seqlens_q_i64), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(cu_seqlens_q_ptr), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
......
......@@ -12,7 +12,7 @@ INFINIOP_METAX_KERNEL pagedCaching(
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment