Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
caa61e9e
Unverified
Commit
caa61e9e
authored
Dec 30, 2025
by
PanZezhong1725
Committed by
GitHub
Dec 30, 2025
Browse files
Merge pull request #868 from InfiniTensor/issue/847
Issue/847 paged attention prefill一段式接口
parents
31c0af3f
99b940b2
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
206 additions
and
46 deletions
+206
-46
include/infinicore/common/hash.hpp
include/infinicore/common/hash.hpp
+10
-0
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+1
-0
include/infinicore/ops/paged_attention.hpp
include/infinicore/ops/paged_attention.hpp
+3
-3
include/infinicore/ops/paged_attention_prefill.hpp
include/infinicore/ops/paged_attention_prefill.hpp
+18
-0
python/infinicore/__init__.py
python/infinicore/__init__.py
+2
-0
python/infinicore/ops/paged_attention.py
python/infinicore/ops/paged_attention.py
+3
-3
python/infinicore/ops/paged_attention_prefill.py
python/infinicore/ops/paged_attention_prefill.py
+46
-0
src/infinicore/ops/paged_attention/paged_attention.cc
src/infinicore/ops/paged_attention/paged_attention.cc
+7
-7
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
...nfinicore/ops/paged_attention/paged_attention_infiniop.cc
+4
-4
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
...re/ops/paged_attention_prefill/paged_attention_prefill.cc
+28
-0
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
...ged_attention_prefill/paged_attention_prefill_infiniop.cc
+55
-0
src/infinicore/pybind11/ops/paged_attention.hpp
src/infinicore/pybind11/ops/paged_attention.hpp
+6
-6
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
+14
-14
test/infinicore/ops/paged_attention.py
test/infinicore/ops/paged_attention.py
+8
-8
xmake/nvidia.lua
xmake/nvidia.lua
+1
-1
No files found.
include/infinicore/common/hash.hpp
View file @
caa61e9e
...
...
@@ -2,6 +2,7 @@
#include "../tensor.hpp"
#include <optional>
#include <type_traits>
namespace
infinicore
{
...
...
@@ -24,6 +25,15 @@ inline void hash_combine(size_t &seed, Tensor tensor) {
}
}
// Specialization for optional
template
<
typename
T
>
inline
void
hash_combine
(
size_t
&
seed
,
const
std
::
optional
<
T
>
&
opt
)
{
hash_combine
(
seed
,
opt
.
has_value
());
if
(
opt
)
{
hash_combine
(
seed
,
*
opt
);
}
}
// Specialization for std::string
inline
void
hash_combine
(
size_t
&
seed
,
const
std
::
string
&
str
)
{
hash_combine
(
seed
,
std
::
hash
<
std
::
string
>
{}(
str
));
...
...
include/infinicore/ops.hpp
View file @
caa61e9e
...
...
@@ -6,6 +6,7 @@
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
...
...
include/infinicore/ops/paged_attention.hpp
View file @
caa61e9e
...
...
@@ -9,10 +9,10 @@ namespace infinicore::op {
class
PagedAttention
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
include/infinicore/ops/paged_attention_prefill.hpp
0 → 100644
View file @
caa61e9e
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
class
PagedAttentionPrefill
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
python/infinicore/__init__.py
View file @
caa61e9e
...
...
@@ -45,6 +45,7 @@ from infinicore.ops.matmul import matmul
from
infinicore.ops.mul
import
mul
from
infinicore.ops.narrow
import
narrow
from
infinicore.ops.paged_attention
import
paged_attention
from
infinicore.ops.paged_attention_prefill
import
paged_attention_prefill
from
infinicore.ops.paged_caching
import
paged_caching
from
infinicore.ops.rearrange
import
rearrange
from
infinicore.ops.squeeze
import
squeeze
...
...
@@ -119,6 +120,7 @@ __all__ = [
"from_torch"
,
"paged_caching"
,
"paged_attention"
,
"paged_attention_prefill"
,
"ones"
,
"strided_empty"
,
"strided_from_blob"
,
...
...
python/infinicore/ops/paged_attention.py
View file @
caa61e9e
...
...
@@ -7,7 +7,7 @@ def paged_attention(
k_cache
:
Tensor
,
v_cache
:
Tensor
,
block_tables
:
Tensor
,
seq
_lens
:
Tensor
,
cache
_lens
:
Tensor
,
alibi_slopes
:
Tensor
|
None
=
None
,
scale
:
float
=
1.0
,
*
,
...
...
@@ -20,7 +20,7 @@ def paged_attention(
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
seq
_lens
.
_underlying
,
cache
_lens
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
...
...
@@ -32,7 +32,7 @@ def paged_attention(
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
seq
_lens
.
_underlying
,
cache
_lens
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
...
...
python/infinicore/ops/paged_attention_prefill.py
0 → 100644
View file @
caa61e9e
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
paged_attention_prefill
(
q
:
Tensor
,
k_cache
:
Tensor
,
v_cache
:
Tensor
,
block_tables
:
Tensor
,
cache_lens
:
Tensor
,
seq_lens
:
Tensor
,
seq_offsets
:
Tensor
,
alibi_slopes
:
Tensor
|
None
=
None
,
scale
:
float
=
1.0
,
*
,
out
:
Tensor
|
None
=
None
,
):
if
out
is
None
:
return
Tensor
(
_infinicore
.
paged_attention_prefill
(
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
seq_lens
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
)
_infinicore
.
paged_attention_prefill_
(
out
.
_underlying
,
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
cache_lens
.
_underlying
,
seq_lens
.
_underlying
,
seq_offsets
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
return
out
src/infinicore/ops/paged_attention/paged_attention.cc
View file @
caa61e9e
...
...
@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return
dispatcher_
;
};
void
PagedAttention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq
_lens
);
void
PagedAttention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq
_lens
,
alibi_slopes
,
scale
);
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq
_lens
,
alibi_slopes
,
scale
);
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttention
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq
_lens
,
alibi_slopes
,
scale
);
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttention
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
View file @
caa61e9e
...
...
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
}
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
);
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
...
...
@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
seq
_lens
->
desc
(),
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
cache
_lens
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
...
...
@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
INFINICORE_CHECK_ERROR
(
infiniopPagedAttention
(
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
seq
_lens
->
data
(),
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
cache
_lens
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
...
...
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc
0 → 100644
View file @
caa61e9e
#include "infinicore/ops/paged_attention_prefill.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
PagedAttentionPrefill
::
schema
>
&
PagedAttentionPrefill
::
dispatcher
()
{
static
common
::
OpDispatcher
<
PagedAttentionPrefill
::
schema
>
dispatcher_
;
return
dispatcher_
;
};
void
PagedAttentionPrefill
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention_prefill
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_prefill_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_prefill_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttentionPrefill
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc
0 → 100644
View file @
caa61e9e
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention_prefill.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
paged_attention_prefill_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopPagedAttentionPrefillDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopPagedAttentionPrefillDescriptor_t
&
desc
)
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyPagedAttentionPrefillDescriptor
(
desc
));
desc
=
nullptr
;
}
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache_lens
,
Tensor
seq_lens
,
Tensor
seq_offsets
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
seq_offsets
,
alibi_slopes
,
scale
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopPagedAttentionPrefillDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionPrefillDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
cache_lens
->
desc
(),
seq_lens
->
desc
(),
seq_offsets
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetPagedAttentionPrefillWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopPagedAttentionPrefill
(
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
cache_lens
->
data
(),
seq_lens
->
data
(),
seq_offsets
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
static
bool
registered
=
[]()
{
PagedAttentionPrefill
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
}
// namespace infinicore::op::paged_attention_prefill_impl::infiniop
src/infinicore/pybind11/ops/paged_attention.hpp
View file @
caa61e9e
...
...
@@ -8,21 +8,21 @@ namespace py = pybind11;
namespace
infinicore
::
ops
{
Tensor
py_paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
Tensor
py_paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
return
op
::
paged_attention
(
q
,
k_cache
,
v_cache
,
block_tables
,
seq
_lens
,
alibi_slopes_tensor
,
scale
);
return
op
::
paged_attention
(
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes_tensor
,
scale
);
}
void
py_paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq
_lens
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
void
py_paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
cache
_lens
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
op
::
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq
_lens
,
alibi_slopes_tensor
,
scale
);
op
::
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache
_lens
,
alibi_slopes_tensor
,
scale
);
}
inline
void
bind_paged_attention
(
py
::
module
&
m
)
{
...
...
@@ -32,7 +32,7 @@ inline void bind_paged_attention(py::module &m) {
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"
seq
_lens"
),
py
::
arg
(
"
cache
_lens"
),
py
::
arg
(
"alibi_slopes"
),
py
::
arg
(
"scale"
),
R"doc(Paged attention of query and key cache tensors.)doc"
);
...
...
@@ -44,7 +44,7 @@ inline void bind_paged_attention(py::module &m) {
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"
seq
_lens"
),
py
::
arg
(
"
cache
_lens"
),
py
::
arg
(
"alibi_slopes"
),
py
::
arg
(
"scale"
),
R"doc(In-place paged attention of query and key cache tensors.)doc"
);
...
...
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
View file @
caa61e9e
...
...
@@ -4,10 +4,10 @@
namespace
op
::
paged_attention_prefill
::
cuda
{
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__
__forceinline__
in
t
find_seq_id
(
in
t
token_idx
,
const
int64_t
*
offset
,
in
t
num_seqs
)
{
in
t
low
=
0
,
high
=
num_seqs
-
1
;
__device__
__forceinline__
size_
t
find_seq_id
(
size_
t
token_idx
,
const
int64_t
*
offset
,
size_
t
num_seqs
)
{
size_
t
low
=
0
,
high
=
num_seqs
-
1
;
while
(
low
<=
high
)
{
in
t
mid
=
(
low
+
high
)
>>
1
;
size_
t
mid
=
(
low
+
high
)
>>
1
;
if
(
token_idx
>=
offset
[
mid
]
&&
token_idx
<
offset
[
mid
+
1
])
{
return
mid
;
}
else
if
(
token_idx
<
offset
[
mid
])
{
...
...
@@ -32,22 +32,22 @@ __global__ void pagedAttentionPrefillKernel(
const
size_t
num_seqs
)
{
// --- 使用 2D Grid 坐标 ---
const
in
t
global_token_idx
=
blockIdx
.
x
;
// 展平后的全局 token 索引
const
in
t
head_idx
=
blockIdx
.
y
;
// Head 索引
const
in
t
dim_idx
=
threadIdx
.
x
;
// Head 内部维度
const
size_
t
global_token_idx
=
blockIdx
.
x
;
// 展平后的全局 token 索引
const
size_
t
head_idx
=
blockIdx
.
y
;
// Head 索引
const
size_
t
dim_idx
=
threadIdx
.
x
;
// Head 内部维度
if
(
dim_idx
>=
head_size
)
{
return
;
}
// --- 通过二分查找 offset 找到所属的 seq_idx ---
in
t
seq_idx
=
find_seq_id
(
global_token_idx
,
offset_
,
num_seqs
);
size_
t
seq_idx
=
find_seq_id
(
global_token_idx
,
offset_
,
num_seqs
);
// --- 获取该 Sequence 本次 Prefill 的长度
const
int64_t
cur_new_len
=
seq_lens_
[
seq_idx
];
// --- 该 token 在当前序列中的相对位置
in
t
q_token_idx
=
global_token_idx
-
offset_
[
seq_idx
];
size_
t
q_token_idx
=
global_token_idx
-
offset_
[
seq_idx
];
const
Tdata
*
q_ptr_base
=
q_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
...
...
@@ -65,14 +65,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 1: 计算 Score 并找最大值
Tcompute
max_score
=
-
FLT_MAX
;
for
(
in
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
size_
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
in
t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_
t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
...
...
@@ -86,14 +86,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 2: 计算 Sum of Exp
Tcompute
sum_exp
=
0.0
f
;
for
(
in
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
size_
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
in
t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_
t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
...
...
@@ -106,14 +106,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 3: 加权求和得到输出
Tcompute
acc
=
0.0
f
;
Tcompute
inv_sum
=
1.0
f
/
(
sum_exp
+
1e-6
f
);
for
(
in
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
for
(
size_
t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
in
t
d
=
0
;
d
<
head_size
;
++
d
)
{
for
(
size_
t
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
...
...
test/infinicore/ops/paged_attention.py
View file @
caa61e9e
...
...
@@ -62,7 +62,7 @@ def parse_test_cases():
max_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
num_blocks
=
num_seqs
*
max_blocks_per_seq
# A reasonable number for testing
seq
_lens_torch
=
torch
.
randint
(
1
,
max_seq_len
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
cache
_lens_torch
=
torch
.
randint
(
1
,
max_seq_len
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
block_tables
=
torch
.
arange
(
0
,
num_seqs
*
max_blocks_per_seq
,
dtype
=
torch
.
int64
...
...
@@ -75,7 +75,7 @@ def parse_test_cases():
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
block_tables_shape
=
block_tables
.
shape
seq
_lens_shape
=
seq
_lens_torch
.
shape
cache
_lens_shape
=
cache
_lens_torch
.
shape
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
...
...
@@ -91,10 +91,10 @@ def parse_test_cases():
set_tensor
=
block_tables
,
dtype
=
infinicore
.
int64
,
)
seq
_lens_spec
=
TensorSpec
.
from_tensor
(
seq
_lens_shape
,
cache
_lens_spec
=
TensorSpec
.
from_tensor
(
cache
_lens_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
seq
_lens_torch
,
set_tensor
=
cache
_lens_torch
,
dtype
=
infinicore
.
int64
,
)
...
...
@@ -108,7 +108,7 @@ def parse_test_cases():
k_cache_spec
,
v_cache_spec
,
block_tables_spec
,
seq
_lens_spec
,
cache
_lens_spec
,
],
kwargs
=
{
"alibi_slopes"
:
None
,
"scale"
:
scale
},
output_spec
=
None
,
...
...
@@ -132,7 +132,7 @@ def ref_masked_attention(query, key, value, scale, attn_mask=None):
def
ref_single_query_cached_kv_attention
(
query
,
key_cache
,
value_cache
,
block_tables
,
seq
_lens
,
alibi_slopes
,
scale
query
,
key_cache
,
value_cache
,
block_tables
,
cache
_lens
,
alibi_slopes
,
scale
):
# Reference implementation for paged attention, iterating through each sequence.
output
=
torch
.
empty_like
(
query
)
...
...
@@ -143,7 +143,7 @@ def ref_single_query_cached_kv_attention(
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
seq_len
=
seq
_lens
[
i
].
item
()
seq_len
=
cache
_lens
[
i
].
item
()
block_table
=
block_tables
[
i
]
keys_lst
,
values_lst
=
[],
[]
...
...
xmake/nvidia.lua
View file @
caa61e9e
...
...
@@ -55,7 +55,7 @@ target("infiniop-nvidia")
end
end
add_cuflags
(
"-Xcompiler=-Wno-error=deprecated-declarations"
)
add_cuflags
(
"-Xcompiler=-Wno-error=deprecated-declarations"
,
"-Xcompiler=-Wno-error=unused-function"
)
local
arch_opt
=
get_config
(
"cuda_arch"
)
if
arch_opt
and
type
(
arch_opt
)
==
"string"
then
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment