Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
0a2839a2
"vscode:/vscode.git/clone" did not exist on "2d85781ba81e8abd54a71d7ad8b227d28e7cd73e"
Commit
0a2839a2
authored
Jan 07, 2026
by
zhushuang
Browse files
issue/867 - feat: adjust paged_attention_prefill interface naming
parent
3b5afffe
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
601 additions
and
231 deletions
+601
-231
include/infinicore/ops/paged_attention_prefill.hpp
include/infinicore/ops/paged_attention_prefill.hpp
+38
-4
include/infiniop/ops/paged_attention_prefill.h
include/infiniop/ops/paged_attention_prefill.h
+18
-14
python/infinicore/ops/paged_attention_prefill.py
python/infinicore/ops/paged_attention_prefill.py
+10
-11
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
...re/ops/paged_attention_prefill/paged_attention_prefill.cc
+19
-7
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
...ged_attention_prefill/paged_attention_prefill_infiniop.cc
+22
-6
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+2
-0
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
+68
-0
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
+21
-31
src/infiniop/ops/paged_attention_prefill/info.h
src/infiniop/ops/paged_attention_prefill/info.h
+36
-22
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+27
-41
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+8
-8
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
...iop/ops/paged_attention_prefill/paged_attention_prefill.h
+47
-47
test/infinicore/ops/paged_attention_prefill.py
test/infinicore/ops/paged_attention_prefill.py
+248
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+0
-2
test/infiniop/paged_attention_prefill.py
test/infiniop/paged_attention_prefill.py
+37
-38
No files found.
include/infinicore/ops/paged_attention_prefill.hpp
View file @
0a2839a2
...
@@ -8,11 +8,45 @@ namespace infinicore::op {
...
@@ -8,11 +8,45 @@ namespace infinicore::op {
class
PagedAttentionPrefill
{
class
PagedAttentionPrefill
{
public:
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
/**
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
);
* @brief PagedAttentionPrefill operator signature
* * Argument order:
* 1. out: Output tensor (Packed format)
* 2. q: Current Query tensor (Packed format)
* 3. k_cache: Physical Key cache (Paged format)
* 4. v_cache: Physical Value cache (Paged format)
* 5. block_tables: Mapping table from logical blocks to physical blocks
* 6. history_lens: Historical KV lengths (existing length of each sequence in cache)
* 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch)
* 8. alibi_slopes: ALiBi bias slopes (optional)
* 9. scale: Scaling factor (typically 1/sqrt(head_size))
*/
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
};
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
Tensor
paged_attention_prefill
(
Tensor
q
,
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
}
// namespace infinicore::op
include/infiniop/ops/paged_attention_prefill.h
View file @
0a2839a2
...
@@ -11,15 +11,22 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
...
@@ -11,15 +11,22 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
* @param handle The handle to the InfiniOP library context.
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* @param out_desc Descriptor for the output tensor.
* Shape: [total_q_tokens, num_heads, head_size]
* @param q_desc Descriptor for the query tensor (packed/flattened).
* @param q_desc Descriptor for the query tensor (packed/flattened).
* Shape: [total_q_tokens, num_heads, head_size]
* @param k_cache_desc Descriptor for the global physical key cache.
* @param k_cache_desc Descriptor for the global physical key cache.
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param v_cache_desc Descriptor for the global physical value cache.
* @param v_cache_desc Descriptor for the global physical value cache.
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* @param cache_lens_desc Descriptor for the total sequence lengths (history + current).
* Shape: [batch_size, max_blocks_per_seq]
* @param seq_lens_desc Descriptor for the current prefill sequence lengths.
* @param history_lens_desc Descriptor for the KV history lengths of each sequence.
* @param offset_desc Descriptor for the start position of each sequence in the packed Q tensor.
* Shape: [batch_size]
* @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence.
* Shape: [batch_size + 1]
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* Shape: [num_heads]
* @param scale The attention scaling factor (typically 1.0 / sqrt(head_size)).
* @return infiniStatus_t Status code of the operation.
* @return infiniStatus_t Status code of the operation.
*/
*/
__C
__export
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
__C
__export
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
...
@@ -30,9 +37,8 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
...
@@ -30,9 +37,8 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
history_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
infiniopTensorDescriptor_t
offset_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
);
float
scale
);
...
@@ -52,11 +58,10 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
...
@@ -52,11 +58,10 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
* @param k_cache Pointer to the global key cache data.
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param block_tables Pointer to the block tables data.
* @param cache_lens Pointer to the total sequence lengths data.
* @param history_lens Pointer to the KV history lengths data.
* @param seq_lens Pointer to the current prefill sequence lengths data.
* @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum).
* @param offset Pointer to the sequence start offsets data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The
CUDA/
device stream for the operation.
* @param stream The device stream
(e.g., cudaStream_t)
for the operation.
* @return infiniStatus_t Status code of the operation.
* @return infiniStatus_t Status code of the operation.
*/
*/
__C
__export
infiniStatus_t
infiniopPagedAttentionPrefill
(
__C
__export
infiniStatus_t
infiniopPagedAttentionPrefill
(
...
@@ -68,9 +73,8 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
...
@@ -68,9 +73,8 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
const
void
*
k_cache
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
history_lens
,
const
void
*
seq_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
offset
,
const
void
*
alibi_slopes
,
const
void
*
alibi_slopes
,
void
*
stream
);
void
*
stream
);
...
...
python/infinicore/ops/paged_attention_prefill.py
View file @
0a2839a2
...
@@ -7,14 +7,15 @@ def paged_attention_prefill(
...
@@ -7,14 +7,15 @@ def paged_attention_prefill(
k_cache
:
Tensor
,
k_cache
:
Tensor
,
v_cache
:
Tensor
,
v_cache
:
Tensor
,
block_tables
:
Tensor
,
block_tables
:
Tensor
,
cache_lens
:
Tensor
,
history_lens
:
Tensor
,
seq_lens
:
Tensor
,
cu_seqlens_q
:
Tensor
,
seq_offsets
:
Tensor
,
alibi_slopes
:
Tensor
|
None
=
None
,
alibi_slopes
:
Tensor
|
None
=
None
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
*
,
*
,
out
:
Tensor
|
None
=
None
,
out
:
Tensor
|
None
=
None
,
):
):
alibi_ptr
=
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
if
out
is
None
:
if
out
is
None
:
return
Tensor
(
return
Tensor
(
_infinicore
.
paged_attention_prefill
(
_infinicore
.
paged_attention_prefill
(
...
@@ -22,10 +23,9 @@ def paged_attention_prefill(
...
@@ -22,10 +23,9 @@ def paged_attention_prefill(
k_cache
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
history_lens
.
_underlying
,
seq_lens
.
_underlying
,
cu_seqlens_q
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_ptr
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
scale
,
)
)
)
)
...
@@ -36,10 +36,9 @@ def paged_attention_prefill(
...
@@ -36,10 +36,9 @@ def paged_attention_prefill(
k_cache
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
history_lens
.
_underlying
,
seq_lens
.
_underlying
,
cu_seqlens_q
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_ptr
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
scale
,
)
)
...
...
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
View file @
0a2839a2
...
@@ -9,20 +9,32 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
...
@@ -9,20 +9,32 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
return
dispatcher_
;
return
dispatcher_
;
};
};
void
PagedAttentionPrefill
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
void
PagedAttentionPrefill
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
);
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
);
infinicore
::
context
::
setDevice
(
out
->
device
());
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes
,
scale
);
}
}
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
seq
_
lens
,
seq_offsets
,
alibi_slopes
,
scale
);
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history
_lens
,
cu_
seqlens
_q
,
alibi_slopes
,
scale
);
return
out
;
return
out
;
}
}
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes
,
scale
);
}
}
}
// namespace infinicore::op
}
// namespace infinicore::op
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
View file @
0a2839a2
...
@@ -15,8 +15,11 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
...
@@ -15,8 +15,11 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
}
}
});
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
&
cache
=
caches
.
getCache
(
device
);
...
@@ -27,8 +30,13 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
...
@@ -27,8 +30,13 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if
(
!
desc_opt
)
{
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionPrefillDescriptor
(
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionPrefillDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
out
->
desc
(),
cache_lens
->
desc
(),
seq_lens
->
desc
(),
seq_offsets
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
history_lens
->
desc
(),
cu_seqlens_q
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
scale
));
cache
.
put
(
seed
,
desc
);
cache
.
put
(
seed
,
desc
);
...
@@ -41,8 +49,16 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
...
@@ -41,8 +49,16 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopPagedAttentionPrefill
(
INFINICORE_CHECK_ERROR
(
infiniopPagedAttentionPrefill
(
desc
,
workspace
->
data
(),
workspace_size
,
desc
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
cache_lens
->
data
(),
seq_lens
->
data
(),
seq_offsets
->
data
(),
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
history_lens
->
data
(),
cu_seqlens_q
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
context
::
getStream
()));
}
}
...
...
src/infinicore/pybind11/ops.hpp
View file @
0a2839a2
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ops/matmul.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rearrange.hpp"
...
@@ -33,6 +34,7 @@ inline void bind(py::module &m) {
...
@@ -33,6 +34,7 @@ inline void bind(py::module &m) {
bind_matmul
(
m
);
bind_matmul
(
m
);
bind_mul
(
m
);
bind_mul
(
m
);
bind_paged_attention
(
m
);
bind_paged_attention
(
m
);
bind_paged_attention_prefill
(
m
);
bind_paged_caching
(
m
);
bind_paged_caching
(
m
);
bind_rearrange
(
m
);
bind_rearrange
(
m
);
bind_rms_norm
(
m
);
bind_rms_norm
(
m
);
...
...
src/infinicore/pybind11/ops/paged_attention_prefill.hpp
0 → 100644
View file @
0a2839a2
#pragma once
#include "infinicore/ops/paged_attention_prefill.hpp"
#include <pybind11/pybind11.h>
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
Tensor
py_paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
py
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
return
op
::
paged_attention_prefill
(
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes_tensor
,
scale
);
}
void
py_paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
history_lens
,
Tensor
cu_seqlens_q
,
py
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
op
::
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes_tensor
,
scale
);
}
inline
void
bind_paged_attention_prefill
(
py
::
module
&
m
)
{
m
.
def
(
"paged_attention_prefill"
,
&
ops
::
py_paged_attention_prefill
,
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"history_lens"
),
py
::
arg
(
"cu_seqlens_q"
),
py
::
arg
(
"alibi_slopes"
)
=
py
::
none
(),
py
::
arg
(
"scale"
)
=
1.0
,
R"doc(Paged attention prefill for packed variable-length queries.)doc"
);
m
.
def
(
"paged_attention_prefill_"
,
&
ops
::
py_paged_attention_prefill_
,
py
::
arg
(
"out"
),
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"history_lens"
),
py
::
arg
(
"cu_seqlens_q"
),
py
::
arg
(
"alibi_slopes"
)
=
py
::
none
(),
py
::
arg
(
"scale"
)
=
1.0
,
R"doc(In-place paged attention prefill for packed variable-length queries.)doc"
);
}
}
// namespace infinicore::ops
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
View file @
0a2839a2
...
@@ -3,14 +3,13 @@
...
@@ -3,14 +3,13 @@
namespace
op
::
paged_attention_prefill
::
cuda
{
namespace
op
::
paged_attention_prefill
::
cuda
{
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
cum_seq_lens_q
,
size_t
num_seqs
)
{
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
offset
,
size_t
num_seqs
)
{
size_t
low
=
0
,
high
=
num_seqs
-
1
;
size_t
low
=
0
,
high
=
num_seqs
-
1
;
while
(
low
<=
high
)
{
while
(
low
<=
high
)
{
size_t
mid
=
(
low
+
high
)
>>
1
;
size_t
mid
=
(
low
+
high
)
>>
1
;
if
(
token_idx
>=
offset
[
mid
]
&&
token_idx
<
offset
[
mid
+
1
])
{
if
(
token_idx
>=
(
size_t
)
cum_seq_lens_q
[
mid
]
&&
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
+
1
])
{
return
mid
;
return
mid
;
}
else
if
(
token_idx
<
offset
[
mid
])
{
}
else
if
(
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
])
{
high
=
mid
-
1
;
high
=
mid
-
1
;
}
else
{
}
else
{
low
=
mid
+
1
;
low
=
mid
+
1
;
...
@@ -22,50 +21,43 @@ __device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *o
...
@@ -22,50 +21,43 @@ __device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *o
template
<
typename
Tdata
,
typename
Tcompute
>
template
<
typename
Tdata
,
typename
Tcompute
>
__global__
void
pagedAttentionPrefillKernel
(
__global__
void
pagedAttentionPrefillKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
cache_lens_
,
const
int64_t
*
seq_lens_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
history_lens_
,
const
int64_t
*
cum_seq_lens_q_
,
const
float
*
alibi_slopes_
,
const
float
*
alibi_slopes_
,
const
size_t
num_heads
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
num_heads
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
size_t
head_size
,
const
size_t
head_size
,
const
int64_t
*
offset_
,
const
size_t
num_seqs
)
{
const
size_t
num_seqs
)
{
//
--- 使用 2D Grid 坐标 ---
//
Grid : x -> token, y -> head
const
size_t
global_token_idx
=
blockIdx
.
x
;
// 展平后的全局 token 索引
const
size_t
global_token_idx
=
blockIdx
.
x
;
const
size_t
head_idx
=
blockIdx
.
y
;
// Head 索引
const
size_t
head_idx
=
blockIdx
.
y
;
const
size_t
dim_idx
=
threadIdx
.
x
;
// Head 内部维度
const
size_t
dim_idx
=
threadIdx
.
x
;
if
(
dim_idx
>=
head_size
)
{
if
(
dim_idx
>=
head_size
)
{
return
;
return
;
}
}
// --- 通过二分查找 offset 找到所属的 seq_idx ---
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
cum_seq_lens_q_
,
num_seqs
);
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
offset_
,
num_seqs
);
// --- 获取该 Sequence 本次 Prefill 的长度
size_t
q_token_idx
=
global_token_idx
-
cum_seq_lens_q_
[
seq_idx
];
const
int64_t
cur_new_len
=
seq_lens_
[
seq_idx
];
// --- 该 token 在当前序列中的相对位置
const
int64_t
history_len
=
history_lens_
[
seq_idx
];
size_t
q_token_idx
=
global_token_idx
-
offset_
[
seq
_idx
]
;
const
int64_t
causal_limit
=
history_len
+
q_token
_idx
;
const
Tdata
*
q_
ptr_base
=
q_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
const
Tdata
*
q_
vec
=
q_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
// --- KV Cache 相关信息
const
int64_t
total_seq_len
=
cache_lens_
[
seq_idx
];
const
int64_t
history_len
=
total_seq_len
-
cur_new_len
;
const
int64_t
causal_limit
=
history_len
+
q_token_idx
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
size_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
size_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
// Pass 1: 计算 Score 并找最大值
Tcompute
max_score
=
-
FLT_MAX
;
Tcompute
max_score
=
-
FLT_MAX
;
for
(
size
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
int64
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
...
@@ -73,7 +65,7 @@ __global__ void pagedAttentionPrefillKernel(
...
@@ -73,7 +65,7 @@ __global__ void pagedAttentionPrefillKernel(
Tcompute
score
=
0.0
f
;
Tcompute
score
=
0.0
f
;
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_
ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
score
+=
static_cast
<
Tcompute
>
(
q_
vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
if
(
alibi_slope
!=
0.0
f
)
{
...
@@ -84,9 +76,8 @@ __global__ void pagedAttentionPrefillKernel(
...
@@ -84,9 +76,8 @@ __global__ void pagedAttentionPrefillKernel(
}
}
}
}
// Pass 2: 计算 Sum of Exp
Tcompute
sum_exp
=
0.0
f
;
Tcompute
sum_exp
=
0.0
f
;
for
(
size
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
int64
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
...
@@ -94,7 +85,7 @@ __global__ void pagedAttentionPrefillKernel(
...
@@ -94,7 +85,7 @@ __global__ void pagedAttentionPrefillKernel(
Tcompute
score
=
0.0
f
;
Tcompute
score
=
0.0
f
;
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_
ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
score
+=
static_cast
<
Tcompute
>
(
q_
vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
if
(
alibi_slope
!=
0.0
f
)
{
...
@@ -103,10 +94,9 @@ __global__ void pagedAttentionPrefillKernel(
...
@@ -103,10 +94,9 @@ __global__ void pagedAttentionPrefillKernel(
sum_exp
+=
expf
(
static_cast
<
float
>
(
score
-
max_score
));
sum_exp
+=
expf
(
static_cast
<
float
>
(
score
-
max_score
));
}
}
// Pass 3: 加权求和得到输出
Tcompute
acc
=
0.0
f
;
Tcompute
acc
=
0.0
f
;
Tcompute
inv_sum
=
1.0
f
/
(
sum_exp
+
1e-6
f
);
Tcompute
inv_sum
=
1.0
f
/
(
sum_exp
+
1e-6
f
);
for
(
size
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
int64
_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
...
@@ -114,7 +104,7 @@ __global__ void pagedAttentionPrefillKernel(
...
@@ -114,7 +104,7 @@ __global__ void pagedAttentionPrefillKernel(
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
Tcompute
score
=
0.0
f
;
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_
ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
score
+=
static_cast
<
Tcompute
>
(
q_
vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
if
(
alibi_slope
!=
0.0
f
)
{
...
...
src/infiniop/ops/paged_attention_prefill/info.h
View file @
0a2839a2
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
#ifndef __
INFINIOP_
PAGED_ATTENTION_PREFILL_INFO_H__
#define __PAGED_ATTENTION_PREFILL_INFO_H__
#define __
INFINIOP_
PAGED_ATTENTION_PREFILL_INFO_H__
#include "../../../utils.h"
#include "../../../utils.h"
#include "../../tensor.h"
#include "../../tensor.h"
...
@@ -35,9 +35,8 @@ public:
...
@@ -35,9 +35,8 @@ public:
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
history_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
infiniopTensorDescriptor_t
offset_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
float
scale
)
{
...
@@ -47,39 +46,54 @@ public:
...
@@ -47,39 +46,54 @@ public:
if
(
out_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
if
(
out_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
if
(
offset_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
seq_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
if
(
cum_seq_lens_q_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
history_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
if
(
alibi_slopes_desc
.
has_value
()
&&
alibi_slopes_desc
.
value
()
!=
nullptr
)
{
if
(
alibi_slopes_desc
.
has_value
()
&&
alibi_slopes_desc
.
value
()
!=
nullptr
)
{
std
::
cerr
<<
"[Error] PagedAttentionPrefill: ALiBi slopes are not supported yet."
<<
std
::
endl
;
}
auto
k_shape
=
k_cache_desc
->
shape
();
auto
v_shape
=
v_cache_desc
->
shape
();
auto
block_tables_shape
=
block_tables_desc
->
shape
();
auto
history_lens_shape
=
history_lens_desc
->
shape
();
auto
cum_seq_lens_q_shape
=
cum_seq_lens_q_desc
->
shape
();
if
(
k_shape
.
size
()
!=
4
||
v_shape
.
size
()
!=
4
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
block_tables_shape
.
size
()
!=
2
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
history_lens_shape
.
size
()
!=
1
||
cum_seq_lens_q_shape
.
size
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
cum_seq_lens_q_shape
[
0
]
!=
history_lens_shape
[
0
]
+
1
)
{
return
INFINI_STATUS_BAD_PARAM
;
return
INFINI_STATUS_BAD_PARAM
;
}
}
// Q shape: [total_tokens, heads, dim]
(3D)
// Q shape: [total_tokens, heads, dim]
auto
q_shape
=
q_desc
->
shape
();
auto
q_shape
=
q_desc
->
shape
();
if
(
q_shape
.
size
()
<
3
)
{
if
(
q_shape
.
size
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
size_t
total_q_tokens
=
q_shape
[
0
];
size_t
total_q_tokens
=
q_shape
[
0
];
size_t
num_heads
=
q_shape
[
1
];
size_t
head_size
=
q_shape
[
2
];
size_t
num_heads
=
q_shape
[
q_shape
.
size
()
-
2
];
if
(
head_size
>
1024
)
{
size_t
head_size
=
q_shape
[
q_shape
.
size
()
-
1
];
return
INFINI_STATUS_BAD_PARAM
;
if
(
head_size
!=
128
)
{
std
::
cerr
<<
"[Error] PagedAttentionPrefill head_size = 128 supported, got "
<<
head_size
<<
std
::
endl
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
// 从 seq_lens 获取 num_seqs
size_t
num_seqs
=
history_lens_shape
[
0
];
size_t
num_seqs
=
seq_lens_desc
->
shape
()[
0
];
auto
k_cache_shape
=
k_cache_desc
->
shape
();
size_t
num_kv_heads
=
k_shape
[
1
];
size_t
num_kv_heads
=
k_cache_shape
[
1
];
size_t
block_size
=
k_shape
[
2
];
size_t
block_size
=
v_cache_desc
->
shape
()[
2
];
size_t
max_num_blocks_per_seq
=
block_tables_shape
[
1
];
size_t
max_num_blocks_per_seq
=
block_tables_desc
->
shape
()[
1
];
// 提取步长,需要保持多个请求的 Q 连续
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
...
...
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
0a2839a2
...
@@ -8,14 +8,12 @@
...
@@ -8,14 +8,12 @@
#include "../cuda/kernel.cuh"
#include "../cuda/kernel.cuh"
#include "paged_attention_prefill_nvidia.cuh"
#include "paged_attention_prefill_nvidia.cuh"
// ==============================================================================
// Host wrapper to launch the global kernel
// ==============================================================================
template
<
typename
Tdata
,
typename
Tcompute
>
template
<
typename
Tdata
,
typename
Tcompute
>
infiniStatus_t
launchPagedAttentionPrefill
(
infiniStatus_t
launchPagedAttentionPrefill
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
int64_t
*
block_tables
,
const
int64_t
*
cache_lens
,
const
int64_t
*
seq_lens
,
const
int64_t
*
block_tables
,
const
int64_t
*
offset
,
const
int64_t
*
history_lens
,
const
int64_t
*
cum_seq_lens_q
,
const
float
*
alibi_slopes
,
const
float
*
alibi_slopes
,
const
size_t
num_heads
,
const
size_t
num_heads
,
const
size_t
num_seqs
,
const
size_t
num_seqs
,
...
@@ -24,36 +22,27 @@ infiniStatus_t launchPagedAttentionPrefill(
...
@@ -24,36 +22,27 @@ infiniStatus_t launchPagedAttentionPrefill(
const
size_t
max_num_blocks_per_seq
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
size_t
block_size
,
const
size_t
total_q_tokens
,
const
size_t
total_q_tokens
,
const
ptrdiff_t
q_strid
e
,
const
size_t
head_siz
e
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
o_stride
,
const
size_t
head_size
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
if
(
total_q_tokens
==
0
||
num_heads
==
0
)
{
if
(
total_q_tokens
==
0
||
num_heads
==
0
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
// 使用 2D Grid: X轴是所有 Token,Y轴是所有 Head
dim3
grid
(
total_q_tokens
,
num_heads
);
dim3
grid
(
total_q_tokens
,
num_heads
);
dim3
block
(
head_size
);
dim3
block
(
head_size
);
op
::
paged_attention_prefill
::
cuda
::
pagedAttentionPrefillKernel
<
Tdata
,
Tcompute
>
op
::
paged_attention_prefill
::
cuda
::
pagedAttentionPrefillKernel
<
Tdata
,
Tcompute
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
seq_lens
,
alibi_slopes
,
block_tables
,
history
_lens
,
cum_
seq_lens
_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
max_num_blocks_per_seq
,
block_size
,
kv_block_stride
,
kv_head_stride
,
kv_block_stride
,
kv_head_stride
,
head_size
,
head_size
,
offset
,
num_seqs
);
num_seqs
);
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
std
::
cerr
<<
"CUDA Kernel Launch Failed: "
<<
cudaGetErrorString
(
err
)
<<
std
::
endl
;
return
INFINI_STATUS_INTERNAL_ERROR
;
}
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
...
@@ -76,16 +65,17 @@ infiniStatus_t Descriptor::create(
...
@@ -76,16 +65,17 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
history_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
infiniopTensorDescriptor_t
offset_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
float
scale
)
{
auto
info
=
PagedAttentionPrefillInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
auto
info
=
PagedAttentionPrefillInfo
::
create
(
block_tables_desc
,
cache_lens_desc
,
seq_lens_desc
,
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
offset_desc
,
block_tables_desc
,
history_lens_desc
,
alibi_slopes_desc
,
scale
);
cum_seq_lens_q_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
*
desc_ptr
=
new
Descriptor
(
...
@@ -98,28 +88,24 @@ infiniStatus_t Descriptor::create(
...
@@ -98,28 +88,24 @@ infiniStatus_t Descriptor::create(
infiniStatus_t
Descriptor
::
calculate
(
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
block_tables
,
const
void
*
offset
,
const
void
*
history_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
if
(
_info
.
head_size
>
1024
)
{
#define LAUNCH_KERNEL(Tdata, Tcompute) \
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
launchPagedAttentionPrefill<Tdata, Tcompute>( \
}
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)history_lens, (const int64_t *)cum_seq_lens_q, \
#define LAUNCH_KERNEL(Tdata, Tcompute) \
(const float *)alibi_slopes, \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
_info.scale, _info.max_num_blocks_per_seq, \
(const int64_t *)block_tables, (const int64_t *)cache_lens, (const int64_t *)seq_lens, \
_info.block_size, _info.total_q_tokens, \
(const int64_t *)offset, \
_info.head_size, \
(const float *)alibi_slopes, \
_info.kv_block_stride, _info.kv_head_stride, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
_info.head_size, \
stream)
stream)
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
...
...
src/infiniop/ops/paged_attention_prefill/operator.cc
View file @
0a2839a2
...
@@ -14,9 +14,8 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
...
@@ -14,9 +14,8 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
history_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
infiniopTensorDescriptor_t
offset_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
)
{
float
scale
)
{
...
@@ -27,8 +26,8 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
...
@@ -27,8 +26,8 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
handle, \
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc,
cache_lens_desc,
\
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc,
\
seq
_lens_desc,
offset
_desc, alibi_opt, scale);
history
_lens_desc,
cum_seq_lens_q
_desc, alibi_opt, scale);
switch
(
handle
->
device
)
{
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
...
@@ -59,8 +58,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
...
@@ -59,8 +58,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
infiniopPagedAttentionPrefillDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
block_tables
,
const
void
*
offset
,
const
void
*
history_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
const
void
*
alibi_slopes
,
void
*
stream
)
{
void
*
stream
)
{
...
@@ -68,7 +68,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
...
@@ -68,7 +68,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
case CASE: \
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
cache
_lens, seq_lens
, offset
, alibi_slopes, stream);
history
_lens,
cum_
seq_lens
_q
, alibi_slopes, stream);
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
...
...
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
View file @
0a2839a2
...
@@ -4,53 +4,53 @@
...
@@ -4,53 +4,53 @@
#include "../../operator.h"
#include "../../operator.h"
#include "info.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE)
\
#define DESCRIPTOR(NAMESPACE) \
\
\
namespace op::paged_attention_prefill::NAMESPACE {
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor {
\
class Descriptor final : public InfiniopDescriptor { \
struct Opaque;
\
struct Opaque; \
Opaque *_opaque;
\
Opaque *_opaque; \
PagedAttentionPrefillInfo _info;
\
PagedAttentionPrefillInfo _info; \
size_t _workspace_size;
\
size_t _workspace_size; \
\
\
Descriptor(
\
Descriptor( \
Opaque *opaque,
\
Opaque *opaque, \
PagedAttentionPrefillInfo info,
\
PagedAttentionPrefillInfo info, \
size_t workspace_size,
\
size_t workspace_size, \
infiniDevice_t device_type,
\
infiniDevice_t device_type, \
int device_id)
\
int device_id) \
: InfiniopDescriptor{device_type, device_id},
\
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque),
\
_opaque(opaque), \
_info(info),
\
_info(info), \
_workspace_size(workspace_size) {}
\
_workspace_size(workspace_size) {} \
\
\
public:
\
public: \
~Descriptor();
\
~Descriptor(); \
\
\
size_t workspaceSize() const { return _workspace_size; }
\
size_t workspaceSize() const { return _workspace_size; } \
\
\
static infiniStatus_t create(
\
static infiniStatus_t create( \
infiniopHandle_t handle,
\
infiniopHandle_t handle, \
Descriptor **desc_ptr,
\
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc,
\
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc,
\
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc,
\
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc,
\
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc,
\
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t
cache
_lens_desc,
\
infiniopTensorDescriptor_t
history
_lens_desc, \
infiniopTensorDescriptor_t seq_lens_desc,
\
infiniopTensorDescriptor_t
cum_
seq_lens_
q_
desc, \
infiniopTensorDescriptor_t
offset_desc,
\
const std::optional<
infiniopTensorDescriptor_t
> &alibi_slopes_desc,
\
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
\
float scale);
\
float scale);
\
\
\
infiniStatus_t calculate(
\
infiniStatus_t calculate(
\
void *workspace, size_t workspace_size,
\
void *
workspace, size_t workspace_size,
\
void *
out, const void *q, const void *k_cache, const void *v_cache,
\
void *out,
const void *
q, const void *k_cache, const void *v_cache,
\
const void *
block_tables,
\
const void *
block_tables, const void *cache_lens, const void *seq_lens,
\
const void *
history_lens,
\
const void *
offset,
\
const void *
cum_seq_lens_q,
\
const void *alibi_slopes,
\
const void *alibi_slopes, \
void *stream) const;
\
void *stream) const; \
};
\
}; \
}
}
#endif // PAGED_ATTENTION_PREFILL_H
#endif // PAGED_ATTENTION_PREFILL_H
test/infinicore/ops/paged_attention_prefill.py
0 → 100644
View file @
0a2839a2
import
sys
import
os
import
torch
import
infinicore
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
TensorInitializer
,
)
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
_TEST_CASES_DATA
=
[
(
1
,
1
,
1
,
128
,
8
,
16
,
1
),
(
1
,
4
,
4
,
128
,
8
,
16
,
4
),
(
2
,
8
,
8
,
128
,
16
,
32
,
2
),
(
4
,
16
,
16
,
128
,
8
,
64
,
3
),
(
8
,
64
,
64
,
128
,
8
,
16
,
5
),
]
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-4
},
# float32 调优容限
infinicore
.
bfloat16
:
{
"atol"
:
2e-2
,
"rtol"
:
2e-2
},
}
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
class
SimpleCacheManager
:
def
__init__
(
self
,
num_blocks
,
block_size
):
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
free_blocks
=
list
(
range
(
num_blocks
))
self
.
request_to_blocks
=
{}
self
.
request_to_len
=
{}
def
allocate_slots
(
self
,
request_id
,
num_new_tokens
):
if
request_id
not
in
self
.
request_to_len
:
self
.
request_to_len
[
request_id
]
=
0
self
.
request_to_blocks
[
request_id
]
=
[]
start_pos
=
self
.
request_to_len
[
request_id
]
new_total_len
=
start_pos
+
num_new_tokens
needed_blocks
=
(
new_total_len
+
self
.
block_size
-
1
)
//
self
.
block_size
added_blocks
=
needed_blocks
-
len
(
self
.
request_to_blocks
[
request_id
])
for
_
in
range
(
added_blocks
):
self
.
request_to_blocks
[
request_id
].
append
(
self
.
free_blocks
.
pop
(
0
))
self
.
request_to_len
[
request_id
]
=
new_total_len
return
self
.
request_to_blocks
[
request_id
],
new_total_len
def
parse_test_cases
():
test_cases
=
[]
for
(
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
max_step_len
,
num_rounds
,
)
in
_TEST_CASES_DATA
:
scale
=
head_size
**-
0.5
num_blocks
=
8192
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
current_history_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int64
)
persistent_k
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
persistent_v
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
for
r
in
range
(
num_rounds
):
q_lens
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
total_q_tokens
=
q_lens
.
sum
().
item
()
cu_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int64
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
q_lens
,
dim
=
0
)
query_base
=
torch
.
randn
((
total_q_tokens
,
num_heads
,
head_size
))
round_block_tables_list
=
[]
for
i
in
range
(
num_seqs
):
p_blocks
,
total_len
=
manager
.
allocate_slots
(
i
,
q_lens
[
i
].
item
())
round_block_tables_list
.
append
(
p_blocks
)
h_len
=
current_history_lens
[
i
].
item
()
q_start
=
cu_seqlens_q
[
i
].
item
()
for
t
in
range
(
q_lens
[
i
].
item
()):
logical_pos
=
h_len
+
t
b_id
=
p_blocks
[
logical_pos
//
block_size
]
off
=
logical_pos
%
block_size
persistent_k
[
b_id
,
:,
off
,
:]
=
torch
.
randn
(
num_kv_heads
,
head_size
)
persistent_v
[
b_id
,
:,
off
,
:]
=
torch
.
randn
(
num_kv_heads
,
head_size
)
max_blks
=
max
(
len
(
t
)
for
t
in
round_block_tables_list
)
padded_tables
=
torch
.
tensor
(
[
t
+
[
0
]
*
(
max_blks
-
len
(
t
))
for
t
in
round_block_tables_list
]
)
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
TensorSpec
.
from_tensor
(
query_base
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
query_base
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_k
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_k
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_v
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_v
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
padded_tables
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
padded_tables
.
clone
(),
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
current_history_lens
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
current_history_lens
.
clone
(),
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
cu_seqlens_q
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
cu_seqlens_q
.
clone
(),
dtype
=
infinicore
.
int64
,
),
],
kwargs
=
{
"scale"
:
scale
},
tolerance
=
tolerance
,
description
=
f
"PagedAttentionPrefill_Round_
{
r
}
_
{
str
(
dtype
).
split
(
'.'
)[
-
1
]
}
"
,
)
)
current_history_lens
+=
q_lens
return
test_cases
def
ref_paged_attention_multi_turn
(
query
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
scale
):
output
=
torch
.
zeros_like
(
query
)
num_seqs
=
len
(
history_lens
)
block_size
=
k_cache
.
shape
[
2
]
for
i
in
range
(
num_seqs
):
q_start
,
q_end
=
cu_seqlens_q
[
i
].
item
(),
cu_seqlens_q
[
i
+
1
].
item
()
cur_q
=
query
[
q_start
:
q_end
]
h_len
=
history_lens
[
i
].
item
()
q_len
=
q_end
-
q_start
total_len
=
h_len
+
q_len
table
=
block_tables
[
i
]
keys
,
values
=
[],
[]
for
j
in
range
(
total_len
):
b_id
=
table
[
j
//
block_size
].
item
()
off
=
j
%
block_size
keys
.
append
(
k_cache
[
b_id
,
:,
off
,
:])
values
.
append
(
v_cache
[
b_id
,
:,
off
,
:])
K
=
torch
.
stack
(
keys
,
dim
=
0
)
V
=
torch
.
stack
(
values
,
dim
=
0
)
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
cur_q
.
float
(),
K
.
float
())
*
scale
mask
=
torch
.
full
((
q_len
,
total_len
),
float
(
"-inf"
),
device
=
query
.
device
)
for
t
in
range
(
q_len
):
mask
[
t
,
:
h_len
+
t
+
1
]
=
0.0
attn
=
torch
.
softmax
(
scores
+
mask
.
unsqueeze
(
0
),
dim
=-
1
).
to
(
query
.
dtype
)
output
[
q_start
:
q_end
]
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
V
)
return
output
class
OpTest
(
BaseOperatorTest
):
def
__init__
(
self
):
super
().
__init__
(
"PagedAttentionPrefill"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
query
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
scale
=
1.0
,
):
return
ref_paged_attention_multi_turn
(
query
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
scale
)
def
infinicore_operator
(
self
,
query
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
scale
=
1.0
,
):
out
=
infinicore
.
paged_attention_prefill
(
query
,
k_cache
,
v_cache
,
block_tables
,
history_lens
,
cu_seqlens_q
,
alibi_slopes
=
None
,
scale
=
scale
,
)
infinicore
.
sync_stream
()
return
out
def
main
():
"""Main entry point"""
runner
=
GenericTestRunner
(
OpTest
)
runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
test/infiniop/libinfiniop/op_register.py
View file @
0a2839a2
...
@@ -1115,7 +1115,6 @@ def paged_attention_prefill_(lib):
...
@@ -1115,7 +1115,6 @@ def paged_attention_prefill_(lib):
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
c_float
,
c_float
,
]
]
...
@@ -1139,7 +1138,6 @@ def paged_attention_prefill_(lib):
...
@@ -1139,7 +1138,6 @@ def paged_attention_prefill_(lib):
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
]
lib
.
infiniopDestroyPagedAttentionPrefillDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyPagedAttentionPrefillDescriptor
.
restype
=
c_int32
...
...
test/infiniop/paged_attention_prefill.py
View file @
0a2839a2
...
@@ -74,14 +74,15 @@ class SimpleCacheManager:
...
@@ -74,14 +74,15 @@ class SimpleCacheManager:
def
ref_paged_attention_multi_turn
(
def
ref_paged_attention_multi_turn
(
query_new
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
new_lens
,
offset
,
scale
query_new
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
cum_seq_lens_q
,
scale
):
):
block_size
=
k_cache
.
shape
[
2
]
block_size
=
k_cache
.
shape
[
2
]
outputs
=
torch
.
zeros_like
(
query_new
)
outputs
=
torch
.
zeros_like
(
query_new
)
for
i
in
range
(
len
(
offset
)
-
1
):
num_seqs
=
len
(
cum_seq_lens_q
)
-
1
total_len
=
seq_lens
[
i
].
item
()
for
i
in
range
(
num_seqs
):
num_new
=
new_lens
[
i
].
item
()
num_new
=
cum_seq_lens_q
[
i
+
1
].
item
()
-
cum_seq_lens_q
[
i
].
item
()
history_len
=
total_len
-
num_new
cache_len
=
seq_lens
[
i
].
item
()
total_len
=
seq_lens
[
i
].
item
()
+
num_new
table
=
block_tables
[
i
]
table
=
block_tables
[
i
]
keys_all
,
values_all
=
[],
[]
keys_all
,
values_all
=
[],
[]
...
@@ -93,19 +94,19 @@ def ref_paged_attention_multi_turn(
...
@@ -93,19 +94,19 @@ def ref_paged_attention_multi_turn(
K
=
torch
.
stack
(
keys_all
,
dim
=
0
)
K
=
torch
.
stack
(
keys_all
,
dim
=
0
)
V
=
torch
.
stack
(
values_all
,
dim
=
0
)
V
=
torch
.
stack
(
values_all
,
dim
=
0
)
Q
=
query_new
[
offset
[
i
]
:
offset
[
i
]
+
num_new
,
:,
:]
Q
=
query_new
[
cum_seq_lens_q
[
i
]
:
cum_seq_lens_q
[
i
+
1
]
,
:,
:]
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
Q
,
K
).
float
()
*
scale
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
Q
,
K
).
float
()
*
scale
mask
=
torch
.
full
((
num_new
,
total_len
),
float
(
"-inf"
),
device
=
Q
.
device
)
mask
=
torch
.
full
((
num_new
,
total_len
),
float
(
"-inf"
),
device
=
Q
.
device
)
for
q_idx
in
range
(
num_new
):
for
q_idx
in
range
(
num_new
):
mask
[
q_idx
,
:
history
_len
+
q_idx
+
1
]
=
0.0
mask
[
q_idx
,
:
cache
_len
+
q_idx
+
1
]
=
0.0
scores
=
scores
+
mask
.
unsqueeze
(
0
)
scores
=
scores
+
mask
.
unsqueeze
(
0
)
attn_weights
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
Q
.
dtype
)
attn_weights
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
Q
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
V
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
V
)
outputs
[
offset
[
i
]
:
offset
[
i
]
+
num_new
,
:,
:]
=
out
outputs
[
cum_seq_lens_q
[
i
]
:
cum_seq_lens_q
[
i
+
1
]
,
:,
:]
=
out
return
outputs
return
outputs
...
@@ -147,43 +148,43 @@ def test(
...
@@ -147,43 +148,43 @@ def test(
# Multi-turn testing loop
# Multi-turn testing loop
for
r
in
range
(
num_rounds
):
for
r
in
range
(
num_rounds
):
# Prepare dynamic inputs for this round
# Prepare dynamic inputs for this round
seq
_lens_cpu
=
torch
.
randint
(
query
_lens_cpu
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
)
q_total_tokens
=
seq
_lens_cpu
.
sum
().
item
()
q_total_tokens
=
query
_lens_cpu
.
sum
().
item
()
q_packed_tensors
=
torch
.
zeros
(
q_total_tokens
,
num_heads
,
head_size
)
q_packed_tensors
=
torch
.
zeros
(
q_total_tokens
,
num_heads
,
head_size
)
cache
_lens_list
=
[]
seq
_lens_list
=
[]
all_block_tables
=
[]
all_block_tables
=
[]
offset
_list
=
[]
cum_seq_lens_q
_list
=
[]
cu
r_offset
=
0
cu
m_q_lens
=
0
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
offset
_list
.
append
(
cu
r_offset
)
cum_seq_lens_q
_list
.
append
(
cu
m_q_lens
)
cur_new_len
=
seq_lens_cpu
[
i
].
item
()
cur_q_len
=
query_lens_cpu
[
i
].
item
()
table
,
cache_len
=
manager
.
allocate_slots
(
i
,
cur_new_len
)
table
,
total_len
=
manager
.
allocate_slots
(
i
,
cur_q_len
)
cache_lens_list
.
append
(
cache_len
)
cur_seq_lens
=
total_len
-
cur_q_len
seq_lens_list
.
append
(
cur_seq_lens
)
all_block_tables
.
append
(
table
)
all_block_tables
.
append
(
table
)
# Simulated KV insertion
# Simulated KV insertion
k_new
=
torch
.
randn
(
cur_
new
_len
,
num_kv_heads
,
head_size
)
k_new
=
torch
.
randn
(
cur_
q
_len
,
num_kv_heads
,
head_size
)
v_new
=
torch
.
randn
(
cur_
new
_len
,
num_kv_heads
,
head_size
)
v_new
=
torch
.
randn
(
cur_
q
_len
,
num_kv_heads
,
head_size
)
q_val
=
torch
.
randn
(
cur_
new
_len
,
num_heads
,
head_size
)
q_val
=
torch
.
randn
(
cur_
q
_len
,
num_heads
,
head_size
)
q_packed_tensors
[
cu
r_offset
:
cur_offset
+
cur_
new
_len
]
=
q_val
q_packed_tensors
[
cu
m_q_lens
:
cum_q_lens
+
cur_
q
_len
]
=
q_val
cu
r_offset
=
cur_offset
+
cur_
new
_len
cu
m_q_lens
=
cum_q_lens
+
cur_
q
_len
history_len
=
cache_len
-
cur_new_len
for
t
in
range
(
cur_q_len
):
for
t
in
range
(
cur_new_len
):
logical_pos
=
cur_seq_lens
+
t
logical_pos
=
history_len
+
t
b_id
=
table
[
logical_pos
//
block_size
]
b_id
=
table
[
logical_pos
//
block_size
]
off
=
logical_pos
%
block_size
off
=
logical_pos
%
block_size
k_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
k_new
[
t
]
k_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
k_new
[
t
]
v_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
v_new
[
t
]
v_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
v_new
[
t
]
offset
_list
.
append
(
cu
r_offset
)
cum_seq_lens_q
_list
.
append
(
cu
m_q_lens
)
k_cache
.
actual_tensor
().
copy_
(
k_cache
.
_torch_tensor
)
k_cache
.
actual_tensor
().
copy_
(
k_cache
.
_torch_tensor
)
v_cache
.
actual_tensor
().
copy_
(
v_cache
.
_torch_tensor
)
v_cache
.
actual_tensor
().
copy_
(
v_cache
.
_torch_tensor
)
...
@@ -193,13 +194,14 @@ def test(
...
@@ -193,13 +194,14 @@ def test(
out
=
TestTensor
.
from_torch
(
q_packed_tensors
,
dtype
,
device
)
out
=
TestTensor
.
from_torch
(
q_packed_tensors
,
dtype
,
device
)
out
.
actual_tensor
().
zero_
()
out
.
actual_tensor
().
zero_
()
cache
_lens
=
TestTensor
.
from_torch
(
seq
_lens
=
TestTensor
.
from_torch
(
torch
.
tensor
(
cache
_lens_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
torch
.
tensor
(
seq
_lens_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
)
)
seq_lens
=
TestTensor
.
from_torch
(
seq_lens_cpu
,
InfiniDtype
.
I64
,
device
)
offset
=
TestTensor
.
from_torch
(
cum_seq_lens_q
=
TestTensor
.
from_torch
(
torch
.
tensor
(
offset_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
torch
.
tensor
(
cum_seq_lens_q_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
,
)
)
max_blocks
=
max
(
len
(
t
)
for
t
in
all_block_tables
)
max_blocks
=
max
(
len
(
t
)
for
t
in
all_block_tables
)
...
@@ -215,9 +217,8 @@ def test(
...
@@ -215,9 +217,8 @@ def test(
k_cache
.
torch_tensor
(),
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
block_tables
.
torch_tensor
(),
block_tables
.
torch_tensor
(),
cache_lens
.
torch_tensor
(),
seq_lens
.
torch_tensor
(),
seq_lens
.
torch_tensor
(),
offset
.
torch_tensor
(),
cum_seq_lens_q
.
torch_tensor
(),
scale
,
scale
,
)
)
...
@@ -234,10 +235,9 @@ def test(
...
@@ -234,10 +235,9 @@ def test(
k_cache
.
descriptor
,
k_cache
.
descriptor
,
v_cache
.
descriptor
,
v_cache
.
descriptor
,
block_tables
.
descriptor
,
block_tables
.
descriptor
,
cache_lens
.
descriptor
,
seq_lens
.
descriptor
,
seq_lens
.
descriptor
,
offset
.
descriptor
,
cum_seq_lens_q
.
descriptor
,
None
,
# alibi_slopes_desc
None
,
scale
,
scale
,
)
)
)
)
...
@@ -261,9 +261,8 @@ def test(
...
@@ -261,9 +261,8 @@ def test(
k_cache
.
data
(),
k_cache
.
data
(),
v_cache
.
data
(),
v_cache
.
data
(),
block_tables
.
data
(),
block_tables
.
data
(),
cache_lens
.
data
(),
seq_lens
.
data
(),
seq_lens
.
data
(),
offset
.
data
(),
cum_seq_lens_q
.
data
(),
None
,
None
,
None
,
None
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment