Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
499b1dc6
Commit
499b1dc6
authored
Jan 09, 2026
by
PanZezhong
Browse files
issue/867 pass total kv lens as paged attn args
parent
0a2839a2
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
136 additions
and
122 deletions
+136
-122
include/infinicore/ops/paged_attention_prefill.hpp
include/infinicore/ops/paged_attention_prefill.hpp
+6
-6
include/infiniop/ops/paged_attention_prefill.h
include/infiniop/ops/paged_attention_prefill.h
+4
-4
src/infinicore/ops/paged_attention/paged_attention.cc
src/infinicore/ops/paged_attention/paged_attention.cc
+7
-7
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
...nfinicore/ops/paged_attention/paged_attention_infiniop.cc
+4
-4
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
...re/ops/paged_attention_prefill/paged_attention_prefill.cc
+7
-8
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
...ged_attention_prefill/paged_attention_prefill_infiniop.cc
+6
-7
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
+2
-1
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
+19
-16
src/infiniop/ops/paged_attention_prefill/info.h
src/infiniop/ops/paged_attention_prefill/info.h
+10
-7
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+19
-15
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+12
-8
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
...iop/ops/paged_attention_prefill/paged_attention_prefill.h
+2
-2
test/infinicore/ops/paged_attention_prefill.py
test/infinicore/ops/paged_attention_prefill.py
+25
-25
test/infiniop/paged_attention_prefill.py
test/infiniop/paged_attention_prefill.py
+13
-12
No files found.
include/infinicore/ops/paged_attention_prefill.hpp
View file @
499b1dc6
...
...
@@ -16,7 +16,7 @@ public:
* 3. k_cache: Physical Key cache (Paged format)
* 4. v_cache: Physical Value cache (Paged format)
* 5. block_tables: Mapping table from logical blocks to physical blocks
* 6.
history_lens: Historical KV lengths (existing length of
each
s
eque
nce in cache)
* 6.
total_kv_lens: lengths of Complete Key/Value for
each
r
eque
st
* 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch)
* 8. alibi_slopes: ALiBi bias slopes (optional)
* 9. scale: Scaling factor (typically 1/sqrt(head_size))
...
...
@@ -24,7 +24,7 @@ public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history
_lens
,
Tensor
cu_seqlens_q
,
Tensor
block_tables
,
Tensor
total_kv
_lens
,
Tensor
cu
m
_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
...
...
@@ -34,8 +34,8 @@ Tensor paged_attention_prefill(Tensor q,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history
_lens
,
Tensor
cu_seqlens_q
,
Tensor
total_kv
_lens
,
Tensor
cu
m
_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
...
...
@@ -44,8 +44,8 @@ void paged_attention_prefill_(Tensor out,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history
_lens
,
Tensor
cu_seqlens_q
,
Tensor
total_kv
_lens
,
Tensor
cu
m
_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
...
...
include/infiniop/ops/paged_attention_prefill.h
View file @
499b1dc6
...
...
@@ -20,7 +20,7 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* Shape: [batch_size, max_blocks_per_seq]
* @param
history
_lens_desc Descriptor for the
KV history
lengths of each sequence.
* @param
seq
_lens_desc Descriptor for the
total KV
lengths of each sequence.
* Shape: [batch_size]
* @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence.
* Shape: [batch_size + 1]
...
...
@@ -37,7 +37,7 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
history
_lens_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
);
...
...
@@ -58,7 +58,7 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param
history
_lens Pointer to the KV
history
lengths data.
* @param
seq
_lens Pointer to the KV lengths data.
* @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum).
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The device stream (e.g., cudaStream_t) for the operation.
...
...
@@ -73,7 +73,7 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
history
_lens
,
const
void
*
seq
_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
void
*
stream
);
...
...
src/infinicore/ops/paged_attention/paged_attention.cc
View file @
499b1dc6
...
...
@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return
dispatcher_
;
};
void
PagedAttention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
);
void
PagedAttention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttention
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttention
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
View file @
499b1dc6
...
...
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
}
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
...
...
@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
cache
_lens
->
desc
(),
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
kv
_lens
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
...
...
@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
INFINICORE_CHECK_ERROR
(
infiniopPagedAttention
(
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
cache
_lens
->
data
(),
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
kv
_lens
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
...
...
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
View file @
499b1dc6
...
...
@@ -10,31 +10,30 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
};
void
PagedAttentionPrefill
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history
_lens
,
Tensor
cu_seqlens_q
,
Tensor
block_tables
,
Tensor
kv
_lens
,
Tensor
cu
m
_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
alibi_slopes
,
scale
);
kv
_lens
,
cu
m
_seqlens_q
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history
_lens
,
Tensor
cu_seqlens_q
,
Tensor
block_tables
,
Tensor
kv
_lens
,
Tensor
cu
m
_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
alibi_slopes
,
scale
);
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
cu
m
_seqlens_q
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history
_lens
,
Tensor
cu_seqlens_q
,
Tensor
block_tables
,
Tensor
kv
_lens
,
Tensor
cu
m
_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
alibi_slopes
,
scale
);
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
cu
m
_seqlens_q
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
View file @
499b1dc6
...
...
@@ -16,10 +16,9 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history
_lens
,
Tensor
cu_seqlens_q
,
Tensor
block_tables
,
Tensor
kv
_lens
,
Tensor
cu
m
_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes
,
scale
);
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
...
...
@@ -35,8 +34,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
history
_lens
->
desc
(),
cu_seqlens_q
->
desc
(),
kv
_lens
->
desc
(),
cu
m
_seqlens_q
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
...
...
@@ -57,8 +56,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
history
_lens
->
data
(),
cu_seqlens_q
->
data
(),
kv
_lens
->
data
(),
cu
m
_seqlens_q
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
...
...
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
View file @
499b1dc6
...
...
@@ -19,7 +19,8 @@ Tensor py_paged_attention_prefill(Tensor q,
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
return
op
::
paged_attention_prefill
(
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes_tensor
,
scale
);
return
op
::
paged_attention_prefill
(
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes_tensor
,
scale
);
}
void
py_paged_attention_prefill_
(
Tensor
out
,
...
...
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
View file @
499b1dc6
...
...
@@ -22,12 +22,13 @@ template <typename Tdata, typename Tcompute>
__global__
void
pagedAttentionPrefillKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
history
_lens_
,
const
int64_t
*
total_kv
_lens_
,
const
int64_t
*
cum_seq_lens_q_
,
const
float
*
alibi_slopes_
,
const
size_t
num_heads
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
q_head_stride
,
const
size_t
head_size
,
const
size_t
num_seqs
)
{
...
...
@@ -44,10 +45,12 @@ __global__ void pagedAttentionPrefillKernel(
size_t
q_token_idx
=
global_token_idx
-
cum_seq_lens_q_
[
seq_idx
];
const
int64_t
history_len
=
history_lens_
[
seq_idx
];
const
int64_t
causal_limit
=
history_len
+
q_token_idx
;
const
size_t
total_kv_len
=
total_kv_lens_
[
seq_idx
];
const
size_t
q_len
=
cum_seq_lens_q_
[
seq_idx
+
1
]
-
cum_seq_lens_q_
[
seq_idx
];
const
size_t
history_len
=
total_kv_len
-
q_len
;
const
size_t
causal_limit
=
history_len
+
q_token_idx
;
const
Tdata
*
q_vec
=
q_
+
global_token_idx
*
num_heads
*
head_siz
e
+
head_idx
*
head_s
iz
e
;
const
Tdata
*
q_vec
=
q_
+
global_token_idx
*
q_strid
e
+
head_idx
*
q_
head_s
trid
e
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
...
...
@@ -57,10 +60,10 @@ __global__ void pagedAttentionPrefillKernel(
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
Tcompute
max_score
=
-
FLT_MAX
;
for
(
int64
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64
_t
b_idx
=
t
/
block_size
;
const
int64
_t
t_off
=
t
%
block_size
;
const
int64
_t
physical_block_id
=
block_table
[
b_idx
];
for
(
size
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size
_t
b_idx
=
t
/
block_size
;
const
size
_t
t_off
=
t
%
block_size
;
const
ptrdiff
_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
...
...
@@ -77,10 +80,10 @@ __global__ void pagedAttentionPrefillKernel(
}
Tcompute
sum_exp
=
0.0
f
;
for
(
int64
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64
_t
b_idx
=
t
/
block_size
;
const
int64
_t
t_off
=
t
%
block_size
;
const
int64
_t
physical_block_id
=
block_table
[
b_idx
];
for
(
size
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size
_t
b_idx
=
t
/
block_size
;
const
size
_t
t_off
=
t
%
block_size
;
const
ptrdiff
_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
...
...
@@ -96,10 +99,10 @@ __global__ void pagedAttentionPrefillKernel(
Tcompute
acc
=
0.0
f
;
Tcompute
inv_sum
=
1.0
f
/
(
sum_exp
+
1e-6
f
);
for
(
int64
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64
_t
b_idx
=
t
/
block_size
;
const
int64
_t
t_off
=
t
%
block_size
;
const
int64
_t
physical_block_id
=
block_table
[
b_idx
];
for
(
size
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size
_t
b_idx
=
t
/
block_size
;
const
size
_t
t_off
=
t
%
block_size
;
const
ptrdiff
_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
...
...
src/infiniop/ops/paged_attention_prefill/info.h
View file @
499b1dc6
...
...
@@ -25,6 +25,7 @@ public:
size_t
total_q_tokens
;
ptrdiff_t
q_stride
;
ptrdiff_t
q_head_stride
;
ptrdiff_t
kv_block_stride
;
ptrdiff_t
kv_head_stride
;
ptrdiff_t
o_stride
;
...
...
@@ -35,7 +36,7 @@ public:
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
history
_lens_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
...
...
@@ -47,7 +48,7 @@ public:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
cum_seq_lens_q_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
history
_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
if
(
cum_seq_lens_q_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
seq
_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
...
...
@@ -57,7 +58,7 @@ public:
auto
k_shape
=
k_cache_desc
->
shape
();
auto
v_shape
=
v_cache_desc
->
shape
();
auto
block_tables_shape
=
block_tables_desc
->
shape
();
auto
history
_lens_shape
=
history
_lens_desc
->
shape
();
auto
seq
_lens_shape
=
seq
_lens_desc
->
shape
();
auto
cum_seq_lens_q_shape
=
cum_seq_lens_q_desc
->
shape
();
if
(
k_shape
.
size
()
!=
4
||
v_shape
.
size
()
!=
4
)
{
...
...
@@ -68,10 +69,11 @@ public:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
history
_lens_shape
.
size
()
!=
1
||
cum_seq_lens_q_shape
.
size
()
!=
1
)
{
if
(
seq
_lens_shape
.
size
()
!=
1
||
cum_seq_lens_q_shape
.
size
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
cum_seq_lens_q_shape
[
0
]
!=
history_lens_shape
[
0
]
+
1
)
{
if
(
cum_seq_lens_q_shape
[
0
]
!=
seq_lens_shape
[
0
]
+
1
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
...
...
@@ -88,13 +90,13 @@ public:
return
INFINI_STATUS_BAD_PARAM
;
}
size_t
num_seqs
=
history_lens_shape
[
0
];
size_t
num_seqs
=
seq_lens_shape
[
0
];
size_t
num_kv_heads
=
k_shape
[
1
];
size_t
block_size
=
k_shape
[
2
];
size_t
max_num_blocks_per_seq
=
block_tables_shape
[
1
];
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
q_head_stride
=
q_desc
->
stride
(
1
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
...
...
@@ -110,6 +112,7 @@ public:
max_num_blocks_per_seq
,
total_q_tokens
,
q_stride
,
q_head_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
});
...
...
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
499b1dc6
...
...
@@ -12,7 +12,7 @@ template <typename Tdata, typename Tcompute>
infiniStatus_t
launchPagedAttentionPrefill
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
int64_t
*
block_tables
,
const
int64_t
*
history
_lens
,
const
int64_t
*
seq
_lens
,
const
int64_t
*
cum_seq_lens_q
,
const
float
*
alibi_slopes
,
const
size_t
num_heads
,
...
...
@@ -25,6 +25,8 @@ infiniStatus_t launchPagedAttentionPrefill(
const
size_t
head_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
q_head_stride
,
cudaStream_t
stream
)
{
if
(
total_q_tokens
==
0
||
num_heads
==
0
)
{
...
...
@@ -37,10 +39,11 @@ infiniStatus_t launchPagedAttentionPrefill(
op
::
paged_attention_prefill
::
cuda
::
pagedAttentionPrefillKernel
<
Tdata
,
Tcompute
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cum_seq_lens_q
,
alibi_slopes
,
block_tables
,
seq
_lens
,
cum_seq_lens_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
kv_block_stride
,
kv_head_stride
,
q_stride
,
q_head_stride
,
head_size
,
num_seqs
);
...
...
@@ -65,14 +68,14 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
history
_lens_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
info
=
PagedAttentionPrefillInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
history
_lens_desc
,
block_tables_desc
,
seq
_lens_desc
,
cum_seq_lens_q_desc
,
alibi_slopes_desc
,
scale
);
...
...
@@ -89,23 +92,24 @@ infiniStatus_t Descriptor::calculate(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
history
_lens
,
const
void
*
seq
_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)history_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
...
...
src/infiniop/ops/paged_attention_prefill/operator.cc
View file @
499b1dc6
...
...
@@ -14,7 +14,7 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
history
_lens_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
)
{
...
...
@@ -27,14 +27,15 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, \
history
_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
seq
_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetPagedAttentionPrefillWorkspaceSize
(
...
...
@@ -50,8 +51,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopPagedAttentionPrefill
(
...
...
@@ -59,7 +61,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
history
_lens
,
const
void
*
seq
_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
void
*
stream
)
{
...
...
@@ -68,14 +70,15 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
history
_lens, cum_seq_lens_q, alibi_slopes, stream);
seq
_lens, cum_seq_lens_q, alibi_slopes, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyPagedAttentionPrefillDescriptor
(
...
...
@@ -90,6 +93,7 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
View file @
499b1dc6
...
...
@@ -37,7 +37,7 @@
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t
history
_lens_desc, \
infiniopTensorDescriptor_t
seq
_lens_desc,
\
infiniopTensorDescriptor_t cum_seq_lens_q_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
...
...
@@ -46,7 +46,7 @@
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, \
const void *
history
_lens, \
const void *
seq
_lens,
\
const void *cum_seq_lens_q, \
const void *alibi_slopes, \
void *stream) const; \
...
...
test/infinicore/ops/paged_attention_prefill.py
View file @
499b1dc6
import
sys
import
os
import
sys
import
torch
import
infinicore
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
TensorInitializer
,
TensorSpec
,
TestCase
,
)
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
...
...
@@ -71,16 +73,17 @@ def parse_test_cases():
scale
=
head_size
**-
0.5
num_blocks
=
8192
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
current_history
_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int64
)
kv
_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int64
)
persistent_k
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
persistent_v
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
for
r
in
range
(
num_rounds
):
q_lens
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
kv_lens
=
kv_lens
+
q_lens
total_q_tokens
=
q_lens
.
sum
().
item
()
cu_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int64
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
q_lens
,
dim
=
0
)
cu
m
_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int64
)
cu
m
_seqlens_q
[
1
:]
=
torch
.
cumsum
(
q_lens
,
dim
=
0
)
query_base
=
torch
.
randn
((
total_q_tokens
,
num_heads
,
head_size
))
...
...
@@ -89,8 +92,7 @@ def parse_test_cases():
p_blocks
,
total_len
=
manager
.
allocate_slots
(
i
,
q_lens
[
i
].
item
())
round_block_tables_list
.
append
(
p_blocks
)
h_len
=
current_history_lens
[
i
].
item
()
q_start
=
cu_seqlens_q
[
i
].
item
()
h_len
=
kv_lens
[
i
].
item
()
-
q_lens
[
i
].
item
()
for
t
in
range
(
q_lens
[
i
].
item
()):
logical_pos
=
h_len
+
t
...
...
@@ -135,15 +137,15 @@ def parse_test_cases():
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
current_history
_lens
.
shape
,
kv
_lens
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
current_history
_lens
.
clone
(),
set_tensor
=
kv
_lens
.
clone
(),
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
cu_seqlens_q
.
shape
,
cu
m
_seqlens_q
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
cu_seqlens_q
.
clone
(),
set_tensor
=
cu
m
_seqlens_q
.
clone
(),
dtype
=
infinicore
.
int64
,
),
],
...
...
@@ -153,23 +155,21 @@ def parse_test_cases():
)
)
current_history_lens
+=
q_lens
return
test_cases
def
ref_paged_attention_multi_turn
(
query
,
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
scale
query
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
cu
m
_seqlens_q
,
scale
):
output
=
torch
.
zeros_like
(
query
)
num_seqs
=
len
(
history
_lens
)
num_seqs
=
len
(
kv
_lens
)
block_size
=
k_cache
.
shape
[
2
]
for
i
in
range
(
num_seqs
):
q_start
,
q_end
=
cu_seqlens_q
[
i
].
item
(),
cu_seqlens_q
[
i
+
1
].
item
()
q_start
,
q_end
=
cu
m
_seqlens_q
[
i
].
item
(),
cu
m
_seqlens_q
[
i
+
1
].
item
()
cur_q
=
query
[
q_start
:
q_end
]
h_len
=
history_lens
[
i
].
item
()
q_len
=
q_end
-
q_start
h_len
=
kv_lens
[
i
].
item
()
-
q_len
total_len
=
h_len
+
q_len
table
=
block_tables
[
i
]
...
...
@@ -206,12 +206,12 @@ class OpTest(BaseOperatorTest):
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
kv
_lens
,
cu
m
_seqlens_q
,
scale
=
1.0
,
):
return
ref_paged_attention_multi_turn
(
query
,
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
scale
query
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
cu
m
_seqlens_q
,
scale
)
def
infinicore_operator
(
...
...
@@ -220,8 +220,8 @@ class OpTest(BaseOperatorTest):
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
kv
_lens
,
cu
m
_seqlens_q
,
scale
=
1.0
,
):
out
=
infinicore
.
paged_attention_prefill
(
...
...
@@ -229,8 +229,8 @@ class OpTest(BaseOperatorTest):
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_seqlens_q
,
kv
_lens
,
cu
m
_seqlens_q
,
alibi_slopes
=
None
,
scale
=
scale
,
)
...
...
test/infiniop/paged_attention_prefill.py
View file @
499b1dc6
import
torch
import
ctypes
from
ctypes
import
c_uint64
import
torch
from
libinfiniop
import
(
LIBINFINIOP
,
InfiniDeviceNames
,
InfiniDtype
,
InfiniDtypeNames
,
TestTensor
,
get_test_devi
ce
s
,
TestWorkspa
ce
,
check_error
,
test_operator
,
get_args
,
debug
,
get_args
,
get_test_devices
,
get_tolerance
,
profile_operation
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
TestWorkspace
,
profile_operation
,
test_operator
,
)
# ==============================================================================
...
...
@@ -81,8 +82,8 @@ def ref_paged_attention_multi_turn(
num_seqs
=
len
(
cum_seq_lens_q
)
-
1
for
i
in
range
(
num_seqs
):
num_new
=
cum_seq_lens_q
[
i
+
1
].
item
()
-
cum_seq_lens_q
[
i
].
item
()
cache
_len
=
seq_lens
[
i
].
item
()
total
_len
=
seq_lens
[
i
].
item
()
+
num_new
total
_len
=
seq_lens
[
i
].
item
()
cache
_len
=
seq_lens
[
i
].
item
()
-
num_new
table
=
block_tables
[
i
]
keys_all
,
values_all
=
[],
[]
...
...
@@ -166,7 +167,7 @@ def test(
cur_q_len
=
query_lens_cpu
[
i
].
item
()
table
,
total_len
=
manager
.
allocate_slots
(
i
,
cur_q_len
)
cur_seq_lens
=
total_len
-
cur_q_len
seq_lens_list
.
append
(
cur_seq
_len
s
)
seq_lens_list
.
append
(
total
_len
)
all_block_tables
.
append
(
table
)
# Simulated KV insertion
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment