Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
6074f7b8
Commit
6074f7b8
authored
Feb 10, 2026
by
zhushuang
Browse files
issue/1001 - feat: add paged attention prefill for moore gpu referencing nvidia
parent
3d3a277f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
281 additions
and
0 deletions
+281
-0
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h
..._attention_prefill/moore/paged_attention_prefill_kernel.h
+132
-0
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.h
...d_attention_prefill/moore/paged_attention_prefill_moore.h
+8
-0
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu
..._attention_prefill/moore/paged_attention_prefill_moore.mu
+126
-0
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+15
-0
No files found.
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h
0 → 100644
View file @
6074f7b8
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace
op
::
paged_attention_prefill
::
cuda
{
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
cum_seq_lens_q
,
size_t
num_seqs
)
{
size_t
low
=
0
,
high
=
num_seqs
-
1
;
while
(
low
<=
high
)
{
size_t
mid
=
(
low
+
high
)
>>
1
;
if
(
token_idx
>=
(
size_t
)
cum_seq_lens_q
[
mid
]
&&
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
+
1
])
{
return
mid
;
}
else
if
(
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
])
{
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
}
}
return
0
;
}
// Warp-level sum reduction with an explicit active mask (safe for partial warps).
__device__
__forceinline__
float
warpReduceSum
(
float
val
,
unsigned
mask
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
// Block-level sum reduction. Returns the sum to all threads in the block.
// Supports blockDim.x up to 1024.
__device__
__forceinline__
float
blockReduceSum
(
float
val
)
{
__shared__
float
shared
[
32
];
// max 32 warps per block
const
int
lane
=
threadIdx
.
x
&
31
;
const
int
wid
=
threadIdx
.
x
>>
5
;
const
unsigned
mask
=
__activemask
();
val
=
warpReduceSum
(
val
,
mask
);
if
(
lane
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
const
int
num_warps
=
(
blockDim
.
x
+
31
)
>>
5
;
float
sum
=
0.0
f
;
if
(
wid
==
0
)
{
sum
=
(
lane
<
num_warps
)
?
shared
[
lane
]
:
0.0
f
;
const
unsigned
mask0
=
(
num_warps
>=
32
)
?
0xffffffffu
:
((
1u
<<
num_warps
)
-
1u
);
sum
=
warpReduceSum
(
sum
,
mask0
);
if
(
lane
==
0
)
{
shared
[
0
]
=
sum
;
}
}
__syncthreads
();
return
shared
[
0
];
}
template
<
typename
Tdata
,
typename
Tcompute
>
__global__
void
pagedAttentionPrefillKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cum_seq_lens_q_
,
const
float
*
alibi_slopes_
,
const
size_t
num_heads
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
q_head_stride
,
const
size_t
head_size
,
const
size_t
num_seqs
)
{
// Grid : x -> token, y -> head
const
size_t
global_token_idx
=
blockIdx
.
x
;
const
size_t
head_idx
=
blockIdx
.
y
;
const
size_t
dim_idx
=
threadIdx
.
x
;
if
(
dim_idx
>=
head_size
)
{
return
;
}
__shared__
size_t
sh_seq_idx
;
__shared__
size_t
sh_causal_limit
;
__shared__
size_t
sh_kv_head_idx
;
__shared__
float
sh_scale_acc
;
__shared__
float
sh_w
;
__shared__
float
sh_inv_l
;
if
(
dim_idx
==
0
)
{
sh_seq_idx
=
find_seq_id
(
global_token_idx
,
cum_seq_lens_q_
,
num_seqs
);
const
size_t
q_token_idx
=
global_token_idx
-
static_cast
<
size_t
>
(
cum_seq_lens_q_
[
sh_seq_idx
]);
const
size_t
total_kv_len
=
static_cast
<
size_t
>
(
total_kv_lens_
[
sh_seq_idx
]);
const
size_t
q_len
=
static_cast
<
size_t
>
(
cum_seq_lens_q_
[
sh_seq_idx
+
1
]
-
cum_seq_lens_q_
[
sh_seq_idx
]);
const
size_t
history_len
=
total_kv_len
-
q_len
;
sh_causal_limit
=
history_len
+
q_token_idx
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
sh_kv_head_idx
=
head_idx
/
num_queries_per_kv
;
}
__syncthreads
();
const
size_t
seq_idx
=
sh_seq_idx
;
const
size_t
causal_limit
=
sh_causal_limit
;
const
size_t
kv_head_idx
=
sh_kv_head_idx
;
const
Tdata
*
q_vec
=
q_
+
global_token_idx
*
q_stride
+
head_idx
*
q_head_stride
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
const
float
qv
=
static_cast
<
float
>
(
q_vec
[
dim_idx
]);
Tcompute
acc
=
0.0
f
;
float
m
=
-
FLT_MAX
;
float
l
=
0.0
f
;
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size_t
b_idx
=
t
/
block_size
;
const
size_t
t_off
=
t
%
block_size
;
const
ptrdiff_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
const
float
dot
=
blockReduceSum
(
qv
*
static_cast
<
float
>
(
k_vec
[
dim_idx
]));
if
(
dim_idx
==
0
)
{
float
score
=
dot
*
static_cast
<
float
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
);
}
const
float
m_new
=
fmaxf
(
m
,
score
);
const
float
scale_acc
=
expf
(
m
-
m_new
);
const
float
w
=
expf
(
score
-
m_new
);
l
=
l
*
scale_acc
+
w
;
m
=
m_new
;
sh_scale_acc
=
scale_acc
;
sh_w
=
w
;
}
__syncthreads
();
const
float
scale_acc
=
sh_scale_acc
;
const
float
w
=
sh_w
;
const
Tdata
*
v_vec
=
v_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
acc
=
acc
*
static_cast
<
Tcompute
>
(
scale_acc
)
+
static_cast
<
Tcompute
>
(
w
)
*
static_cast
<
Tcompute
>
(
v_vec
[
dim_idx
]);
__syncthreads
();
}
if
(
dim_idx
==
0
)
{
sh_inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
__syncthreads
();
out_ptr
[
dim_idx
]
=
static_cast
<
Tdata
>
(
acc
*
static_cast
<
Tcompute
>
(
sh_inv_l
));
}
}
// namespace op::paged_attention_prefill::cuda
#endif
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.h
0 → 100644
View file @
6074f7b8
#ifndef __PAGED_ATTENTION_PREFILL_MOORE_H__
#define __PAGED_ATTENTION_PREFILL_MOORE_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR
(
moore
)
#endif // __PAGED_ATTENTION_PREFILL_MOORE_H__
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu
0 → 100644
View file @
6074f7b8
#include <musa_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "paged_attention_prefill_kernel.h"
#include "paged_attention_prefill_moore.h"
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t page_block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t k_batch_stride,
const ptrdiff_t k_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
musaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, page_block_size,
k_batch_stride, k_head_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
return INFINI_STATUS_SUCCESS;
}
namespace op::paged_attention_prefill::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.page_block_size, _info.total_q_tokens, \
_info.head_size, \
_info.k_batch_stride, _info.k_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__mt_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::moore
src/infiniop/ops/paged_attention_prefill/operator.cc
View file @
6074f7b8
...
...
@@ -8,6 +8,9 @@
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_attention_prefill_moore.h"
#endif
__C
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -44,6 +47,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -71,6 +77,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -105,6 +114,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -131,6 +143,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment