Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
9a0f2505
Commit
9a0f2505
authored
Mar 05, 2026
by
wooway777
Browse files
issue/1050 - fix paged caching and paged prefill on metax
parent
a9503148
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
73 deletions
+73
-73
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca
...ttention_prefill/metax/paged_attention_prefill_metax.maca
+60
-58
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+12
-14
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
...infiniop/ops/paged_caching/metax/paged_caching_metax.maca
+1
-1
No files found.
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca
View file @
9a0f2505
...
...
@@ -46,8 +46,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -79,8 +79,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -112,8 +112,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -148,8 +148,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -184,8 +184,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -220,8 +220,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -257,8 +257,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -293,8 +293,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -329,8 +329,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -364,8 +364,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -404,8 +404,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -447,8 +447,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -512,8 +512,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -548,8 +548,8 @@ INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
...
...
@@ -584,8 +584,8 @@ infiniStatus_t launch_prefill_ref(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -645,8 +645,8 @@ infiniStatus_t launch_prefill_warp(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -712,8 +712,8 @@ infiniStatus_t launch_prefill(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -778,8 +778,8 @@ infiniStatus_t launch_prefill_warpcta8(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -844,8 +844,8 @@ infiniStatus_t launch_prefill_warpcta8pipe(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -910,8 +910,8 @@ infiniStatus_t launch_prefill_warpcta8mma(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -1027,8 +1027,8 @@ infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -1122,8 +1122,8 @@ infiniStatus_t launch_prefill_warpcta8n128(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -1177,8 +1177,8 @@ infiniStatus_t launch_prefill_warpcta16(
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const in
t64_t
*total_kv_lens,
const in
t64_t
*cu_seqlens_q,
const
T
in
dex
*total_kv_lens,
const
T
in
dex
*cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
...
...
@@ -1310,8 +1310,8 @@ infiniStatus_t Descriptor::calculate(
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const
auto
*total_kv_lens_
i64 = static_cast<const int64_t *>(
total_kv_lens
)
;
const
auto
*cu_seqlens_q_
i64 = static_cast<const int64_t *>(
cum_seqlens_q
)
;
const
void
*total_kv_lens_
ptr =
total_kv_lens;
const
void
*cu_seqlens_q_
ptr =
cum_seqlens_q;
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
...
...
@@ -1345,7 +1345,7 @@ infiniStatus_t Descriptor::calculate(
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently
always
int64.
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently
either int 32 or
int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
...
...
@@ -1354,7 +1354,9 @@ infiniStatus_t Descriptor::calculate(
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
static_cast<const Tindex *>(total_kv_lens_ptr), \
static_cast<const Tindex *>(cu_seqlens_q_ptr), \
alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1424,7 +1426,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1437,7 +1439,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1450,7 +1452,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1463,7 +1465,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1477,7 +1479,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1491,7 +1493,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1504,7 +1506,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1517,7 +1519,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables),
total_kv_lens_i64,
cu_seqlens_q_
i64
, alibi_ptr, \
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens_ptr), static_cast<const Tindex *>(
cu_seqlens_q_
ptr)
, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
9a0f2505
...
...
@@ -1311,10 +1311,8 @@ infiniStatus_t Descriptor::calculate(
auto
stream
=
static_cast
<
cudaStream_t
>
(
stream_
);
const
float
*
alibi_ptr
=
(
alibi_slopes
==
nullptr
)
?
nullptr
:
static_cast
<
const
float
*>
(
alibi_slopes
);
// const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
// const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
const
void
*
total_kv_lens_i64
=
total_kv_lens
;
const
void
*
cu_seqlens_q_i64
=
cum_seqlens_q
;
const
void
*
total_kv_lens_ptr
=
total_kv_lens
;
const
void
*
cu_seqlens_q_ptr
=
cum_seqlens_q
;
bool
use_splitkv
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_SPLITKV"
))
{
...
...
@@ -1357,8 +1355,8 @@ infiniStatus_t Descriptor::calculate(
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
static_cast<const Tindex *>(total_kv_lens_
i64
), \
static_cast<const Tindex *>(cu_seqlens_q_
i64
), \
static_cast<const Tindex *>(total_kv_lens_
ptr
), \
static_cast<const Tindex *>(cu_seqlens_q_
ptr
), \
alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
...
...
@@ -1428,7 +1426,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1441,7 +1439,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1454,7 +1452,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1467,7 +1465,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1481,7 +1479,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1495,7 +1493,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1508,7 +1506,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
@@ -1521,7 +1519,7 @@ infiniStatus_t Descriptor::calculate(
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
i64
), static_cast<const Tindex *>(cu_seqlens_q_
i64
), alibi_ptr, \
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(total_kv_lens_
ptr
), static_cast<const Tindex *>(cu_seqlens_q_
ptr
), alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
...
...
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
View file @
9a0f2505
...
...
@@ -12,7 +12,7 @@ INFINIOP_METAX_KERNEL pagedCaching(
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid
e
) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment