Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
7377e711
Commit
7377e711
authored
Feb 12, 2026
by
zhangyue
Browse files
issue/1008: adapt paged_attention_prefill
parent
f46e9f65
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
15 deletions
+69
-15
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
+49
-0
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
+15
-15
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+5
-0
No files found.
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
View file @
7377e711
...
...
@@ -16,17 +16,66 @@ struct OnlineSoftmaxState {
}
};
__device__
__forceinline__
float
warpReduceSum
(
float
x
)
{
#if defined(ENABLE_ILUVATAR_API)
// Iluvatar may use warp size 64; __shfl_sync(0xffffffff) only covers 32 threads.
// Use shared-memory tree reduce for portability across warp sizes.
constexpr
int
kMaxWarps
=
16
;
__shared__
float
_reduce_buf
[
kMaxWarps
*
32
];
const
int
lane
=
threadIdx
.
x
&
31
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
_reduce_buf
[
threadIdx
.
x
]
=
x
;
__syncthreads
();
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
if
(
lane
<
offset
)
{
_reduce_buf
[
warp_id
*
32
+
lane
]
+=
_reduce_buf
[
warp_id
*
32
+
lane
+
offset
];
}
__syncthreads
();
}
return
_reduce_buf
[
warp_id
*
32
];
#else
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
x
+=
__shfl_down_sync
(
0xffffffff
,
x
,
offset
);
}
return
x
;
#endif
}
__device__
__forceinline__
float
warpBroadcast
(
float
x
,
int
src_lane
)
{
#if defined(ENABLE_ILUVATAR_API)
__shared__
float
_bcast_buf
[
16
];
const
int
warp_id
=
threadIdx
.
x
/
32
;
if
((
threadIdx
.
x
&
31
)
==
src_lane
)
{
_bcast_buf
[
warp_id
]
=
x
;
}
__syncthreads
();
return
_bcast_buf
[
warp_id
];
#else
return
__shfl_sync
(
0xffffffff
,
x
,
src_lane
);
#endif
}
__device__
__forceinline__
float
warpReduceMax
(
float
x
)
{
#if defined(ENABLE_ILUVATAR_API)
__shared__
float
_reduce_buf
[
16
*
32
];
const
int
lane
=
threadIdx
.
x
&
31
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
_reduce_buf
[
threadIdx
.
x
]
=
x
;
__syncthreads
();
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
if
(
lane
<
offset
)
{
float
other
=
_reduce_buf
[
warp_id
*
32
+
lane
+
offset
];
float
cur
=
_reduce_buf
[
warp_id
*
32
+
lane
];
_reduce_buf
[
warp_id
*
32
+
lane
]
=
fmaxf
(
cur
,
other
);
}
__syncthreads
();
}
return
_reduce_buf
[
warp_id
*
32
];
#else
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
x
=
fmaxf
(
x
,
__shfl_down_sync
(
0xffffffff
,
x
,
offset
));
}
return
x
;
#endif
}
__device__
__forceinline__
unsigned
int
cvtaToShared
(
const
void
*
ptr
)
{
...
...
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
View file @
7377e711
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API)
|| defined(ENABLE_ILUVATAR_API)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
...
...
@@ -194,8 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
alpha
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
alpha
,
0
);
beta
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
...
...
@@ -233,7 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel(
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
inv_l
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
...
...
@@ -411,8 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
alpha
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
alpha
,
0
);
beta
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
...
...
@@ -450,7 +450,7 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
inv_l
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
...
...
@@ -785,8 +785,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
alpha
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
alpha
,
0
);
beta
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
...
...
@@ -826,7 +826,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
inv_l
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
...
...
@@ -1270,7 +1270,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
inv_l
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
...
...
@@ -1961,8 +1961,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
alpha
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
alpha
,
0
);
beta
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
...
...
@@ -2002,7 +2002,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
inv_l
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
...
...
@@ -2131,7 +2131,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
inv_l
=
op
::
paged_attention
::
cuda
::
warpBroadcast
(
inv_l
,
0
);
const
int64_t
q_token
=
q_start
+
static_cast
<
int64_t
>
(
q_token_local
);
half
*
out_ptr
=
out_
+
q_token
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
...
...
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
7377e711
...
...
@@ -21,6 +21,11 @@ constexpr size_t ceilDiv(size_t a, size_t b) {
}
inline
const
char
*
default_prefill_kernel
(
const
PagedAttentionPrefillInfo
&
info
)
{
// Iluvatar: use warp (stable). Users can override via INFINIOP_FLASH_PREFILL_KERNEL.
#ifdef ENABLE_ILUVATAR_API
(
void
)
info
;
return
"warp"
;
#endif
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment