Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
3883f32f
Unverified
Commit
3883f32f
authored
Jan 09, 2026
by
Haojie Wang
Committed by
GitHub
Jan 09, 2026
Browse files
Merge pull request #883 from InfiniTensor/issue/867
Issue/867: adjust paged_attention_prefill interface naming
parents
3b5afffe
499b1dc6
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
640 additions
and
256 deletions
+640
-256
include/infinicore/ops/paged_attention_prefill.hpp
include/infinicore/ops/paged_attention_prefill.hpp
+38
-4
include/infiniop/ops/paged_attention_prefill.h
include/infiniop/ops/paged_attention_prefill.h
+16
-12
python/infinicore/ops/paged_attention_prefill.py
python/infinicore/ops/paged_attention_prefill.py
+10
-11
src/infinicore/ops/paged_attention/paged_attention.cc
src/infinicore/ops/paged_attention/paged_attention.cc
+7
-7
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
...nfinicore/ops/paged_attention/paged_attention_infiniop.cc
+4
-4
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
...re/ops/paged_attention_prefill/paged_attention_prefill.cc
+18
-7
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
...ged_attention_prefill/paged_attention_prefill_infiniop.cc
+21
-6
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+2
-0
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
+69
-0
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
+30
-37
src/infiniop/ops/paged_attention_prefill/info.h
src/infiniop/ops/paged_attention_prefill/info.h
+39
-22
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+30
-40
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+15
-11
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
...iop/ops/paged_attention_prefill/paged_attention_prefill.h
+47
-47
test/infinicore/ops/paged_attention_prefill.py
test/infinicore/ops/paged_attention_prefill.py
+248
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+0
-2
test/infiniop/paged_attention_prefill.py
test/infiniop/paged_attention_prefill.py
+46
-46
No files found.
include/infinicore/ops/paged_attention_prefill.hpp
View file @
3883f32f
...
...
@@ -8,11 +8,45 @@ namespace infinicore::op {
class
PagedAttentionPrefill
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
);
/**
* @brief PagedAttentionPrefill operator signature
* * Argument order:
* 1. out: Output tensor (Packed format)
* 2. q: Current Query tensor (Packed format)
* 3. k_cache: Physical Key cache (Paged format)
* 4. v_cache: Physical Value cache (Paged format)
* 5. block_tables: Mapping table from logical blocks to physical blocks
* 6. total_kv_lens: lengths of Complete Key/Value for each request
* 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch)
* 8. alibi_slopes: ALiBi bias slopes (optional)
* 9. scale: Scaling factor (typically 1/sqrt(head_size))
*/
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
total_kv_lens
,
Tensor
cum_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
total_kv_lens
,
Tensor
cum_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
total_kv_lens
,
Tensor
cum_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
include/infiniop/ops/paged_attention_prefill.h
View file @
3883f32f
...
...
@@ -11,15 +11,22 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* Shape: [total_q_tokens, num_heads, head_size]
* @param q_desc Descriptor for the query tensor (packed/flattened).
* Shape: [total_q_tokens, num_heads, head_size]
* @param k_cache_desc Descriptor for the global physical key cache.
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param v_cache_desc Descriptor for the global physical value cache.
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* @param cache_lens_desc Descriptor for the total sequence lengths (history + current).
* @param seq_lens_desc Descriptor for the current prefill sequence lengths.
* @param offset_desc Descriptor for the start position of each sequence in the packed Q tensor.
* Shape: [batch_size, max_blocks_per_seq]
* @param seq_lens_desc Descriptor for the total KV lengths of each sequence.
* Shape: [batch_size]
* @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence.
* Shape: [batch_size + 1]
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* Shape: [num_heads]
* @param scale The attention scaling factor (typically 1.0 / sqrt(head_size)).
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
...
...
@@ -30,9 +37,8 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset
_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q
_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
);
...
...
@@ -52,11 +58,10 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param cache_lens Pointer to the total sequence lengths data.
* @param seq_lens Pointer to the current prefill sequence lengths data.
* @param offset Pointer to the sequence start offsets data.
* @param seq_lens Pointer to the KV lengths data.
* @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum).
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The
CUDA/
device stream for the operation.
* @param stream The device stream
(e.g., cudaStream_t)
for the operation.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopPagedAttentionPrefill
(
...
...
@@ -68,9 +73,8 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
offset
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
void
*
stream
);
...
...
python/infinicore/ops/paged_attention_prefill.py
View file @
3883f32f
...
...
@@ -7,14 +7,15 @@ def paged_attention_prefill(
k_cache
:
Tensor
,
v_cache
:
Tensor
,
block_tables
:
Tensor
,
cache_lens
:
Tensor
,
seq_lens
:
Tensor
,
seq_offsets
:
Tensor
,
history_lens
:
Tensor
,
cu_seqlens_q
:
Tensor
,
alibi_slopes
:
Tensor
|
None
=
None
,
scale
:
float
=
1.0
,
*
,
out
:
Tensor
|
None
=
None
,
):
alibi_ptr
=
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
if
out
is
None
:
return
Tensor
(
_infinicore
.
paged_attention_prefill
(
...
...
@@ -22,10 +23,9 @@ def paged_attention_prefill(
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
seq_lens
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
history_lens
.
_underlying
,
cu_seqlens_q
.
_underlying
,
alibi_ptr
,
scale
,
)
)
...
...
@@ -36,10 +36,9 @@ def paged_attention_prefill(
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
seq_lens
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
history_lens
.
_underlying
,
cu_seqlens_q
.
_underlying
,
alibi_ptr
,
scale
,
)
...
...
src/infinicore/ops/paged_attention/paged_attention.cc
View file @
3883f32f
...
...
@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return
dispatcher_
;
};
void
PagedAttention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
);
void
PagedAttention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttention
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttention
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
View file @
3883f32f
...
...
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
}
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
...
...
@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
cache
_lens
->
desc
(),
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
kv
_lens
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
...
...
@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
INFINICORE_CHECK_ERROR
(
infiniopPagedAttention
(
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
cache
_lens
->
data
(),
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
kv
_lens
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
...
...
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
View file @
3883f32f
...
...
@@ -9,20 +9,31 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
return
dispatcher_
;
};
void
PagedAttentionPrefill
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
);
void
PagedAttentionPrefill
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv_lens
,
Tensor
cum_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv_lens
,
Tensor
cum_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
seq
_
lens
,
seq_offsets
,
alibi_slopes
,
scale
);
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv
_lens
,
cum_
seqlens
_q
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv_lens
,
Tensor
cum_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
View file @
3883f32f
...
...
@@ -15,8 +15,10 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
}
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
kv_lens
,
Tensor
cum_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
...
...
@@ -27,8 +29,13 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionPrefillDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
cache_lens
->
desc
(),
seq_lens
->
desc
(),
seq_offsets
->
desc
(),
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
kv_lens
->
desc
(),
cum_seqlens_q
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
...
...
@@ -41,8 +48,16 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopPagedAttentionPrefill
(
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
cache_lens
->
data
(),
seq_lens
->
data
(),
seq_offsets
->
data
(),
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
kv_lens
->
data
(),
cum_seqlens_q
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
...
...
src/infinicore/pybind11/ops.hpp
View file @
3883f32f
...
...
@@ -11,6 +11,7 @@
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
...
...
@@ -33,6 +34,7 @@ inline void bind(py::module &m) {
bind_matmul
(
m
);
bind_mul
(
m
);
bind_paged_attention
(
m
);
bind_paged_attention_prefill
(
m
);
bind_paged_caching
(
m
);
bind_rearrange
(
m
);
bind_rms_norm
(
m
);
...
...
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
0 → 100644
View file @
3883f32f
#pragma once
#include "infinicore/ops/paged_attention_prefill.hpp"
#include <pybind11/pybind11.h>
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
Tensor
py_paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
py
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
return
op
::
paged_attention_prefill
(
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes_tensor
,
scale
);
}
void
py_paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
py
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
op
::
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes_tensor
,
scale
);
}
inline
void
bind_paged_attention_prefill
(
py
::
module
&
m
)
{
m
.
def
(
"paged_attention_prefill"
,
&
ops
::
py_paged_attention_prefill
,
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"history_lens"
),
py
::
arg
(
"cu_seqlens_q"
),
py
::
arg
(
"alibi_slopes"
)
=
py
::
none
(),
py
::
arg
(
"scale"
)
=
1.0
,
R"doc(Paged attention prefill for packed variable-length queries.)doc"
);
m
.
def
(
"paged_attention_prefill_"
,
&
ops
::
py_paged_attention_prefill_
,
py
::
arg
(
"out"
),
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"history_lens"
),
py
::
arg
(
"cu_seqlens_q"
),
py
::
arg
(
"alibi_slopes"
)
=
py
::
none
(),
py
::
arg
(
"scale"
)
=
1.0
,
R"doc(In-place paged attention prefill for packed variable-length queries.)doc"
);
}
}
// namespace infinicore::ops
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
View file @
3883f32f
...
...
@@ -3,14 +3,13 @@
namespace
op
::
paged_attention_prefill
::
cuda
{
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
offset
,
size_t
num_seqs
)
{
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
cum_seq_lens_q
,
size_t
num_seqs
)
{
size_t
low
=
0
,
high
=
num_seqs
-
1
;
while
(
low
<=
high
)
{
size_t
mid
=
(
low
+
high
)
>>
1
;
if
(
token_idx
>=
offset
[
mid
]
&&
token_idx
<
offset
[
mid
+
1
])
{
if
(
token_idx
>=
(
size_t
)
cum_seq_lens_q
[
mid
]
&&
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
+
1
])
{
return
mid
;
}
else
if
(
token_idx
<
offset
[
mid
])
{
}
else
if
(
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
])
{
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
...
...
@@ -22,58 +21,54 @@ __device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *o
template
<
typename
Tdata
,
typename
Tcompute
>
__global__
void
pagedAttentionPrefillKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
cache_lens_
,
const
int64_t
*
seq_lens_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cum_seq_lens_q_
,
const
float
*
alibi_slopes_
,
const
size_t
num_heads
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
q_head_stride
,
const
size_t
head_size
,
const
int64_t
*
offset_
,
const
size_t
num_seqs
)
{
//
--- 使用 2D Grid 坐标 ---
const
size_t
global_token_idx
=
blockIdx
.
x
;
// 展平后的全局 token 索引
const
size_t
head_idx
=
blockIdx
.
y
;
// Head 索引
const
size_t
dim_idx
=
threadIdx
.
x
;
// Head 内部维度
//
Grid : x -> token, y -> head
const
size_t
global_token_idx
=
blockIdx
.
x
;
const
size_t
head_idx
=
blockIdx
.
y
;
const
size_t
dim_idx
=
threadIdx
.
x
;
if
(
dim_idx
>=
head_size
)
{
return
;
}
// --- 通过二分查找 offset 找到所属的 seq_idx ---
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
offset_
,
num_seqs
);
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
cum_seq_lens_q_
,
num_seqs
);
// --- 获取该 Sequence 本次 Prefill 的长度
const
int64_t
cur_new_len
=
seq_lens_
[
seq_idx
];
size_t
q_token_idx
=
global_token_idx
-
cum_seq_lens_q_
[
seq_idx
];
// --- 该 token 在当前序列中的相对位置
size_t
q_token_idx
=
global_token_idx
-
offset_
[
seq_idx
];
const
size_t
total_kv_len
=
total_kv_lens_
[
seq_idx
];
const
size_t
q_len
=
cum_seq_lens_q_
[
seq_idx
+
1
]
-
cum_seq_lens_q_
[
seq_idx
];
const
size_t
history_len
=
total_kv_len
-
q_len
;
const
size_t
causal_limit
=
history_len
+
q_token_idx
;
const
Tdata
*
q_
ptr_base
=
q_
+
global_token_idx
*
num_heads
*
head_siz
e
+
head_idx
*
head_s
iz
e
;
const
Tdata
*
q_
vec
=
q_
+
global_token_idx
*
q_strid
e
+
head_idx
*
q_
head_s
trid
e
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
// --- KV Cache 相关信息
const
int64_t
total_seq_len
=
cache_lens_
[
seq_idx
];
const
int64_t
history_len
=
total_seq_len
-
cur_new_len
;
const
int64_t
causal_limit
=
history_len
+
q_token_idx
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
size_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
// Pass 1: 计算 Score 并找最大值
Tcompute
max_score
=
-
FLT_MAX
;
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64
_t
b_idx
=
t
/
block_size
;
const
int64
_t
t_off
=
t
%
block_size
;
const
int64
_t
physical_block_id
=
block_table
[
b_idx
];
const
size
_t
b_idx
=
t
/
block_size
;
const
size
_t
t_off
=
t
%
block_size
;
const
ptrdiff
_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_
ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
score
+=
static_cast
<
Tcompute
>
(
q_
vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
...
...
@@ -84,17 +79,16 @@ __global__ void pagedAttentionPrefillKernel(
}
}
// Pass 2: 计算 Sum of Exp
Tcompute
sum_exp
=
0.0
f
;
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64
_t
b_idx
=
t
/
block_size
;
const
int64
_t
t_off
=
t
%
block_size
;
const
int64
_t
physical_block_id
=
block_table
[
b_idx
];
const
size
_t
b_idx
=
t
/
block_size
;
const
size
_t
t_off
=
t
%
block_size
;
const
ptrdiff
_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_
ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
score
+=
static_cast
<
Tcompute
>
(
q_
vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
...
...
@@ -103,18 +97,17 @@ __global__ void pagedAttentionPrefillKernel(
sum_exp
+=
expf
(
static_cast
<
float
>
(
score
-
max_score
));
}
// Pass 3: 加权求和得到输出
Tcompute
acc
=
0.0
f
;
Tcompute
inv_sum
=
1.0
f
/
(
sum_exp
+
1e-6
f
);
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64
_t
b_idx
=
t
/
block_size
;
const
int64
_t
t_off
=
t
%
block_size
;
const
int64
_t
physical_block_id
=
block_table
[
b_idx
];
const
size
_t
b_idx
=
t
/
block_size
;
const
size
_t
t_off
=
t
%
block_size
;
const
ptrdiff
_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_
ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
score
+=
static_cast
<
Tcompute
>
(
q_
vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
...
...
src/infiniop/ops/paged_attention_prefill/info.h
View file @
3883f32f
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
#define __PAGED_ATTENTION_PREFILL_INFO_H__
#ifndef __
INFINIOP_
PAGED_ATTENTION_PREFILL_INFO_H__
#define __
INFINIOP_
PAGED_ATTENTION_PREFILL_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
...
...
@@ -25,6 +25,7 @@ public:
size_t
total_q_tokens
;
ptrdiff_t
q_stride
;
ptrdiff_t
q_head_stride
;
ptrdiff_t
kv_block_stride
;
ptrdiff_t
kv_head_stride
;
ptrdiff_t
o_stride
;
...
...
@@ -35,9 +36,8 @@ public:
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset
_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q
_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
...
...
@@ -47,40 +47,56 @@ public:
if
(
out_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
offset_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
seq_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
if
(
cum_seq_lens_q_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
seq_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
alibi_slopes_desc
.
has_value
()
&&
alibi_slopes_desc
.
value
()
!=
nullptr
)
{
std
::
cerr
<<
"[Error] PagedAttentionPrefill: ALiBi slopes are not supported yet."
<<
std
::
endl
;
return
INFINI_STATUS_BAD_PARAM
;
}
// Q shape: [total_tokens, heads, dim] (3D)
auto
q_shape
=
q_desc
->
shape
();
if
(
q_shape
.
size
()
<
3
)
{
auto
k_shape
=
k_cache_desc
->
shape
();
auto
v_shape
=
v_cache_desc
->
shape
();
auto
block_tables_shape
=
block_tables_desc
->
shape
();
auto
seq_lens_shape
=
seq_lens_desc
->
shape
();
auto
cum_seq_lens_q_shape
=
cum_seq_lens_q_desc
->
shape
();
if
(
k_shape
.
size
()
!=
4
||
v_shape
.
size
()
!=
4
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
total_q_tokens
=
q_shape
[
0
];
size_t
num_heads
=
q_shape
[
q_shape
.
size
()
-
2
];
size_t
head_size
=
q_shape
[
q_shape
.
size
()
-
1
];
if
(
block_tables_shape
.
size
()
!=
2
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
head_size
!=
128
)
{
std
::
cerr
<<
"[Error] PagedAttentionPrefill head_size = 128 supported, got "
<<
head_size
<<
std
::
endl
;
if
(
seq_lens_shape
.
size
()
!=
1
||
cum_seq_lens_q_shape
.
size
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// 从 seq_lens 获取 num_seqs
size_t
num_seqs
=
seq_lens_desc
->
shape
()[
0
];
if
(
cum_seq_lens_q_shape
[
0
]
!=
seq_lens_shape
[
0
]
+
1
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// Q shape: [total_tokens, heads, dim]
auto
q_shape
=
q_desc
->
shape
();
if
(
q_shape
.
size
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
total_q_tokens
=
q_shape
[
0
];
size_t
num_heads
=
q_shape
[
1
];
size_t
head_size
=
q_shape
[
2
];
if
(
head_size
>
1024
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
auto
k_cache_shape
=
k_cache_desc
->
shape
()
;
size_t
num_kv_heads
=
k_
cache_
shape
[
1
];
size_t
block_size
=
v_cache_desc
->
shape
()
[
2
];
size_t
max_num_blocks_per_seq
=
block_tables_
desc
->
shape
()
[
1
];
size_t
num_seqs
=
seq_lens_
shape
[
0
]
;
size_t
num_kv_heads
=
k_shape
[
1
];
size_t
block_size
=
k_
shape
[
2
];
size_t
max_num_blocks_per_seq
=
block_tables_shape
[
1
];
// 提取步长,需要保持多个请求的 Q 连续
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
q_head_stride
=
q_desc
->
stride
(
1
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
...
...
@@ -96,6 +112,7 @@ public:
max_num_blocks_per_seq
,
total_q_tokens
,
q_stride
,
q_head_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
});
...
...
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
3883f32f
...
...
@@ -8,14 +8,12 @@
#include "../cuda/kernel.cuh"
#include "paged_attention_prefill_nvidia.cuh"
// ==============================================================================
// Host wrapper to launch the global kernel
// ==============================================================================
template
<
typename
Tdata
,
typename
Tcompute
>
infiniStatus_t
launchPagedAttentionPrefill
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
int64_t
*
block_tables
,
const
int64_t
*
cache_lens
,
const
int64_t
*
seq_lens
,
const
int64_t
*
offset
,
const
int64_t
*
block_tables
,
const
int64_t
*
seq_lens
,
const
int64_t
*
cum_seq_lens_q
,
const
float
*
alibi_slopes
,
const
size_t
num_heads
,
const
size_t
num_seqs
,
...
...
@@ -24,36 +22,30 @@ infiniStatus_t launchPagedAttentionPrefill(
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
size_t
total_q_tokens
,
const
ptrdiff_t
q_strid
e
,
const
size_t
head_siz
e
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
o
_stride
,
const
size
_t
head_s
iz
e
,
const
ptrdiff_t
q
_stride
,
const
ptrdiff
_t
q_
head_s
trid
e
,
cudaStream_t
stream
)
{
if
(
total_q_tokens
==
0
||
num_heads
==
0
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// 使用 2D Grid: X轴是所有 Token,Y轴是所有 Head
dim3
grid
(
total_q_tokens
,
num_heads
);
dim3
block
(
head_size
);
op
::
paged_attention_prefill
::
cuda
::
pagedAttentionPrefillKernel
<
Tdata
,
Tcompute
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
seq_lens
,
alibi_slopes
,
block_tables
,
seq
_lens
,
cum_
seq_lens
_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
kv_block_stride
,
kv_head_stride
,
q_stride
,
q_head_stride
,
head_size
,
offset
,
num_seqs
);
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
std
::
cerr
<<
"CUDA Kernel Launch Failed: "
<<
cudaGetErrorString
(
err
)
<<
std
::
endl
;
return
INFINI_STATUS_INTERNAL_ERROR
;
}
num_seqs
);
return
INFINI_STATUS_SUCCESS
;
}
...
...
@@ -76,16 +68,17 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset
_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q
_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
info
=
PagedAttentionPrefillInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
cache_lens_desc
,
seq_lens_desc
,
offset_desc
,
alibi_slopes_desc
,
scale
);
auto
info
=
PagedAttentionPrefillInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
seq_lens_desc
,
cum_seq_lens_q_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
...
...
@@ -98,28 +91,25 @@ infiniStatus_t Descriptor::create(
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
offset
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
if
(
_info
.
head_size
>
1024
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)cache_lens, (const int64_t *)seq_lens, \
(const int64_t *)offset, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
_info.head_size, \
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
...
...
src/infiniop/ops/paged_attention_prefill/operator.cc
View file @
3883f32f
...
...
@@ -14,9 +14,8 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset
_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q
_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
)
{
...
...
@@ -27,15 +26,16 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc,
cache_lens_desc,
\
seq_lens_desc,
offset
_desc, alibi_opt, scale);
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc,
\
seq_lens_desc,
cum_seq_lens_q
_desc, alibi_opt, scale);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetPagedAttentionPrefillWorkspaceSize
(
...
...
@@ -51,16 +51,18 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopPagedAttentionPrefill
(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
offset
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
void
*
stream
)
{
...
...
@@ -68,14 +70,15 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
cache
_lens, seq_lens
, offset
, alibi_slopes, stream);
seq
_lens,
cum_
seq_lens
_q
, alibi_slopes, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyPagedAttentionPrefillDescriptor
(
...
...
@@ -90,6 +93,7 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
View file @
3883f32f
...
...
@@ -4,53 +4,53 @@
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE)
\
\
namespace op::paged_attention_prefill::NAMESPACE {
\
class Descriptor final : public InfiniopDescriptor {
\
struct Opaque;
\
Opaque *_opaque;
\
PagedAttentionPrefillInfo _info;
\
size_t _workspace_size;
\
\
Descriptor(
\
Opaque *opaque,
\
PagedAttentionPrefillInfo info,
\
size_t workspace_size,
\
infiniDevice_t device_type,
\
int device_id)
\
: InfiniopDescriptor{device_type, device_id},
\
_opaque(opaque),
\
_info(info),
\
_workspace_size(workspace_size) {}
\
\
public:
\
~Descriptor();
\
\
size_t workspaceSize() const { return _workspace_size; }
\
\
static infiniStatus_t create(
\
infiniopHandle_t handle,
\
Descriptor **desc_ptr,
\
infiniopTensorDescriptor_t out_desc,
\
infiniopTensorDescriptor_t q_desc,
\
infiniopTensorDescriptor_t k_cache_desc,
\
infiniopTensorDescriptor_t v_cache_desc,
\
infiniopTensorDescriptor_t block_tables_desc,
\
infiniopTensorDescriptor_t
cache
_lens_desc,
\
infiniopTensorDescriptor_t seq_lens_desc,
\
infiniopTensorDescriptor_t
offset_desc,
\
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
\
float scale);
\
\
infiniStatus_t calculate(
\
void *
workspace, size_t workspace_size,
\
void *out,
const void *
q, const void *k_cache, const void *v_cache,
\
const void *
block_tables, const void *cache_lens, const void *seq_lens,
\
const void *
offset,
\
const void *alibi_slopes,
\
void *stream) const;
\
};
\
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionPrefillInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionPrefillInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t
seq
_lens_desc, \
infiniopTensorDescriptor_t
cum_
seq_lens_
q_
desc, \
const std::optional<
infiniopTensorDescriptor_t
> &alibi_slopes_desc,
\
float scale);
\
\
infiniStatus_t calculate(
\
void *workspace, size_t workspace_size,
\
void *
out, const void *q, const void *k_cache, const void *v_cache,
\
const void *
block_tables,
\
const void *
seq_lens,
\
const void *
cum_seq_lens_q,
\
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_PREFILL_H
test/infinicore/ops/paged_attention_prefill.py
0 → 100644
View file @
3883f32f
import
os
import
sys
import
torch
import
infinicore
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
from
framework
import
(
BaseOperatorTest
,
GenericTestRunner
,
TensorInitializer
,
TensorSpec
,
TestCase
,
)
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
_TEST_CASES_DATA
=
[
(
1
,
1
,
1
,
128
,
8
,
16
,
1
),
(
1
,
4
,
4
,
128
,
8
,
16
,
4
),
(
2
,
8
,
8
,
128
,
16
,
32
,
2
),
(
4
,
16
,
16
,
128
,
8
,
64
,
3
),
(
8
,
64
,
64
,
128
,
8
,
16
,
5
),
]
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-4
},
# float32 调优容限
infinicore
.
bfloat16
:
{
"atol"
:
2e-2
,
"rtol"
:
2e-2
},
}
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
class
SimpleCacheManager
:
def
__init__
(
self
,
num_blocks
,
block_size
):
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
free_blocks
=
list
(
range
(
num_blocks
))
self
.
request_to_blocks
=
{}
self
.
request_to_len
=
{}
def
allocate_slots
(
self
,
request_id
,
num_new_tokens
):
if
request_id
not
in
self
.
request_to_len
:
self
.
request_to_len
[
request_id
]
=
0
self
.
request_to_blocks
[
request_id
]
=
[]
start_pos
=
self
.
request_to_len
[
request_id
]
new_total_len
=
start_pos
+
num_new_tokens
needed_blocks
=
(
new_total_len
+
self
.
block_size
-
1
)
//
self
.
block_size
added_blocks
=
needed_blocks
-
len
(
self
.
request_to_blocks
[
request_id
])
for
_
in
range
(
added_blocks
):
self
.
request_to_blocks
[
request_id
].
append
(
self
.
free_blocks
.
pop
(
0
))
self
.
request_to_len
[
request_id
]
=
new_total_len
return
self
.
request_to_blocks
[
request_id
],
new_total_len
def
parse_test_cases
():
test_cases
=
[]
for
(
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
max_step_len
,
num_rounds
,
)
in
_TEST_CASES_DATA
:
scale
=
head_size
**-
0.5
num_blocks
=
8192
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
kv_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int64
)
persistent_k
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
persistent_v
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
for
r
in
range
(
num_rounds
):
q_lens
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
kv_lens
=
kv_lens
+
q_lens
total_q_tokens
=
q_lens
.
sum
().
item
()
cum_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int64
)
cum_seqlens_q
[
1
:]
=
torch
.
cumsum
(
q_lens
,
dim
=
0
)
query_base
=
torch
.
randn
((
total_q_tokens
,
num_heads
,
head_size
))
round_block_tables_list
=
[]
for
i
in
range
(
num_seqs
):
p_blocks
,
total_len
=
manager
.
allocate_slots
(
i
,
q_lens
[
i
].
item
())
round_block_tables_list
.
append
(
p_blocks
)
h_len
=
kv_lens
[
i
].
item
()
-
q_lens
[
i
].
item
()
for
t
in
range
(
q_lens
[
i
].
item
()):
logical_pos
=
h_len
+
t
b_id
=
p_blocks
[
logical_pos
//
block_size
]
off
=
logical_pos
%
block_size
persistent_k
[
b_id
,
:,
off
,
:]
=
torch
.
randn
(
num_kv_heads
,
head_size
)
persistent_v
[
b_id
,
:,
off
,
:]
=
torch
.
randn
(
num_kv_heads
,
head_size
)
max_blks
=
max
(
len
(
t
)
for
t
in
round_block_tables_list
)
padded_tables
=
torch
.
tensor
(
[
t
+
[
0
]
*
(
max_blks
-
len
(
t
))
for
t
in
round_block_tables_list
]
)
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
TensorSpec
.
from_tensor
(
query_base
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
query_base
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_k
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_k
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_v
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_v
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
padded_tables
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
padded_tables
.
clone
(),
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
kv_lens
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
kv_lens
.
clone
(),
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
cum_seqlens_q
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
cum_seqlens_q
.
clone
(),
dtype
=
infinicore
.
int64
,
),
],
kwargs
=
{
"scale"
:
scale
},
tolerance
=
tolerance
,
description
=
f
"PagedAttentionPrefill_Round_
{
r
}
_
{
str
(
dtype
).
split
(
'.'
)[
-
1
]
}
"
,
)
)
return
test_cases
def
ref_paged_attention_multi_turn
(
query
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
scale
):
output
=
torch
.
zeros_like
(
query
)
num_seqs
=
len
(
kv_lens
)
block_size
=
k_cache
.
shape
[
2
]
for
i
in
range
(
num_seqs
):
q_start
,
q_end
=
cum_seqlens_q
[
i
].
item
(),
cum_seqlens_q
[
i
+
1
].
item
()
cur_q
=
query
[
q_start
:
q_end
]
q_len
=
q_end
-
q_start
h_len
=
kv_lens
[
i
].
item
()
-
q_len
total_len
=
h_len
+
q_len
table
=
block_tables
[
i
]
keys
,
values
=
[],
[]
for
j
in
range
(
total_len
):
b_id
=
table
[
j
//
block_size
].
item
()
off
=
j
%
block_size
keys
.
append
(
k_cache
[
b_id
,
:,
off
,
:])
values
.
append
(
v_cache
[
b_id
,
:,
off
,
:])
K
=
torch
.
stack
(
keys
,
dim
=
0
)
V
=
torch
.
stack
(
values
,
dim
=
0
)
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
cur_q
.
float
(),
K
.
float
())
*
scale
mask
=
torch
.
full
((
q_len
,
total_len
),
float
(
"-inf"
),
device
=
query
.
device
)
for
t
in
range
(
q_len
):
mask
[
t
,
:
h_len
+
t
+
1
]
=
0.0
attn
=
torch
.
softmax
(
scores
+
mask
.
unsqueeze
(
0
),
dim
=-
1
).
to
(
query
.
dtype
)
output
[
q_start
:
q_end
]
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
V
)
return
output
class
OpTest
(
BaseOperatorTest
):
def
__init__
(
self
):
super
().
__init__
(
"PagedAttentionPrefill"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
query
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
scale
=
1.0
,
):
return
ref_paged_attention_multi_turn
(
query
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
scale
)
def
infinicore_operator
(
self
,
query
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
scale
=
1.0
,
):
out
=
infinicore
.
paged_attention_prefill
(
query
,
k_cache
,
v_cache
,
block_tables
,
kv_lens
,
cum_seqlens_q
,
alibi_slopes
=
None
,
scale
=
scale
,
)
infinicore
.
sync_stream
()
return
out
def
main
():
"""Main entry point"""
runner
=
GenericTestRunner
(
OpTest
)
runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
test/infiniop/libinfiniop/op_register.py
View file @
3883f32f
...
...
@@ -1115,7 +1115,6 @@ def paged_attention_prefill_(lib):
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
c_float
,
]
...
...
@@ -1139,7 +1138,6 @@ def paged_attention_prefill_(lib):
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyPagedAttentionPrefillDescriptor
.
restype
=
c_int32
...
...
test/infiniop/paged_attention_prefill.py
View file @
3883f32f
import
torch
import
ctypes
from
ctypes
import
c_uint64
import
torch
from
libinfiniop
import
(
LIBINFINIOP
,
InfiniDeviceNames
,
InfiniDtype
,
InfiniDtypeNames
,
TestTensor
,
get_test_devi
ce
s
,
TestWorkspa
ce
,
check_error
,
test_operator
,
get_args
,
debug
,
get_args
,
get_test_devices
,
get_tolerance
,
profile_operation
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
TestWorkspace
,
profile_operation
,
test_operator
,
)
# ==============================================================================
...
...
@@ -74,14 +75,15 @@ class SimpleCacheManager:
def
ref_paged_attention_multi_turn
(
query_new
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
new_lens
,
offset
,
scale
query_new
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
cum_seq_lens_q
,
scale
):
block_size
=
k_cache
.
shape
[
2
]
outputs
=
torch
.
zeros_like
(
query_new
)
for
i
in
range
(
len
(
offset
)
-
1
):
num_seqs
=
len
(
cum_seq_lens_q
)
-
1
for
i
in
range
(
num_seqs
):
num_new
=
cum_seq_lens_q
[
i
+
1
].
item
()
-
cum_seq_lens_q
[
i
].
item
()
total_len
=
seq_lens
[
i
].
item
()
num_new
=
new_lens
[
i
].
item
()
history_len
=
total_len
-
num_new
cache_len
=
seq_lens
[
i
].
item
()
-
num_new
table
=
block_tables
[
i
]
keys_all
,
values_all
=
[],
[]
...
...
@@ -93,19 +95,19 @@ def ref_paged_attention_multi_turn(
K
=
torch
.
stack
(
keys_all
,
dim
=
0
)
V
=
torch
.
stack
(
values_all
,
dim
=
0
)
Q
=
query_new
[
offset
[
i
]
:
offset
[
i
]
+
num_new
,
:,
:]
Q
=
query_new
[
cum_seq_lens_q
[
i
]
:
cum_seq_lens_q
[
i
+
1
]
,
:,
:]
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
Q
,
K
).
float
()
*
scale
mask
=
torch
.
full
((
num_new
,
total_len
),
float
(
"-inf"
),
device
=
Q
.
device
)
for
q_idx
in
range
(
num_new
):
mask
[
q_idx
,
:
history
_len
+
q_idx
+
1
]
=
0.0
mask
[
q_idx
,
:
cache
_len
+
q_idx
+
1
]
=
0.0
scores
=
scores
+
mask
.
unsqueeze
(
0
)
attn_weights
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
Q
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
V
)
outputs
[
offset
[
i
]
:
offset
[
i
]
+
num_new
,
:,
:]
=
out
outputs
[
cum_seq_lens_q
[
i
]
:
cum_seq_lens_q
[
i
+
1
]
,
:,
:]
=
out
return
outputs
...
...
@@ -147,43 +149,43 @@ def test(
# Multi-turn testing loop
for
r
in
range
(
num_rounds
):
# Prepare dynamic inputs for this round
seq
_lens_cpu
=
torch
.
randint
(
query
_lens_cpu
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
q_total_tokens
=
seq
_lens_cpu
.
sum
().
item
()
q_total_tokens
=
query
_lens_cpu
.
sum
().
item
()
q_packed_tensors
=
torch
.
zeros
(
q_total_tokens
,
num_heads
,
head_size
)
cache
_lens_list
=
[]
seq
_lens_list
=
[]
all_block_tables
=
[]
offset
_list
=
[]
cu
r_offset
=
0
cum_seq_lens_q
_list
=
[]
cu
m_q_lens
=
0
for
i
in
range
(
num_seqs
):
offset
_list
.
append
(
cu
r_offset
)
cum_seq_lens_q
_list
.
append
(
cu
m_q_lens
)
cur_new_len
=
seq_lens_cpu
[
i
].
item
()
table
,
cache_len
=
manager
.
allocate_slots
(
i
,
cur_new_len
)
cache_lens_list
.
append
(
cache_len
)
cur_q_len
=
query_lens_cpu
[
i
].
item
()
table
,
total_len
=
manager
.
allocate_slots
(
i
,
cur_q_len
)
cur_seq_lens
=
total_len
-
cur_q_len
seq_lens_list
.
append
(
total_len
)
all_block_tables
.
append
(
table
)
# Simulated KV insertion
k_new
=
torch
.
randn
(
cur_
new
_len
,
num_kv_heads
,
head_size
)
v_new
=
torch
.
randn
(
cur_
new
_len
,
num_kv_heads
,
head_size
)
q_val
=
torch
.
randn
(
cur_
new
_len
,
num_heads
,
head_size
)
q_packed_tensors
[
cu
r_offset
:
cur_offset
+
cur_
new
_len
]
=
q_val
k_new
=
torch
.
randn
(
cur_
q
_len
,
num_kv_heads
,
head_size
)
v_new
=
torch
.
randn
(
cur_
q
_len
,
num_kv_heads
,
head_size
)
q_val
=
torch
.
randn
(
cur_
q
_len
,
num_heads
,
head_size
)
q_packed_tensors
[
cu
m_q_lens
:
cum_q_lens
+
cur_
q
_len
]
=
q_val
cu
r_offset
=
cur_offset
+
cur_
new
_len
cu
m_q_lens
=
cum_q_lens
+
cur_
q
_len
history_len
=
cache_len
-
cur_new_len
for
t
in
range
(
cur_new_len
):
logical_pos
=
history_len
+
t
for
t
in
range
(
cur_q_len
):
logical_pos
=
cur_seq_lens
+
t
b_id
=
table
[
logical_pos
//
block_size
]
off
=
logical_pos
%
block_size
k_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
k_new
[
t
]
v_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
v_new
[
t
]
offset
_list
.
append
(
cu
r_offset
)
cum_seq_lens_q
_list
.
append
(
cu
m_q_lens
)
k_cache
.
actual_tensor
().
copy_
(
k_cache
.
_torch_tensor
)
v_cache
.
actual_tensor
().
copy_
(
v_cache
.
_torch_tensor
)
...
...
@@ -193,13 +195,14 @@ def test(
out
=
TestTensor
.
from_torch
(
q_packed_tensors
,
dtype
,
device
)
out
.
actual_tensor
().
zero_
()
cache
_lens
=
TestTensor
.
from_torch
(
torch
.
tensor
(
cache
_lens_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
seq
_lens
=
TestTensor
.
from_torch
(
torch
.
tensor
(
seq
_lens_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
)
seq_lens
=
TestTensor
.
from_torch
(
seq_lens_cpu
,
InfiniDtype
.
I64
,
device
)
offset
=
TestTensor
.
from_torch
(
torch
.
tensor
(
offset_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
cum_seq_lens_q
=
TestTensor
.
from_torch
(
torch
.
tensor
(
cum_seq_lens_q_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
,
)
max_blocks
=
max
(
len
(
t
)
for
t
in
all_block_tables
)
...
...
@@ -215,9 +218,8 @@ def test(
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
block_tables
.
torch_tensor
(),
cache_lens
.
torch_tensor
(),
seq_lens
.
torch_tensor
(),
offset
.
torch_tensor
(),
cum_seq_lens_q
.
torch_tensor
(),
scale
,
)
...
...
@@ -234,10 +236,9 @@ def test(
k_cache
.
descriptor
,
v_cache
.
descriptor
,
block_tables
.
descriptor
,
cache_lens
.
descriptor
,
seq_lens
.
descriptor
,
offset
.
descriptor
,
None
,
# alibi_slopes_desc
cum_seq_lens_q
.
descriptor
,
None
,
scale
,
)
)
...
...
@@ -261,9 +262,8 @@ def test(
k_cache
.
data
(),
v_cache
.
data
(),
block_tables
.
data
(),
cache_lens
.
data
(),
seq_lens
.
data
(),
offset
.
data
(),
cum_seq_lens_q
.
data
(),
None
,
None
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment