Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
99b940b2
Commit
99b940b2
authored
Dec 30, 2025
by
PanZezhong
Browse files
issue/847 paged attention prefill一段式接口
parent
e13ad8f9
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
176 additions
and
16 deletions
+176
-16
include/infinicore/common/hash.hpp
include/infinicore/common/hash.hpp
+10
-0
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+1
-0
include/infinicore/ops/paged_attention_prefill.hpp
include/infinicore/ops/paged_attention_prefill.hpp
+18
-0
python/infinicore/__init__.py
python/infinicore/__init__.py
+2
-0
python/infinicore/ops/paged_attention_prefill.py
python/infinicore/ops/paged_attention_prefill.py
+46
-0
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
...nfinicore/ops/paged_attention/paged_attention_infiniop.cc
+1
-1
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
...re/ops/paged_attention_prefill/paged_attention_prefill.cc
+28
-0
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
...ged_attention_prefill/paged_attention_prefill_infiniop.cc
+55
-0
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
+14
-14
xmake/nvidia.lua
xmake/nvidia.lua
+1
-1
No files found.
include/infinicore/common/hash.hpp
View file @
99b940b2
...
...
@@ -2,6 +2,7 @@
#include "../tensor.hpp"
#include <optional>
#include <type_traits>
namespace
infinicore
{
...
...
@@ -24,6 +25,15 @@ inline void hash_combine(size_t &seed, Tensor tensor) {
}
}
// Specialization for optional
template
<
typename
T
>
inline
void
hash_combine
(
size_t
&
seed
,
const
std
::
optional
<
T
>
&
opt
)
{
hash_combine
(
seed
,
opt
.
has_value
());
if
(
opt
)
{
hash_combine
(
seed
,
*
opt
);
}
}
// Specialization for std::string
inline
void
hash_combine
(
size_t
&
seed
,
const
std
::
string
&
str
)
{
hash_combine
(
seed
,
std
::
hash
<
std
::
string
>
{}(
str
));
...
...
include/infinicore/ops.hpp
View file @
99b940b2
...
...
@@ -6,6 +6,7 @@
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
...
...
include/infinicore/ops/paged_attention_prefill.hpp
0 → 100644
View file @
99b940b2
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
class
PagedAttentionPrefill
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
python/infinicore/__init__.py
View file @
99b940b2
...
...
@@ -45,6 +45,7 @@ from infinicore.ops.matmul import matmul
from
infinicore.ops.mul
import
mul
from
infinicore.ops.narrow
import
narrow
from
infinicore.ops.paged_attention
import
paged_attention
from
infinicore.ops.paged_attention_prefill
import
paged_attention_prefill
from
infinicore.ops.paged_caching
import
paged_caching
from
infinicore.ops.rearrange
import
rearrange
from
infinicore.ops.squeeze
import
squeeze
...
...
@@ -119,6 +120,7 @@ __all__ = [
"from_torch"
,
"paged_caching"
,
"paged_attention"
,
"paged_attention_prefill"
,
"ones"
,
"strided_empty"
,
"strided_from_blob"
,
...
...
python/infinicore/ops/paged_attention_prefill.py
0 → 100644
View file @
99b940b2
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
paged_attention_prefill
(
q
:
Tensor
,
k_cache
:
Tensor
,
v_cache
:
Tensor
,
block_tables
:
Tensor
,
cache_lens
:
Tensor
,
seq_lens
:
Tensor
,
seq_offsets
:
Tensor
,
alibi_slopes
:
Tensor
|
None
=
None
,
scale
:
float
=
1.0
,
*
,
out
:
Tensor
|
None
=
None
,
):
if
out
is
None
:
return
Tensor
(
_infinicore
.
paged_attention_prefill
(
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
seq_lens
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
)
_infinicore
.
paged_attention_prefill_
(
out
.
_underlying
,
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
seq_lens
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
return
out
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
View file @
99b940b2
...
...
@@ -16,7 +16,7 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
);
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
...
...
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
0 → 100644
View file @
99b940b2
#include "infinicore/ops/paged_attention_prefill.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
PagedAttentionPrefill
::
schema
>
&
PagedAttentionPrefill
::
dispatcher
()
{
static
common
::
OpDispatcher
<
PagedAttentionPrefill
::
schema
>
dispatcher_
;
return
dispatcher_
;
};
void
PagedAttentionPrefill
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
0 → 100644
View file @
99b940b2
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention_prefill.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
paged_attention_prefill_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopPagedAttentionPrefillDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopPagedAttentionPrefillDescriptor_t
&
desc
)
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyPagedAttentionPrefillDescriptor
(
desc
));
desc
=
nullptr
;
}
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopPagedAttentionPrefillDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionPrefillDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
cache_lens
->
desc
(),
seq_lens
->
desc
(),
seq_offsets
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetPagedAttentionPrefillWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopPagedAttentionPrefill
(
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
cache_lens
->
data
(),
seq_lens
->
data
(),
seq_offsets
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
static
bool
registered
=
[]()
{
PagedAttentionPrefill
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
}
// namespace infinicore::op::paged_attention_prefill_impl::infiniop
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
View file @
99b940b2
...
...
@@ -4,10 +4,10 @@
namespace
op
::
paged_attention_prefill
::
cuda
{
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__
__forceinline__
in
t
find_seq_id
(
in
t
token_idx
,
const
int64_t
*
offset
,
in
t
num_seqs
)
{
in
t
low
=
0
,
high
=
num_seqs
-
1
;
__device__
__forceinline__
size_
t
find_seq_id
(
size_
t
token_idx
,
const
int64_t
*
offset
,
size_
t
num_seqs
)
{
size_
t
low
=
0
,
high
=
num_seqs
-
1
;
while
(
low
<=
high
)
{
in
t
mid
=
(
low
+
high
)
>>
1
;
size_
t
mid
=
(
low
+
high
)
>>
1
;
if
(
token_idx
>=
offset
[
mid
]
&&
token_idx
<
offset
[
mid
+
1
])
{
return
mid
;
}
else
if
(
token_idx
<
offset
[
mid
])
{
...
...
@@ -32,22 +32,22 @@ __global__ void pagedAttentionPrefillKernel(
const
size_t
num_seqs
)
{
// --- 使用 2D Grid 坐标 ---
const
in
t
global_token_idx
=
blockIdx
.
x
;
// 展平后的全局 token 索引
const
in
t
head_idx
=
blockIdx
.
y
;
// Head 索引
const
in
t
dim_idx
=
threadIdx
.
x
;
// Head 内部维度
const
size_
t
global_token_idx
=
blockIdx
.
x
;
// 展平后的全局 token 索引
const
size_
t
head_idx
=
blockIdx
.
y
;
// Head 索引
const
size_
t
dim_idx
=
threadIdx
.
x
;
// Head 内部维度
if
(
dim_idx
>=
head_size
)
{
return
;
}
// --- 通过二分查找 offset 找到所属的 seq_idx ---
in
t
seq_idx
=
find_seq_id
(
global_token_idx
,
offset_
,
num_seqs
);
size_
t
seq_idx
=
find_seq_id
(
global_token_idx
,
offset_
,
num_seqs
);
// --- 获取该 Sequence 本次 Prefill 的长度
const
int64_t
cur_new_len
=
seq_lens_
[
seq_idx
];
// --- 该 token 在当前序列中的相对位置
in
t
q_token_idx
=
global_token_idx
-
offset_
[
seq_idx
];
size_
t
q_token_idx
=
global_token_idx
-
offset_
[
seq_idx
];
const
Tdata
*
q_ptr_base
=
q_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
...
...
@@ -65,14 +65,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 1: 计算 Score 并找最大值
Tcompute
max_score
=
-
FLT_MAX
;
for
(
in
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
size_
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
in
t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_
t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
...
...
@@ -86,14 +86,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 2: 计算 Sum of Exp
Tcompute
sum_exp
=
0.0
f
;
for
(
in
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
size_
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
in
t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_
t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
...
...
@@ -106,14 +106,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 3: 加权求和得到输出
Tcompute
acc
=
0.0
f
;
Tcompute
inv_sum
=
1.0
f
/
(
sum_exp
+
1e-6
f
);
for
(
in
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
size_
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
in
t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_
t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
...
...
xmake/nvidia.lua
View file @
99b940b2
...
...
@@ -55,7 +55,7 @@ target("infiniop-nvidia")
end
end
add_cuflags
(
"-Xcompiler=-Wno-error=deprecated-declarations"
)
add_cuflags
(
"-Xcompiler=-Wno-error=deprecated-declarations"
,
"-Xcompiler=-Wno-error=unused-function"
)
local
arch_opt
=
get_config
(
"cuda_arch"
)
if
arch_opt
and
type
(
arch_opt
)
==
"string"
then
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment