Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
1c18c046
Commit
1c18c046
authored
Jan 23, 2026
by
PanZezhong
Committed by
wooway777
Jan 27, 2026
Browse files
issue/979 optimize paged attention
parent
97eced0e
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
8209 additions
and
251 deletions
+8209
-251
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
+2085
-0
src/infiniop/ops/paged_attention/info.h
src/infiniop/ops/paged_attention/info.h
+114
-35
src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu
...iniop/ops/paged_attention/nvidia/paged_attention_hd128.cu
+1024
-0
src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu
...finiop/ops/paged_attention/nvidia/paged_attention_hd64.cu
+524
-0
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
...niop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
+320
-105
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
+2361
-0
src/infiniop/ops/paged_attention_prefill/info.h
src/infiniop/ops/paged_attention_prefill/info.h
+124
-42
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+1655
-65
test/infiniop/paged_attention.py
test/infiniop/paged_attention.py
+1
-2
test/infiniop/paged_attention_prefill.py
test/infiniop/paged_attention_prefill.py
+1
-2
No files found.
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
0 → 100644
View file @
1c18c046
#ifndef __PAGED_ATTENTION_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_KERNEL_V2_CUH__
namespace
op
::
paged_attention
::
cuda
{
struct
OnlineSoftmaxState
{
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
__device__
__forceinline__
void
update
(
float
x
,
float
&
alpha
,
float
&
beta
)
{
const
float
m_new
=
fmaxf
(
m
,
x
);
alpha
=
expf
(
m
-
m_new
);
beta
=
expf
(
x
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
};
__device__
__forceinline__
float
warpReduceSum
(
float
x
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
x
+=
__shfl_down_sync
(
0xffffffff
,
x
,
offset
);
}
return
x
;
}
__device__
__forceinline__
float
warpReduceMax
(
float
x
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
x
=
fmaxf
(
x
,
__shfl_down_sync
(
0xffffffff
,
x
,
offset
));
}
return
x
;
}
__device__
__forceinline__
unsigned
int
cvtaToShared
(
const
void
*
ptr
)
{
return
static_cast
<
unsigned
int
>
(
__cvta_generic_to_shared
(
ptr
));
}
__device__
__forceinline__
void
cpAsyncCaSharedGlobal16
(
void
*
dst_shared
,
const
void
*
src_global
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
const
unsigned
int
dst
=
cvtaToShared
(
dst_shared
);
asm
volatile
(
"cp.async.ca.shared.global [%0], [%1], 16;
\n
"
::
"r"
(
dst
),
"l"
(
src_global
));
#else
auto
*
dst
=
reinterpret_cast
<
uint4
*>
(
dst_shared
);
const
auto
*
src
=
reinterpret_cast
<
const
uint4
*>
(
src_global
);
*
dst
=
*
src
;
#endif
}
__device__
__forceinline__
void
cpAsyncCommit
()
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
#endif
}
template
<
int
N
>
__device__
__forceinline__
void
cpAsyncWaitGroup
()
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
N
));
#endif
}
// cp.async.wait_group requires a compile-time immediate, so for small fixed
// stage counts we provide a tiny runtime switch.
__device__
__forceinline__
void
cpAsyncWaitGroupRt
(
int
n
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
if
(
n
<=
0
)
{
cpAsyncWaitGroup
<
0
>
();
}
else
if
(
n
==
1
)
{
cpAsyncWaitGroup
<
1
>
();
}
else
{
// Clamp to 2 because v0.4 CTA kernel uses STAGES=3.
cpAsyncWaitGroup
<
2
>
();
}
#else
(
void
)
n
;
#endif
}
__device__
__forceinline__
void
cpAsyncWaitAll
()
{
cpAsyncWaitGroup
<
0
>
();
}
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
>
__device__
void
flashAttentionDecodeWarpKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
Tindex
*
cache_lens_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
lane
=
threadIdx
.
x
;
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
seq_len
=
static_cast
<
int
>
(
cache_lens_
[
seq_idx
]);
if
(
seq_len
<=
0
)
{
return
;
}
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
const
Tindex
*
block_table
=
block_tables_
+
seq_idx
*
static_cast
<
int
>
(
max_num_blocks_per_seq
);
// q/out are [num_seqs, num_heads, head_size]
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
Tdata
*
out_ptr
=
out_
+
seq_idx
*
o_stride
+
head_idx
*
HEAD_SIZE
;
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
static_cast
<
float
>
(
q_ptr
[
dim
]);
acc
[
i
]
=
0.0
f
;
}
#if defined(__CUDA_ARCH__)
float2
q_reg2
[
DIMS_PER_THREAD
/
2
];
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
q2
=
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__half22float2
(
q2
[
j
]);
}
}
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
q2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__bfloat1622float2
(
q2
[
j
]);
}
}
#endif
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
// Iterate by blocks to avoid per-token division/mod and redundant block_table loads.
// Note: Per-token cp.async prefetching is generally too fine-grained for decode and can regress.
// We keep the warp kernel simple and reserve cp.async pipelining for CTA tile kernels.
int
t_base
=
0
;
for
(
int
logical_block
=
0
;
t_base
<
seq_len
;
++
logical_block
,
t_base
+=
pbs
)
{
int
physical_block
=
0
;
if
(
lane
==
0
)
{
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
}
physical_block
=
__shfl_sync
(
0xffffffff
,
physical_block
,
0
);
const
Tdata
*
k_base
=
k_cache_
+
physical_block
*
k_batch_stride
+
kv_head_idx
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
physical_block
*
v_batch_stride
+
kv_head_idx
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
seq_len
-
t_base
);
for
(
int
token_in_block
=
0
;
token_in_block
<
token_end
;
++
token_in_block
)
{
const
int
t
=
t_base
+
token_in_block
;
const
Tdata
*
k_ptr
=
k_base
+
token_in_block
*
k_row_stride
;
const
Tdata
*
v_ptr
=
v_base
+
token_in_block
*
v_row_stride
;
float
qk
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__half22float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__bfloat1622float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
qk
+=
q_reg
[
i
]
*
static_cast
<
float
>
(
k_ptr
[
dim
]);
}
}
qk
=
warpReduceSum
(
qk
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
lane
==
0
)
{
float
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
t
-
(
seq_len
-
1
)))
*
kLog2e
;
}
const
float
m_new
=
fmaxf
(
m
,
score
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
score
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__half22float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
v_val
=
static_cast
<
float
>
(
v_ptr
[
dim
]);
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
v_val
;
}
}
}
}
float
inv_l
=
0.0
f
;
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
o
=
acc
[
i
]
*
inv_l
;
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
o
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
o
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
o
);
}
}
}
// Split-KV decode (FA2-style): each split scans a shard of KV and writes partial (m, l, acc)
// to workspace, then a combine kernel merges splits into final out.
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
>
__device__
void
flashAttentionDecodeSplitKvWarpKernel
(
float
*
partial_acc
,
// [num_splits, num_seqs, num_heads, head_size]
float
*
partial_m
,
// [num_splits, num_seqs, num_heads]
float
*
partial_l
,
// [num_splits, num_seqs, num_heads]
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
Tindex
*
cache_lens_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
split_idx
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int
lane
=
threadIdx
.
x
;
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
seq_len
=
static_cast
<
int
>
(
cache_lens_
[
seq_idx
]);
if
(
seq_len
<=
0
||
num_splits
<=
0
)
{
return
;
}
// Split the [0, seq_len) range into num_splits contiguous shards.
const
int
shard
=
(
seq_len
+
num_splits
-
1
)
/
num_splits
;
const
int
start
=
split_idx
*
shard
;
const
int
end
=
min
(
seq_len
,
start
+
shard
);
if
(
start
>=
end
)
{
// Empty shard => write neutral element.
const
int
n
=
gridDim
.
y
*
gridDim
.
x
;
const
int
idx
=
(
split_idx
*
n
+
seq_idx
*
gridDim
.
x
+
head_idx
);
if
(
lane
==
0
)
{
partial_m
[
idx
]
=
-
INFINITY
;
partial_l
[
idx
]
=
0.0
f
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
]
=
0.0
f
;
}
return
;
}
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
const
Tindex
*
block_table
=
block_tables_
+
seq_idx
*
static_cast
<
int
>
(
max_num_blocks_per_seq
);
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
static_cast
<
float
>
(
q_ptr
[
dim
]);
acc
[
i
]
=
0.0
f
;
}
#if defined(__CUDA_ARCH__)
float2
q_reg2
[
DIMS_PER_THREAD
/
2
];
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
q2
=
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__half22float2
(
q2
[
j
]);
}
}
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
q2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__bfloat1622float2
(
q2
[
j
]);
}
}
#endif
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
// Scan only [start, end).
int
t
=
start
;
int
logical_block
=
t
/
pbs
;
int
token_in_block
=
t
-
logical_block
*
pbs
;
for
(;
t
<
end
;
++
logical_block
)
{
int
physical_block
=
0
;
if
(
lane
==
0
)
{
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
}
physical_block
=
__shfl_sync
(
0xffffffff
,
physical_block
,
0
);
const
Tdata
*
k_base
=
k_cache_
+
physical_block
*
k_batch_stride
+
kv_head_idx
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
physical_block
*
v_batch_stride
+
kv_head_idx
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
end
-
logical_block
*
pbs
);
for
(;
token_in_block
<
token_end
&&
t
<
end
;
++
token_in_block
,
++
t
)
{
const
Tdata
*
k_ptr
=
k_base
+
token_in_block
*
k_row_stride
;
const
Tdata
*
v_ptr
=
v_base
+
token_in_block
*
v_row_stride
;
float
qk
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__half22float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__bfloat1622float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
qk
+=
q_reg
[
i
]
*
static_cast
<
float
>
(
k_ptr
[
dim
]);
}
}
qk
=
warpReduceSum
(
qk
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
lane
==
0
)
{
float
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
t
-
(
seq_len
-
1
)))
*
kLog2e
;
}
const
float
m_new
=
fmaxf
(
m
,
score
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
score
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__half22float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
v_val
=
static_cast
<
float
>
(
v_ptr
[
dim
]);
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
v_val
;
}
}
}
token_in_block
=
0
;
}
const
int
n
=
gridDim
.
y
*
gridDim
.
x
;
const
int
idx
=
(
split_idx
*
n
+
seq_idx
*
gridDim
.
x
+
head_idx
);
if
(
lane
==
0
)
{
partial_m
[
idx
]
=
m
;
partial_l
[
idx
]
=
l
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
]
=
acc
[
i
];
}
}
template
<
typename
Tdata
,
int
HEAD_SIZE
>
__device__
void
flashAttentionDecodeSplitKvCombineWarpKernel
(
Tdata
*
out_
,
const
float
*
partial_acc
,
// [num_splits, num_seqs, num_heads, head_size]
const
float
*
partial_m
,
// [num_splits, num_seqs, num_heads]
const
float
*
partial_l
,
// [num_splits, num_seqs, num_heads]
int
num_splits
,
ptrdiff_t
o_stride
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
lane
=
threadIdx
.
x
;
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
n
=
gridDim
.
y
*
gridDim
.
x
;
const
int
base
=
(
seq_idx
*
gridDim
.
x
+
head_idx
);
float
m
=
-
INFINITY
;
if
(
lane
==
0
)
{
for
(
int
s
=
0
;
s
<
num_splits
;
++
s
)
{
m
=
fmaxf
(
m
,
partial_m
[
s
*
n
+
base
]);
}
}
m
=
__shfl_sync
(
0xffffffff
,
m
,
0
);
float
l
=
0.0
f
;
if
(
lane
==
0
)
{
for
(
int
s
=
0
;
s
<
num_splits
;
++
s
)
{
const
float
ms
=
partial_m
[
s
*
n
+
base
];
const
float
ls
=
partial_l
[
s
*
n
+
base
];
if
(
ls
>
0.0
f
)
{
l
+=
ls
*
exp2f
(
ms
-
m
);
}
}
}
l
=
__shfl_sync
(
0xffffffff
,
l
,
0
);
const
float
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
// Combine acc for each dim.
Tdata
*
out_ptr
=
out_
+
seq_idx
*
o_stride
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
float
acc
=
0.0
f
;
for
(
int
s
=
0
;
s
<
num_splits
;
++
s
)
{
const
float
ms
=
partial_m
[
s
*
n
+
base
];
const
float
w
=
exp2f
(
ms
-
m
);
acc
+=
partial_acc
[(
s
*
n
+
base
)
*
HEAD_SIZE
+
dim
]
*
w
;
}
const
float
o
=
acc
*
inv_l
;
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
o
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
o
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
o
);
}
}
}
// Split-KV decode with a CTA tile kernel (FA2-style): each CTA scans a shard of KV,
// writes partial (m, l, acc) to workspace, then a combine kernel merges splits.
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
,
int
CTA_THREADS
,
int
TOKENS_PER_TILE
>
__device__
void
flashAttentionDecodeSplitKvCtaKernel
(
float
*
partial_acc
,
// [num_splits, num_seqs, num_heads, head_size]
float
*
partial_m
,
// [num_splits, num_seqs, num_heads]
float
*
partial_l
,
// [num_splits, num_seqs, num_heads]
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
Tindex
*
cache_lens_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
constexpr
int
kWarpSize
=
32
;
static_assert
(
CTA_THREADS
%
kWarpSize
==
0
,
"CTA_THREADS must be a multiple of 32."
);
static_assert
(
TOKENS_PER_TILE
>
0
&&
TOKENS_PER_TILE
<=
16
,
"TOKENS_PER_TILE should stay small."
);
constexpr
int
NUM_WARPS
=
CTA_THREADS
/
kWarpSize
;
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
HEAD_SIZE
%
CTA_THREADS
==
0
,
"HEAD_SIZE must be divisible by CTA_THREADS."
);
constexpr
int
kPack
=
HEAD_SIZE
/
CTA_THREADS
;
// 2 (64@32t, 128@64t) or 4 (128@32t)
static_assert
(
kPack
==
2
||
kPack
==
4
,
"v0.4 split-kv CTA kernel supports kPack=2/4 only."
);
constexpr
int
kPackedDims
=
CTA_THREADS
;
constexpr
int
kComputeWarps
=
(
kPackedDims
+
kWarpSize
-
1
)
/
kWarpSize
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
split_idx
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int
tid
=
threadIdx
.
x
;
const
int
lane
=
tid
%
kWarpSize
;
const
int
warp_id
=
tid
/
kWarpSize
;
const
int
seq_len
=
static_cast
<
int
>
(
cache_lens_
[
seq_idx
]);
if
(
seq_len
<=
0
||
num_splits
<=
0
)
{
return
;
}
// Split the [0, seq_len) range into num_splits contiguous shards.
const
int
shard
=
(
seq_len
+
num_splits
-
1
)
/
num_splits
;
const
int
start
=
split_idx
*
shard
;
const
int
end
=
min
(
seq_len
,
start
+
shard
);
const
int
n
=
gridDim
.
y
*
gridDim
.
x
;
const
int
idx
=
(
split_idx
*
n
+
seq_idx
*
gridDim
.
x
+
head_idx
);
if
(
start
>=
end
)
{
// Empty shard => write neutral element.
if
(
tid
==
0
)
{
partial_m
[
idx
]
=
-
INFINITY
;
partial_l
[
idx
]
=
0.0
f
;
}
const
int
dim
=
tid
*
kPack
;
if
constexpr
(
kPack
==
2
)
{
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
0
]
=
0.0
f
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
1
]
=
0.0
f
;
}
else
{
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
0
]
=
0.0
f
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
1
]
=
0.0
f
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
2
]
=
0.0
f
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
3
]
=
0.0
f
;
}
return
;
}
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
Tindex
*
block_table
=
block_tables_
+
seq_idx
*
static_cast
<
int
>
(
max_num_blocks_per_seq
);
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
const
int
dim
=
tid
*
kPack
;
float
q0
=
0.0
f
,
q1
=
0.0
f
,
q2
=
0.0
f
,
q3
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
half2
qh2
=
*
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim
);
const
float2
qf
=
__half22float2
(
qh2
);
q0
=
qf
.
x
;
q1
=
qf
.
y
;
}
else
{
const
half2
qh2_0
=
*
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim
+
0
);
const
half2
qh2_1
=
*
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim
+
2
);
const
float2
qf0
=
__half22float2
(
qh2_0
);
const
float2
qf1
=
__half22float2
(
qh2_1
);
q0
=
qf0
.
x
;
q1
=
qf0
.
y
;
q2
=
qf1
.
x
;
q3
=
qf1
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
__nv_bfloat162
qb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim
);
const
float2
qf
=
__bfloat1622float2
(
qb2
);
q0
=
qf
.
x
;
q1
=
qf
.
y
;
}
else
{
const
__nv_bfloat162
qb2_0
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim
+
0
);
const
__nv_bfloat162
qb2_1
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim
+
2
);
const
float2
qf0
=
__bfloat1622float2
(
qb2_0
);
const
float2
qf1
=
__bfloat1622float2
(
qb2_1
);
q0
=
qf0
.
x
;
q1
=
qf0
.
y
;
q2
=
qf1
.
x
;
q3
=
qf1
.
y
;
}
}
else
#endif
{
q0
=
static_cast
<
float
>
(
q_ptr
[
dim
+
0
]);
q1
=
static_cast
<
float
>
(
q_ptr
[
dim
+
1
]);
if
constexpr
(
kPack
==
4
)
{
q2
=
static_cast
<
float
>
(
q_ptr
[
dim
+
2
]);
q3
=
static_cast
<
float
>
(
q_ptr
[
dim
+
3
]);
}
}
float
acc0
=
0.0
f
,
acc1
=
0.0
f
,
acc2
=
0.0
f
,
acc3
=
0.0
f
;
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
__shared__
float
warp_sums
[
TOKENS_PER_TILE
][
kComputeWarps
];
__shared__
float
alpha_shared
;
__shared__
float
weights_shared
[
TOKENS_PER_TILE
];
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
static_assert
(
sizeof
(
Tdata
)
==
2
,
"CTA split-kv kernel assumes fp16/bf16."
);
constexpr
int
CHUNK_ELEMS
=
8
;
// 8 * 2 bytes = 16 bytes.
constexpr
int
CHUNKS
=
HEAD_SIZE
/
CHUNK_ELEMS
;
constexpr
int
LOADS_PER_TILE
=
CHUNKS
*
TOKENS_PER_TILE
;
constexpr
int
STAGES
=
3
;
__shared__
__align__
(
16
)
Tdata
sh_k
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
__shared__
__align__
(
16
)
Tdata
sh_v
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
const
int
first_block
=
start
/
pbs
;
const
int
last_block
=
(
end
-
1
)
/
pbs
;
for
(
int
logical_block
=
first_block
;
logical_block
<=
last_block
;
++
logical_block
)
{
const
int
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
const
Tdata
*
k_base
=
k_cache_
+
physical_block
*
k_batch_stride
+
kv_head_idx
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
physical_block
*
v_batch_stride
+
kv_head_idx
*
v_head_stride
;
const
int
t_base
=
logical_block
*
pbs
;
const
int
token_begin
=
(
logical_block
==
first_block
)
?
(
start
-
t_base
)
:
0
;
const
int
token_end
=
(
logical_block
==
last_block
)
?
(
end
-
t_base
)
:
pbs
;
const
int
token_count
=
token_end
-
token_begin
;
if
(
token_count
<=
0
)
{
continue
;
}
const
int
num_tiles
=
(
token_count
+
TOKENS_PER_TILE
-
1
)
/
TOKENS_PER_TILE
;
int
pending_groups
=
0
;
const
int
preload
=
min
(
STAGES
,
num_tiles
);
for
(
int
ti
=
0
;
ti
<
preload
;
++
ti
)
{
const
int
token_in_block
=
token_begin
+
ti
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
CTA_THREADS
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
tile_n
)
{
const
Tdata
*
k_src
=
k_base
+
(
token_in_block
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
(
token_in_block
+
tok
)
*
v_row_stride
+
off
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
ti
][
tok
][
off
],
k_src
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
ti
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
cpAsyncCommit
();
++
pending_groups
;
}
int
desired_pending
=
pending_groups
-
1
;
if
(
desired_pending
<
0
)
{
desired_pending
=
0
;
}
if
(
desired_pending
>
(
STAGES
-
1
))
{
desired_pending
=
(
STAGES
-
1
);
}
cpAsyncWaitGroupRt
(
desired_pending
);
pending_groups
=
desired_pending
;
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
for
(
int
tile_idx
=
0
;
tile_idx
<
num_tiles
;
++
tile_idx
)
{
const
int
buf
=
tile_idx
%
STAGES
;
const
int
token_in_block
=
token_begin
+
tile_idx
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
float
partial
[
TOKENS_PER_TILE
];
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
if
(
j
<
tile_n
)
{
float
k0
=
0.0
f
,
k1
=
0.0
f
,
k2
=
0.0
f
,
k3
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
half2
kh2
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_k
[
buf
][
j
][
dim
]);
const
float2
kf
=
__half22float2
(
kh2
);
k0
=
kf
.
x
;
k1
=
kf
.
y
;
}
else
{
const
half2
kh2_0
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_k
[
buf
][
j
][
dim
+
0
]);
const
half2
kh2_1
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_k
[
buf
][
j
][
dim
+
2
]);
const
float2
kf0
=
__half22float2
(
kh2_0
);
const
float2
kf1
=
__half22float2
(
kh2_1
);
k0
=
kf0
.
x
;
k1
=
kf0
.
y
;
k2
=
kf1
.
x
;
k3
=
kf1
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
__nv_bfloat162
kb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_k
[
buf
][
j
][
dim
]);
const
float2
kf
=
__bfloat1622float2
(
kb2
);
k0
=
kf
.
x
;
k1
=
kf
.
y
;
}
else
{
const
__nv_bfloat162
kb2_0
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_k
[
buf
][
j
][
dim
+
0
]);
const
__nv_bfloat162
kb2_1
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_k
[
buf
][
j
][
dim
+
2
]);
const
float2
kf0
=
__bfloat1622float2
(
kb2_0
);
const
float2
kf1
=
__bfloat1622float2
(
kb2_1
);
k0
=
kf0
.
x
;
k1
=
kf0
.
y
;
k2
=
kf1
.
x
;
k3
=
kf1
.
y
;
}
}
else
#endif
{
k0
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
0
]);
k1
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
1
]);
if
constexpr
(
kPack
==
4
)
{
k2
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
2
]);
k3
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
3
]);
}
}
if
constexpr
(
kPack
==
2
)
{
partial
[
j
]
=
fmaf
(
q0
,
k0
,
q1
*
k1
);
}
else
{
partial
[
j
]
=
fmaf
(
q0
,
k0
,
fmaf
(
q1
,
k1
,
fmaf
(
q2
,
k2
,
q3
*
k3
)));
}
}
else
{
partial
[
j
]
=
0.0
f
;
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
const
float
sum
=
warpReduceSum
(
partial
[
j
]);
if
(
lane
==
0
&&
warp_id
<
kComputeWarps
)
{
warp_sums
[
j
][
warp_id
]
=
sum
;
}
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
if
(
warp_id
==
0
)
{
float
score
=
-
INFINITY
;
if
(
lane
<
TOKENS_PER_TILE
&&
lane
<
tile_n
)
{
float
qk
=
0.0
f
;
#pragma unroll
for
(
int
w
=
0
;
w
<
kComputeWarps
;
++
w
)
{
qk
+=
warp_sums
[
lane
][
w
];
}
const
int
t
=
t_base
+
token_in_block
+
lane
;
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
t
-
(
seq_len
-
1
)))
*
kLog2e
;
}
}
float
tile_max
=
warpReduceMax
(
score
);
tile_max
=
__shfl_sync
(
0xffffffff
,
tile_max
,
0
);
float
m_new
=
0.0
f
;
if
(
lane
==
0
)
{
m_new
=
fmaxf
(
m
,
tile_max
);
}
m_new
=
__shfl_sync
(
0xffffffff
,
m_new
,
0
);
float
w
=
0.0
f
;
if
(
lane
<
TOKENS_PER_TILE
&&
lane
<
tile_n
)
{
w
=
exp2f
(
score
-
m_new
);
}
if
(
lane
<
TOKENS_PER_TILE
)
{
weights_shared
[
lane
]
=
(
lane
<
tile_n
)
?
w
:
0.0
f
;
}
const
float
tile_sum
=
warpReduceSum
(
w
);
if
(
lane
==
0
)
{
const
float
alpha
=
exp2f
(
m
-
m_new
);
alpha_shared
=
alpha
;
l
=
l
*
alpha
+
tile_sum
;
m
=
m_new
;
}
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
const
float
alpha
=
alpha_shared
;
float
sum_wv0
=
0.0
f
,
sum_wv1
=
0.0
f
,
sum_wv2
=
0.0
f
,
sum_wv3
=
0.0
f
;
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
const
float
w
=
weights_shared
[
j
];
float
v0
=
0.0
f
,
v1
=
0.0
f
,
v2
=
0.0
f
,
v3
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
half2
vh2
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_v
[
buf
][
j
][
dim
]);
const
float2
vf
=
__half22float2
(
vh2
);
v0
=
vf
.
x
;
v1
=
vf
.
y
;
}
else
{
const
half2
vh2_0
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_v
[
buf
][
j
][
dim
+
0
]);
const
half2
vh2_1
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_v
[
buf
][
j
][
dim
+
2
]);
const
float2
vf0
=
__half22float2
(
vh2_0
);
const
float2
vf1
=
__half22float2
(
vh2_1
);
v0
=
vf0
.
x
;
v1
=
vf0
.
y
;
v2
=
vf1
.
x
;
v3
=
vf1
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
__nv_bfloat162
vb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_v
[
buf
][
j
][
dim
]);
const
float2
vf
=
__bfloat1622float2
(
vb2
);
v0
=
vf
.
x
;
v1
=
vf
.
y
;
}
else
{
const
__nv_bfloat162
vb2_0
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_v
[
buf
][
j
][
dim
+
0
]);
const
__nv_bfloat162
vb2_1
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_v
[
buf
][
j
][
dim
+
2
]);
const
float2
vf0
=
__bfloat1622float2
(
vb2_0
);
const
float2
vf1
=
__bfloat1622float2
(
vb2_1
);
v0
=
vf0
.
x
;
v1
=
vf0
.
y
;
v2
=
vf1
.
x
;
v3
=
vf1
.
y
;
}
}
else
#endif
{
v0
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
0
]);
v1
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
1
]);
if
constexpr
(
kPack
==
4
)
{
v2
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
2
]);
v3
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
3
]);
}
}
sum_wv0
=
fmaf
(
w
,
v0
,
sum_wv0
);
sum_wv1
=
fmaf
(
w
,
v1
,
sum_wv1
);
if
constexpr
(
kPack
==
4
)
{
sum_wv2
=
fmaf
(
w
,
v2
,
sum_wv2
);
sum_wv3
=
fmaf
(
w
,
v3
,
sum_wv3
);
}
}
acc0
=
acc0
*
alpha
+
sum_wv0
;
acc1
=
acc1
*
alpha
+
sum_wv1
;
if
constexpr
(
kPack
==
4
)
{
acc2
=
acc2
*
alpha
+
sum_wv2
;
acc3
=
acc3
*
alpha
+
sum_wv3
;
}
const
int
prefetch_tile
=
tile_idx
+
STAGES
;
if
(
prefetch_tile
<
num_tiles
)
{
const
int
token_prefetch
=
token_begin
+
prefetch_tile
*
TOKENS_PER_TILE
;
const
int
prefetch_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_prefetch
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
CTA_THREADS
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
prefetch_n
)
{
const
Tdata
*
k_src
=
k_base
+
(
token_prefetch
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
(
token_prefetch
+
tok
)
*
v_row_stride
+
off
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
buf
][
tok
][
off
],
k_src
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
buf
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
cpAsyncCommit
();
++
pending_groups
;
}
if
(
tile_idx
+
1
<
num_tiles
)
{
int
desired_pending2
=
pending_groups
-
1
;
if
(
desired_pending2
<
0
)
{
desired_pending2
=
0
;
}
if
(
desired_pending2
>
(
STAGES
-
1
))
{
desired_pending2
=
(
STAGES
-
1
);
}
cpAsyncWaitGroupRt
(
desired_pending2
);
pending_groups
=
desired_pending2
;
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
}
}
cpAsyncWaitAll
();
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
}
if
(
tid
==
0
)
{
partial_m
[
idx
]
=
m
;
partial_l
[
idx
]
=
l
;
}
if
constexpr
(
kPack
==
2
)
{
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
0
]
=
acc0
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
1
]
=
acc1
;
}
else
{
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
0
]
=
acc0
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
1
]
=
acc1
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
2
]
=
acc2
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
+
3
]
=
acc3
;
}
}
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
>
__device__
void
flashAttentionDecodeCtaPipelinedKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
Tindex
*
cache_lens_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
NUM_WARPS
=
HEAD_SIZE
/
kWarpSize
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int
lane
=
tid
%
kWarpSize
;
const
int
warp_id
=
tid
/
kWarpSize
;
const
int
seq_len
=
static_cast
<
int
>
(
cache_lens_
[
seq_idx
]);
if
(
seq_len
<=
0
)
{
return
;
}
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
const
Tindex
*
block_table
=
block_tables_
+
seq_idx
*
static_cast
<
int
>
(
max_num_blocks_per_seq
);
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
Tdata
*
out_ptr
=
out_
+
seq_idx
*
o_stride
+
head_idx
*
HEAD_SIZE
;
const
float
q_val
=
static_cast
<
float
>
(
q_ptr
[
tid
]);
float
acc
=
0.0
f
;
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
__shared__
Tdata
sh_k
[
2
][
HEAD_SIZE
];
__shared__
Tdata
sh_v
[
2
][
HEAD_SIZE
];
__shared__
float
warp_sums
[
NUM_WARPS
];
__shared__
float
alpha_s
;
__shared__
float
beta_s
;
__shared__
int
physical_block_s
;
constexpr
int
CHUNK_ELEMS
=
8
;
// 8 * 2 bytes = 16 bytes.
constexpr
int
CHUNKS
=
HEAD_SIZE
/
CHUNK_ELEMS
;
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
// Prefetch the very first token.
int
buf
=
0
;
int
t_base
=
0
;
int
token_in_block
=
0
;
int
logical_block
=
0
;
{
if
(
tid
==
0
)
{
physical_block_s
=
static_cast
<
int
>
(
block_table
[
0
]);
}
__syncthreads
();
const
Tdata
*
k_base
=
k_cache_
+
physical_block_s
*
k_batch_stride
+
kv_head_idx
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
physical_block_s
*
v_batch_stride
+
kv_head_idx
*
v_head_stride
;
if
(
tid
<
CHUNKS
)
{
const
int
off
=
tid
*
CHUNK_ELEMS
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
buf
][
off
],
(
k_base
+
0
*
k_row_stride
)
+
off
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
buf
][
off
],
(
v_base
+
0
*
v_row_stride
)
+
off
);
}
cpAsyncCommit
();
cpAsyncWaitAll
();
__syncthreads
();
}
for
(
int
t
=
0
;
t
<
seq_len
;
++
t
)
{
// Compute current token location within paged KV.
const
int
next_t
=
t
+
1
;
const
bool
has_next
=
next_t
<
seq_len
;
if
(
has_next
)
{
const
int
next_block
=
next_t
/
pbs
;
const
int
next_in_block
=
next_t
-
next_block
*
pbs
;
if
(
next_block
!=
logical_block
)
{
logical_block
=
next_block
;
if
(
tid
==
0
)
{
physical_block_s
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
}
__syncthreads
();
}
const
Tdata
*
k_base
=
k_cache_
+
physical_block_s
*
k_batch_stride
+
kv_head_idx
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
physical_block_s
*
v_batch_stride
+
kv_head_idx
*
v_head_stride
;
const
Tdata
*
k_src
=
k_base
+
next_in_block
*
k_row_stride
;
const
Tdata
*
v_src
=
v_base
+
next_in_block
*
v_row_stride
;
if
(
tid
<
CHUNKS
)
{
const
int
off
=
tid
*
CHUNK_ELEMS
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
buf
^
1
][
off
],
k_src
+
off
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
buf
^
1
][
off
],
v_src
+
off
);
}
cpAsyncCommit
();
}
// Dot: each thread handles one dim, reduce across head dim.
const
float
k_val
=
static_cast
<
float
>
(
sh_k
[
buf
][
tid
]);
float
partial
=
q_val
*
k_val
;
float
warp_sum
=
warpReduceSum
(
partial
);
if
(
lane
==
0
)
{
warp_sums
[
warp_id
]
=
warp_sum
;
}
__syncthreads
();
float
qk
=
0.0
f
;
if
(
warp_id
==
0
)
{
float
v
=
(
lane
<
NUM_WARPS
)
?
warp_sums
[
lane
]
:
0.0
f
;
v
=
warpReduceSum
(
v
);
if
(
lane
==
0
)
{
qk
=
v
;
float
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
t
-
(
seq_len
-
1
)))
*
kLog2e
;
}
const
float
m_new
=
fmaxf
(
m
,
score
);
const
float
alpha
=
exp2f
(
m
-
m_new
);
const
float
beta
=
exp2f
(
score
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
alpha_s
=
alpha
;
beta_s
=
beta
;
}
}
__syncthreads
();
const
float
alpha
=
alpha_s
;
const
float
beta
=
beta_s
;
const
float
v_val
=
static_cast
<
float
>
(
sh_v
[
buf
][
tid
]);
acc
=
acc
*
alpha
+
beta
*
v_val
;
if
(
has_next
)
{
cpAsyncWaitAll
();
__syncthreads
();
buf
^=
1
;
}
}
__shared__
float
inv_l_s
;
if
(
tid
==
0
)
{
inv_l_s
=
1.0
f
/
(
l
+
1e-6
f
);
}
__syncthreads
();
out_ptr
[
tid
]
=
static_cast
<
Tdata
>
(
acc
*
inv_l_s
);
}
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
,
int
CTA_THREADS
,
int
TOKENS_PER_TILE
>
__device__
void
flashAttentionDecodeCtaKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
Tindex
*
cache_lens_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
constexpr
int
kWarpSize
=
32
;
static_assert
(
CTA_THREADS
%
kWarpSize
==
0
,
"CTA_THREADS must be a multiple of 32."
);
static_assert
(
TOKENS_PER_TILE
>
0
&&
TOKENS_PER_TILE
<=
16
,
"TOKENS_PER_TILE should stay small."
);
constexpr
int
NUM_WARPS
=
CTA_THREADS
/
kWarpSize
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int
lane
=
tid
%
kWarpSize
;
const
int
warp_id
=
tid
/
kWarpSize
;
// Each thread owns a small packed vector of head dims. This lets us shrink the
// CTA to 1-2 warps and reduce block-wide synchronization overhead.
static_assert
(
HEAD_SIZE
%
CTA_THREADS
==
0
,
"HEAD_SIZE must be divisible by CTA_THREADS."
);
constexpr
int
kPack
=
HEAD_SIZE
/
CTA_THREADS
;
// 2 (64@32t, 128@64t) or 4 (128@32t)
static_assert
(
kPack
==
2
||
kPack
==
4
,
"v0.4 CTA tile kernel supports kPack=2/4 only."
);
constexpr
int
kPackedDims
=
CTA_THREADS
;
constexpr
int
kComputeWarps
=
(
kPackedDims
+
kWarpSize
-
1
)
/
kWarpSize
;
const
int
dim
=
tid
*
kPack
;
const
int
seq_len
=
static_cast
<
int
>
(
cache_lens_
[
seq_idx
]);
if
(
seq_len
<=
0
)
{
return
;
}
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
Tindex
*
block_table
=
block_tables_
+
seq_idx
*
static_cast
<
int
>
(
max_num_blocks_per_seq
);
// q/out are [num_seqs, num_heads, head_size]
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
Tdata
*
out_ptr
=
out_
+
seq_idx
*
o_stride
+
head_idx
*
HEAD_SIZE
;
float
q0
=
0.0
f
;
float
q1
=
0.0
f
;
float
q2
=
0.0
f
;
float
q3
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
half2
qh2
=
*
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim
);
const
float2
qf
=
__half22float2
(
qh2
);
q0
=
qf
.
x
;
q1
=
qf
.
y
;
}
else
{
const
half2
qh2_0
=
*
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim
+
0
);
const
half2
qh2_1
=
*
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim
+
2
);
const
float2
qf0
=
__half22float2
(
qh2_0
);
const
float2
qf1
=
__half22float2
(
qh2_1
);
q0
=
qf0
.
x
;
q1
=
qf0
.
y
;
q2
=
qf1
.
x
;
q3
=
qf1
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
__nv_bfloat162
qb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim
);
const
float2
qf
=
__bfloat1622float2
(
qb2
);
q0
=
qf
.
x
;
q1
=
qf
.
y
;
}
else
{
const
__nv_bfloat162
qb2_0
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim
+
0
);
const
__nv_bfloat162
qb2_1
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim
+
2
);
const
float2
qf0
=
__bfloat1622float2
(
qb2_0
);
const
float2
qf1
=
__bfloat1622float2
(
qb2_1
);
q0
=
qf0
.
x
;
q1
=
qf0
.
y
;
q2
=
qf1
.
x
;
q3
=
qf1
.
y
;
}
}
else
#endif
{
q0
=
static_cast
<
float
>
(
q_ptr
[
dim
+
0
]);
q1
=
static_cast
<
float
>
(
q_ptr
[
dim
+
1
]);
if
constexpr
(
kPack
==
4
)
{
q2
=
static_cast
<
float
>
(
q_ptr
[
dim
+
2
]);
q3
=
static_cast
<
float
>
(
q_ptr
[
dim
+
3
]);
}
}
float
acc0
=
0.0
f
;
float
acc1
=
0.0
f
;
float
acc2
=
0.0
f
;
float
acc3
=
0.0
f
;
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
// Only the compute warps contribute QK partial sums. Keeping this array
// compact reduces shared-memory traffic and bank pressure.
__shared__
float
warp_sums
[
TOKENS_PER_TILE
][
kComputeWarps
];
__shared__
float
alpha_shared
;
__shared__
float
weights_shared
[
TOKENS_PER_TILE
];
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
static_assert
(
sizeof
(
Tdata
)
==
2
,
"CTA tile kernel assumes 16B chunks map to 8 elements for fp16/bf16."
);
constexpr
int
CHUNK_ELEMS
=
8
;
// 8 * 2 bytes = 16 bytes.
constexpr
int
CHUNKS
=
HEAD_SIZE
/
CHUNK_ELEMS
;
constexpr
int
LOADS_PER_TILE
=
CHUNKS
*
TOKENS_PER_TILE
;
// Multi-stage cp.async pipeline. Using >= 3 stages allows us to keep
// multiple groups in-flight and overlap global->shared copies with compute.
constexpr
int
STAGES
=
3
;
__shared__
__align__
(
16
)
Tdata
sh_k
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
__shared__
__align__
(
16
)
Tdata
sh_v
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
int
t_base
=
0
;
for
(
int
logical_block
=
0
;
t_base
<
seq_len
;
++
logical_block
,
t_base
+=
pbs
)
{
const
int
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
const
Tdata
*
k_base
=
k_cache_
+
physical_block
*
k_batch_stride
+
kv_head_idx
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
physical_block
*
v_batch_stride
+
kv_head_idx
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
seq_len
-
t_base
);
const
int
num_tiles
=
(
token_end
+
TOKENS_PER_TILE
-
1
)
/
TOKENS_PER_TILE
;
if
(
num_tiles
<=
0
)
{
continue
;
}
int
pending_groups
=
0
;
const
int
preload
=
min
(
STAGES
,
num_tiles
);
for
(
int
ti
=
0
;
ti
<
preload
;
++
ti
)
{
const
int
token_in_block
=
ti
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
CTA_THREADS
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
tile_n
)
{
const
Tdata
*
k_src
=
k_base
+
(
token_in_block
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
(
token_in_block
+
tok
)
*
v_row_stride
+
off
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
ti
][
tok
][
off
],
k_src
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
ti
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
cpAsyncCommit
();
++
pending_groups
;
}
// Ensure tile 0 is ready. We want to keep up to (STAGES - 1) groups
// in flight for overlap, but still make forward progress in the tail
// when we stop issuing new prefetch groups.
int
desired_pending
=
pending_groups
-
1
;
if
(
desired_pending
<
0
)
{
desired_pending
=
0
;
}
if
(
desired_pending
>
(
STAGES
-
1
))
{
desired_pending
=
(
STAGES
-
1
);
}
cpAsyncWaitGroupRt
(
desired_pending
);
pending_groups
=
desired_pending
;
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
for
(
int
tile_idx
=
0
;
tile_idx
<
num_tiles
;
++
tile_idx
)
{
const
int
buf
=
tile_idx
%
STAGES
;
const
int
token_in_block
=
tile_idx
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
float
partial
[
TOKENS_PER_TILE
];
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
if
(
j
<
tile_n
)
{
float
k0
=
0.0
f
;
float
k1
=
0.0
f
;
float
k2
=
0.0
f
;
float
k3
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
half2
kh2
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_k
[
buf
][
j
][
dim
]);
const
float2
kf
=
__half22float2
(
kh2
);
k0
=
kf
.
x
;
k1
=
kf
.
y
;
}
else
{
const
half2
kh2_0
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_k
[
buf
][
j
][
dim
+
0
]);
const
half2
kh2_1
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_k
[
buf
][
j
][
dim
+
2
]);
const
float2
kf0
=
__half22float2
(
kh2_0
);
const
float2
kf1
=
__half22float2
(
kh2_1
);
k0
=
kf0
.
x
;
k1
=
kf0
.
y
;
k2
=
kf1
.
x
;
k3
=
kf1
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
__nv_bfloat162
kb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_k
[
buf
][
j
][
dim
]);
const
float2
kf
=
__bfloat1622float2
(
kb2
);
k0
=
kf
.
x
;
k1
=
kf
.
y
;
}
else
{
const
__nv_bfloat162
kb2_0
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_k
[
buf
][
j
][
dim
+
0
]);
const
__nv_bfloat162
kb2_1
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_k
[
buf
][
j
][
dim
+
2
]);
const
float2
kf0
=
__bfloat1622float2
(
kb2_0
);
const
float2
kf1
=
__bfloat1622float2
(
kb2_1
);
k0
=
kf0
.
x
;
k1
=
kf0
.
y
;
k2
=
kf1
.
x
;
k3
=
kf1
.
y
;
}
}
else
#endif
{
k0
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
0
]);
k1
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
1
]);
if
constexpr
(
kPack
==
4
)
{
k2
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
2
]);
k3
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
3
]);
}
}
if
constexpr
(
kPack
==
2
)
{
partial
[
j
]
=
fmaf
(
q0
,
k0
,
q1
*
k1
);
}
else
{
partial
[
j
]
=
fmaf
(
q0
,
k0
,
fmaf
(
q1
,
k1
,
fmaf
(
q2
,
k2
,
q3
*
k3
)));
}
}
else
{
partial
[
j
]
=
0.0
f
;
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
float
sum
=
warpReduceSum
(
partial
[
j
]);
// Only compute warps contribute to qk; load-only warps would
// otherwise write zeros and increase reduction overhead.
if
(
lane
==
0
&&
warp_id
<
kComputeWarps
)
{
warp_sums
[
j
][
warp_id
]
=
sum
;
}
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
if
(
warp_id
==
0
)
{
// Distribute token-wise score computation across lanes to avoid
// serial loops in lane0. TOKENS_PER_TILE <= 16 by construction.
float
score
=
-
INFINITY
;
if
(
lane
<
TOKENS_PER_TILE
&&
lane
<
tile_n
)
{
float
qk
=
0.0
f
;
#pragma unroll
for
(
int
w
=
0
;
w
<
kComputeWarps
;
++
w
)
{
qk
+=
warp_sums
[
lane
][
w
];
}
const
int
t
=
t_base
+
token_in_block
+
lane
;
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
t
-
(
seq_len
-
1
)))
*
kLog2e
;
}
}
float
tile_max
=
warpReduceMax
(
score
);
tile_max
=
__shfl_sync
(
0xffffffff
,
tile_max
,
0
);
float
m_new
=
0.0
f
;
if
(
lane
==
0
)
{
m_new
=
fmaxf
(
m
,
tile_max
);
}
m_new
=
__shfl_sync
(
0xffffffff
,
m_new
,
0
);
float
w
=
0.0
f
;
if
(
lane
<
TOKENS_PER_TILE
&&
lane
<
tile_n
)
{
w
=
exp2f
(
score
-
m_new
);
}
if
(
lane
<
TOKENS_PER_TILE
)
{
weights_shared
[
lane
]
=
(
lane
<
tile_n
)
?
w
:
0.0
f
;
}
float
tile_sum
=
warpReduceSum
(
w
);
if
(
lane
==
0
)
{
const
float
alpha
=
exp2f
(
m
-
m_new
);
alpha_shared
=
alpha
;
l
=
l
*
alpha
+
tile_sum
;
m
=
m_new
;
}
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
const
float
alpha
=
alpha_shared
;
float
sum_wv0
=
0.0
f
;
float
sum_wv1
=
0.0
f
;
float
sum_wv2
=
0.0
f
;
float
sum_wv3
=
0.0
f
;
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
const
float
w
=
weights_shared
[
j
];
float
v0
=
0.0
f
;
float
v1
=
0.0
f
;
float
v2
=
0.0
f
;
float
v3
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
half2
vh2
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_v
[
buf
][
j
][
dim
]);
const
float2
vf
=
__half22float2
(
vh2
);
v0
=
vf
.
x
;
v1
=
vf
.
y
;
}
else
{
const
half2
vh2_0
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_v
[
buf
][
j
][
dim
+
0
]);
const
half2
vh2_1
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_v
[
buf
][
j
][
dim
+
2
]);
const
float2
vf0
=
__half22float2
(
vh2_0
);
const
float2
vf1
=
__half22float2
(
vh2_1
);
v0
=
vf0
.
x
;
v1
=
vf0
.
y
;
v2
=
vf1
.
x
;
v3
=
vf1
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
if
constexpr
(
kPack
==
2
)
{
const
__nv_bfloat162
vb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_v
[
buf
][
j
][
dim
]);
const
float2
vf
=
__bfloat1622float2
(
vb2
);
v0
=
vf
.
x
;
v1
=
vf
.
y
;
}
else
{
const
__nv_bfloat162
vb2_0
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_v
[
buf
][
j
][
dim
+
0
]);
const
__nv_bfloat162
vb2_1
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_v
[
buf
][
j
][
dim
+
2
]);
const
float2
vf0
=
__bfloat1622float2
(
vb2_0
);
const
float2
vf1
=
__bfloat1622float2
(
vb2_1
);
v0
=
vf0
.
x
;
v1
=
vf0
.
y
;
v2
=
vf1
.
x
;
v3
=
vf1
.
y
;
}
}
else
#endif
{
v0
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
0
]);
v1
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
1
]);
if
constexpr
(
kPack
==
4
)
{
v2
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
2
]);
v3
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
3
]);
}
}
sum_wv0
=
fmaf
(
w
,
v0
,
sum_wv0
);
sum_wv1
=
fmaf
(
w
,
v1
,
sum_wv1
);
if
constexpr
(
kPack
==
4
)
{
sum_wv2
=
fmaf
(
w
,
v2
,
sum_wv2
);
sum_wv3
=
fmaf
(
w
,
v3
,
sum_wv3
);
}
}
acc0
=
acc0
*
alpha
+
sum_wv0
;
acc1
=
acc1
*
alpha
+
sum_wv1
;
if
constexpr
(
kPack
==
4
)
{
acc2
=
acc2
*
alpha
+
sum_wv2
;
acc3
=
acc3
*
alpha
+
sum_wv3
;
}
// Prefetch the tile that will reuse this buffer (STAGES steps ahead).
const
int
prefetch_tile
=
tile_idx
+
STAGES
;
if
(
prefetch_tile
<
num_tiles
)
{
const
int
token_prefetch
=
prefetch_tile
*
TOKENS_PER_TILE
;
const
int
prefetch_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_prefetch
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
CTA_THREADS
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
prefetch_n
)
{
const
Tdata
*
k_src
=
k_base
+
(
token_prefetch
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
(
token_prefetch
+
tok
)
*
v_row_stride
+
off
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
buf
][
tok
][
off
],
k_src
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
buf
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
cpAsyncCommit
();
++
pending_groups
;
}
if
(
tile_idx
+
1
<
num_tiles
)
{
// Before consuming the next tile, ensure at least one group
// completes. In steady state we keep (STAGES - 1) in flight; in
// the tail (no more prefetches) we gradually drain.
int
desired_pending
=
pending_groups
-
1
;
if
(
desired_pending
<
0
)
{
desired_pending
=
0
;
}
if
(
desired_pending
>
(
STAGES
-
1
))
{
desired_pending
=
(
STAGES
-
1
);
}
cpAsyncWaitGroupRt
(
desired_pending
);
pending_groups
=
desired_pending
;
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
}
}
// Drain any in-flight async copies before moving to the next paged block.
cpAsyncWaitAll
();
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
}
__shared__
float
inv_l_shared
;
if
(
tid
==
0
)
{
inv_l_shared
=
1.0
f
/
(
l
+
1e-6
f
);
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
{
const
float
s
=
inv_l_shared
;
const
float
o0
=
acc0
*
s
;
const
float
o1
=
acc1
*
s
;
const
float
o2
=
acc2
*
s
;
const
float
o3
=
acc3
*
s
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
+
0
]
=
__float2half_rn
(
o0
);
out_ptr
[
dim
+
1
]
=
__float2half_rn
(
o1
);
if
constexpr
(
kPack
==
4
)
{
out_ptr
[
dim
+
2
]
=
__float2half_rn
(
o2
);
out_ptr
[
dim
+
3
]
=
__float2half_rn
(
o3
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
+
0
]
=
__float2bfloat16_rn
(
o0
);
out_ptr
[
dim
+
1
]
=
__float2bfloat16_rn
(
o1
);
if
constexpr
(
kPack
==
4
)
{
out_ptr
[
dim
+
2
]
=
__float2bfloat16_rn
(
o2
);
out_ptr
[
dim
+
3
]
=
__float2bfloat16_rn
(
o3
);
}
}
else
#endif
{
out_ptr
[
dim
+
0
]
=
static_cast
<
Tdata
>
(
o0
);
out_ptr
[
dim
+
1
]
=
static_cast
<
Tdata
>
(
o1
);
if
constexpr
(
kPack
==
4
)
{
out_ptr
[
dim
+
2
]
=
static_cast
<
Tdata
>
(
o2
);
out_ptr
[
dim
+
3
]
=
static_cast
<
Tdata
>
(
o3
);
}
}
}
}
// GQA/MQA fused decode kernel: one CTA computes outputs for NGROUPS query heads that
// share the same KV head. This reduces redundant K/V reads when num_heads > num_kv_heads.
//
// v0.4: implemented for head_dim=128 and NGROUPS=4 (common case: 32 Q heads / 8 KV heads).
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
,
int
CTA_THREADS
,
int
TOKENS_PER_TILE
,
int
NGROUPS
>
__device__
void
flashAttentionDecodeCtaGqaKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
Tindex
*
cache_lens_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
==
128
,
"v0.4 GQA fused CTA kernel is implemented for head_size=128 only."
);
static_assert
(
NGROUPS
==
4
,
"v0.4 GQA fused CTA kernel is implemented for NGROUPS=4 only."
);
static_assert
(
CTA_THREADS
%
kWarpSize
==
0
,
"CTA_THREADS must be a multiple of 32."
);
static_assert
(
TOKENS_PER_TILE
>
0
&&
TOKENS_PER_TILE
<=
16
,
"TOKENS_PER_TILE should stay small."
);
constexpr
int
NUM_WARPS
=
CTA_THREADS
/
kWarpSize
;
// Pack dims per thread. For head_dim=128 and CTA_THREADS=64, kPack=2.
static_assert
(
HEAD_SIZE
%
CTA_THREADS
==
0
,
"HEAD_SIZE must be divisible by CTA_THREADS."
);
constexpr
int
kPack
=
HEAD_SIZE
/
CTA_THREADS
;
static_assert
(
kPack
==
2
,
"v0.4 GQA fused CTA kernel expects kPack=2."
);
constexpr
int
kPackedDims
=
CTA_THREADS
;
constexpr
int
kComputeWarps
=
(
kPackedDims
+
kWarpSize
-
1
)
/
kWarpSize
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
kv_head_idx
=
blockIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int
lane
=
tid
%
kWarpSize
;
const
int
warp_id
=
tid
/
kWarpSize
;
const
int
dim
=
tid
*
kPack
;
const
int
seq_len
=
static_cast
<
int
>
(
cache_lens_
[
seq_idx
]);
if
(
seq_len
<=
0
)
{
return
;
}
// v0.4 limitation: alibi slopes are per query head; support can be added later.
if
(
alibi_slopes_
!=
nullptr
)
{
return
;
}
const
Tindex
*
block_table
=
block_tables_
+
seq_idx
*
static_cast
<
int
>
(
max_num_blocks_per_seq
);
// q/out are [num_seqs, num_heads, head_size]. For a KV head, we handle NGROUPS query heads:
// q_head = kv_head * NGROUPS + g
float
q0
[
NGROUPS
];
float
q1
[
NGROUPS
];
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
const
int
q_head
=
kv_head_idx
*
NGROUPS
+
g
;
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
q_head
*
HEAD_SIZE
;
const
half2
qh2
=
*
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim
);
const
float2
qf
=
__half22float2
(
qh2
);
q0
[
g
]
=
qf
.
x
;
q1
[
g
]
=
qf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
const
int
q_head
=
kv_head_idx
*
NGROUPS
+
g
;
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
q_head
*
HEAD_SIZE
;
const
__nv_bfloat162
qb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim
);
const
float2
qf
=
__bfloat1622float2
(
qb2
);
q0
[
g
]
=
qf
.
x
;
q1
[
g
]
=
qf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
const
int
q_head
=
kv_head_idx
*
NGROUPS
+
g
;
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
q_head
*
HEAD_SIZE
;
q0
[
g
]
=
static_cast
<
float
>
(
q_ptr
[
dim
+
0
]);
q1
[
g
]
=
static_cast
<
float
>
(
q_ptr
[
dim
+
1
]);
}
}
float
acc0
[
NGROUPS
];
float
acc1
[
NGROUPS
];
float
m
[
NGROUPS
];
float
l
[
NGROUPS
];
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
acc0
[
g
]
=
0.0
f
;
acc1
[
g
]
=
0.0
f
;
m
[
g
]
=
-
INFINITY
;
l
[
g
]
=
0.0
f
;
}
__shared__
float
warp_sums
[
NGROUPS
][
TOKENS_PER_TILE
][
kComputeWarps
];
__shared__
float
alpha_shared
[
NGROUPS
];
__shared__
float
weights_shared
[
NGROUPS
][
TOKENS_PER_TILE
];
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
static_assert
(
sizeof
(
Tdata
)
==
2
,
"CTA GQA kernel assumes fp16/bf16."
);
constexpr
int
CHUNK_ELEMS
=
8
;
// 8 * 2 bytes = 16 bytes.
constexpr
int
CHUNKS
=
HEAD_SIZE
/
CHUNK_ELEMS
;
constexpr
int
LOADS_PER_TILE
=
CHUNKS
*
TOKENS_PER_TILE
;
constexpr
int
STAGES
=
3
;
__shared__
__align__
(
16
)
Tdata
sh_k
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
__shared__
__align__
(
16
)
Tdata
sh_v
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
int
t_base
=
0
;
for
(
int
logical_block
=
0
;
t_base
<
seq_len
;
++
logical_block
,
t_base
+=
pbs
)
{
const
int
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
const
Tdata
*
k_base
=
k_cache_
+
physical_block
*
k_batch_stride
+
kv_head_idx
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
physical_block
*
v_batch_stride
+
kv_head_idx
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
seq_len
-
t_base
);
const
int
num_tiles
=
(
token_end
+
TOKENS_PER_TILE
-
1
)
/
TOKENS_PER_TILE
;
if
(
num_tiles
<=
0
)
{
continue
;
}
int
pending_groups
=
0
;
const
int
preload
=
min
(
STAGES
,
num_tiles
);
for
(
int
ti
=
0
;
ti
<
preload
;
++
ti
)
{
const
int
token_in_block
=
ti
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
CTA_THREADS
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
tile_n
)
{
const
Tdata
*
k_src
=
k_base
+
(
token_in_block
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
(
token_in_block
+
tok
)
*
v_row_stride
+
off
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
ti
][
tok
][
off
],
k_src
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
ti
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
cpAsyncCommit
();
++
pending_groups
;
}
int
desired_pending
=
pending_groups
-
1
;
if
(
desired_pending
<
0
)
{
desired_pending
=
0
;
}
if
(
desired_pending
>
(
STAGES
-
1
))
{
desired_pending
=
(
STAGES
-
1
);
}
cpAsyncWaitGroupRt
(
desired_pending
);
pending_groups
=
desired_pending
;
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
for
(
int
tile_idx
=
0
;
tile_idx
<
num_tiles
;
++
tile_idx
)
{
const
int
buf
=
tile_idx
%
STAGES
;
const
int
token_in_block
=
tile_idx
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
// Compute QK partial sums for each group and each token in the tile.
float
partial_qk
[
NGROUPS
][
TOKENS_PER_TILE
];
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
if
(
j
<
tile_n
)
{
float
k0
=
0.0
f
;
float
k1
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
half2
kh2
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_k
[
buf
][
j
][
dim
]);
const
float2
kf
=
__half22float2
(
kh2
);
k0
=
kf
.
x
;
k1
=
kf
.
y
;
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
__nv_bfloat162
kb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_k
[
buf
][
j
][
dim
]);
const
float2
kf
=
__bfloat1622float2
(
kb2
);
k0
=
kf
.
x
;
k1
=
kf
.
y
;
}
else
#endif
{
k0
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
0
]);
k1
=
static_cast
<
float
>
(
sh_k
[
buf
][
j
][
dim
+
1
]);
}
partial_qk
[
g
][
j
]
=
fmaf
(
q0
[
g
],
k0
,
q1
[
g
]
*
k1
);
}
else
{
partial_qk
[
g
][
j
]
=
0.0
f
;
}
}
}
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
const
float
sum
=
warpReduceSum
(
partial_qk
[
g
][
j
]);
if
(
lane
==
0
&&
warp_id
<
kComputeWarps
)
{
warp_sums
[
g
][
j
][
warp_id
]
=
sum
;
}
}
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
if
(
warp_id
==
0
)
{
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
float
score
=
-
INFINITY
;
if
(
lane
<
TOKENS_PER_TILE
&&
lane
<
tile_n
)
{
float
qk
=
0.0
f
;
#pragma unroll
for
(
int
w
=
0
;
w
<
kComputeWarps
;
++
w
)
{
qk
+=
warp_sums
[
g
][
lane
][
w
];
}
score
=
qk
*
scale_log2
;
}
float
tile_max
=
warpReduceMax
(
score
);
tile_max
=
__shfl_sync
(
0xffffffff
,
tile_max
,
0
);
float
m_new
=
0.0
f
;
if
(
lane
==
0
)
{
m_new
=
fmaxf
(
m
[
g
],
tile_max
);
}
m_new
=
__shfl_sync
(
0xffffffff
,
m_new
,
0
);
float
w
=
0.0
f
;
if
(
lane
<
TOKENS_PER_TILE
&&
lane
<
tile_n
)
{
w
=
exp2f
(
score
-
m_new
);
}
if
(
lane
<
TOKENS_PER_TILE
)
{
weights_shared
[
g
][
lane
]
=
(
lane
<
tile_n
)
?
w
:
0.0
f
;
}
const
float
tile_sum
=
warpReduceSum
(
w
);
if
(
lane
==
0
)
{
const
float
alpha
=
exp2f
(
m
[
g
]
-
m_new
);
alpha_shared
[
g
]
=
alpha
;
l
[
g
]
=
l
[
g
]
*
alpha
+
tile_sum
;
m
[
g
]
=
m_new
;
}
}
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
float
alpha
[
NGROUPS
];
float
sum_wv0
[
NGROUPS
];
float
sum_wv1
[
NGROUPS
];
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
alpha
[
g
]
=
alpha_shared
[
g
];
sum_wv0
[
g
]
=
0.0
f
;
sum_wv1
[
g
]
=
0.0
f
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
TOKENS_PER_TILE
;
++
j
)
{
float
v0
=
0.0
f
;
float
v1
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
half2
vh2
=
*
reinterpret_cast
<
const
half2
*>
(
&
sh_v
[
buf
][
j
][
dim
]);
const
float2
vf
=
__half22float2
(
vh2
);
v0
=
vf
.
x
;
v1
=
vf
.
y
;
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
__nv_bfloat162
vb2
=
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_v
[
buf
][
j
][
dim
]);
const
float2
vf
=
__bfloat1622float2
(
vb2
);
v0
=
vf
.
x
;
v1
=
vf
.
y
;
}
else
#endif
{
v0
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
0
]);
v1
=
static_cast
<
float
>
(
sh_v
[
buf
][
j
][
dim
+
1
]);
}
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
const
float
w
=
weights_shared
[
g
][
j
];
sum_wv0
[
g
]
=
fmaf
(
w
,
v0
,
sum_wv0
[
g
]);
sum_wv1
[
g
]
=
fmaf
(
w
,
v1
,
sum_wv1
[
g
]);
}
}
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
acc0
[
g
]
=
acc0
[
g
]
*
alpha
[
g
]
+
sum_wv0
[
g
];
acc1
[
g
]
=
acc1
[
g
]
*
alpha
[
g
]
+
sum_wv1
[
g
];
}
const
int
prefetch_tile
=
tile_idx
+
STAGES
;
if
(
prefetch_tile
<
num_tiles
)
{
const
int
token_prefetch
=
prefetch_tile
*
TOKENS_PER_TILE
;
const
int
prefetch_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_prefetch
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
CTA_THREADS
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
prefetch_n
)
{
const
Tdata
*
k_src
=
k_base
+
(
token_prefetch
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
(
token_prefetch
+
tok
)
*
v_row_stride
+
off
;
cpAsyncCaSharedGlobal16
(
&
sh_k
[
buf
][
tok
][
off
],
k_src
);
cpAsyncCaSharedGlobal16
(
&
sh_v
[
buf
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
cpAsyncCommit
();
++
pending_groups
;
}
if
(
tile_idx
+
1
<
num_tiles
)
{
int
desired_pending2
=
pending_groups
-
1
;
if
(
desired_pending2
<
0
)
{
desired_pending2
=
0
;
}
if
(
desired_pending2
>
(
STAGES
-
1
))
{
desired_pending2
=
(
STAGES
-
1
);
}
cpAsyncWaitGroupRt
(
desired_pending2
);
pending_groups
=
desired_pending2
;
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
}
}
cpAsyncWaitAll
();
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
}
// Write outputs for each group.
__shared__
float
inv_l_shared
[
NGROUPS
];
if
(
tid
<
NGROUPS
)
{
inv_l_shared
[
tid
]
=
1.0
f
/
(
l
[
tid
]
+
1e-6
f
);
}
if
constexpr
(
NUM_WARPS
==
1
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
#pragma unroll
for
(
int
g
=
0
;
g
<
NGROUPS
;
++
g
)
{
const
int
q_head
=
kv_head_idx
*
NGROUPS
+
g
;
Tdata
*
out_ptr
=
out_
+
seq_idx
*
o_stride
+
q_head
*
HEAD_SIZE
;
const
float
s
=
inv_l_shared
[
g
];
const
float
o0
=
acc0
[
g
]
*
s
;
const
float
o1
=
acc1
[
g
]
*
s
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
+
0
]
=
__float2half_rn
(
o0
);
out_ptr
[
dim
+
1
]
=
__float2half_rn
(
o1
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
+
0
]
=
__float2bfloat16_rn
(
o0
);
out_ptr
[
dim
+
1
]
=
__float2bfloat16_rn
(
o1
);
}
else
#endif
{
out_ptr
[
dim
+
0
]
=
static_cast
<
Tdata
>
(
o0
);
out_ptr
[
dim
+
1
]
=
static_cast
<
Tdata
>
(
o1
);
}
}
}
}
// namespace op::paged_attention::cuda
#endif // __PAGED_ATTENTION_KERNEL_V2_CUH__
src/infiniop/ops/paged_attention/info.h
View file @
1c18c046
...
...
@@ -13,92 +13,171 @@ class PagedAttentionInfo {
PagedAttentionInfo
()
=
default
;
public:
// --- Data Types and Scale ---
infiniDtype_t
dtype
;
infiniDtype_t
index_dtype
;
float
scale
;
// --- Shape Dimensions ---
size_t
num_seqs
;
size_t
num_heads
;
size_t
num_kv_heads
;
size_t
head_size
;
size_t
block_size
;
size_t
page_
block_size
;
size_t
max_num_blocks_per_seq
;
// --- Strides for Memory Layout ---
ptrdiff_t
q_stride
;
ptrdiff_t
kv_block_stride
;
ptrdiff_t
kv_head_stride
;
ptrdiff_t
k_batch_stride
;
ptrdiff_t
k_row_stride
;
ptrdiff_t
k_head_stride
;
ptrdiff_t
v_batch_stride
;
ptrdiff_t
v_row_stride
;
ptrdiff_t
v_head_stride
;
ptrdiff_t
o_stride
;
ptrdiff_t
block_table_batch_stride
;
ptrdiff_t
cache_lens_stride
;
static
utils
::
Result
<
PagedAttentionInfo
>
create
(
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cache
_lens_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
dtype
=
q_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
);
if
(
out_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
q_desc
->
ndim
()
!=
3
||
k_cache_desc
->
ndim
()
<
4
||
v_cache_desc
->
ndim
()
<
4
||
block_tables_desc
->
ndim
()
!=
2
||
seq_lens_desc
->
ndim
()
!=
1
)
{
if
(
q_desc
->
ndim
()
!=
3
||
out_desc
->
ndim
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
k_cache_desc
->
ndim
()
!=
4
||
v_cache_desc
->
ndim
()
!=
4
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
block_tables_desc
->
ndim
()
!=
2
||
cache_lens_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
block_tables_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
CHECK_OR_RETURN
(
q_desc
->
stride
(
2
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
out_desc
->
stride
(
2
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
k_cache_desc
->
stride
(
3
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
v_cache_desc
->
stride
(
3
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
const
auto
block_tables_dt
=
block_tables_desc
->
dtype
();
const
auto
cache_lens_dt
=
cache_lens_desc
->
dtype
();
const
bool
debug_dtype
=
(
std
::
getenv
(
"INFINIOP_FLASH_DEBUG_DTYPE"
)
!=
nullptr
);
const
bool
block_tables_ok
=
(
block_tables_dt
==
INFINI_DTYPE_I64
)
||
(
block_tables_dt
==
INFINI_DTYPE_I32
)
||
(
block_tables_dt
==
INFINI_DTYPE_U32
);
const
bool
cache_lens_ok
=
(
cache_lens_dt
==
INFINI_DTYPE_I64
)
||
(
cache_lens_dt
==
INFINI_DTYPE_I32
)
||
(
cache_lens_dt
==
INFINI_DTYPE_U32
);
if
(
!
(
block_tables_ok
&&
cache_lens_ok
))
{
if
(
debug_dtype
)
{
std
::
fprintf
(
stderr
,
"[flash_attention] Bad index dtype: block_tables=%d cache_lens=%d (expected I32/I64/U32)
\n
"
,
static_cast
<
int
>
(
block_tables_dt
),
static_cast
<
int
>
(
cache_lens_dt
));
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
block_tables_dt
!=
cache_lens_dt
)
{
// Keep them consistent to simplify backend dispatch.
if
(
debug_dtype
)
{
std
::
fprintf
(
stderr
,
"[flash_attention] Mismatched index dtype: block_tables=%d cache_lens=%d
\n
"
,
static_cast
<
int
>
(
block_tables_dt
),
static_cast
<
int
>
(
cache_lens_dt
));
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
CHECK_OR_RETURN
(
block_tables_desc
->
stride
(
1
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
cache_lens_desc
->
stride
(
0
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
if
(
seq_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
if
(
alibi_slopes_desc
.
has_value
()
&&
alibi_slopes_desc
.
value
()
!=
nullptr
)
{
if
(
alibi_slopes_desc
.
value
()
->
dtype
()
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
alibi_slopes_desc
.
value
()
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
CHECK_OR_RETURN
(
alibi_slopes_desc
.
value
()
->
stride
(
0
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
}
//
--- Extract shape dimensions ---
//
Shapes
auto
q_shape
=
q_desc
->
shape
();
auto
k_cache_shape
=
k_cache_desc
->
shape
();
auto
k_shape
=
k_cache_desc
->
shape
();
const
size_t
num_seqs
=
q_shape
[
0
];
const
size_t
num_heads
=
q_shape
[
1
];
const
size_t
head_size
=
q_shape
[
2
];
const
size_t
num_blocks
=
k_shape
[
0
];
(
void
)
num_blocks
;
const
size_t
page_block_size
=
k_shape
[
2
];
const
size_t
num_kv_heads
=
k_shape
[
1
];
// if (page_block_size % 256 != 0) {
// printf("paged block size %zu\n", page_block_size);
// return INFINI_STATUS_BAD_TENSOR_SHAPE;
// }
if
(
head_size
!=
64
&&
head_size
!=
128
)
{
// First build only targets common FA2 head dims (expand later).
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
num_heads
%
num_kv_heads
!=
0
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
v_cache_desc
->
shape
()[
0
]
!=
k_shape
[
0
]
||
v_cache_desc
->
shape
()[
1
]
!=
k_shape
[
1
]
||
v_cache_desc
->
shape
()[
2
]
!=
k_shape
[
2
]
||
v_cache_desc
->
shape
()[
3
]
!=
k_shape
[
3
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
num_seqs
=
q_shape
[
0
];
size_t
num_heads
=
q_shape
[
1
]
;
size_t
head_size
=
q_shape
[
2
];
if
(
out_desc
->
shape
()[
0
]
!=
q_shape
[
0
]
||
out_desc
->
shape
()[
1
]
!=
q_shape
[
1
]
||
out_desc
->
shape
()[
2
]
!
=
q_shape
[
2
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
head_size
!=
16
&&
head_size
!=
32
&&
head_size
!=
64
&&
head_size
!=
128
&&
head_size
!=
256
)
{
std
::
cerr
<<
"[Error] Now only supports head_size = 16/32/64/128/256, but got "
<<
head_size
<<
"."
<<
std
::
endl
;
if
(
cache_lens_desc
->
shape
()[
0
]
!=
num_seqs
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
num_kv_heads
=
k_cache_shape
[
1
];
size_t
block_size
=
v_cache_desc
->
shape
()[
2
];
// 使用V cache的block size维度更可靠
size_t
max_num_blocks_per_seq
=
block_tables_desc
->
shape
()[
1
];
const
size_t
max_num_blocks_per_seq
=
block_tables_desc
->
shape
()[
1
];
// Strides (in elements)
const
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
const
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
const
ptrdiff_t
k_batch_stride
=
k_cache_desc
->
stride
(
0
);
const
ptrdiff_t
k_row_stride
=
k_cache_desc
->
stride
(
2
);
const
ptrdiff_t
k_head_stride
=
k_cache_desc
->
stride
(
1
);
const
ptrdiff_t
v_batch_stride
=
v_cache_desc
->
stride
(
0
);
const
ptrdiff_t
v_row_stride
=
v_cache_desc
->
stride
(
2
);
const
ptrdiff_t
v_head_stride
=
v_cache_desc
->
stride
(
1
);
// --- Calculate max_seq_len for shared memory allocation ---
// This is a safe upper bound.
// info.max_seq_len = info.max_num_blocks_per_seq * info.block_size;
// --- Extract strides for memory access ---
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
const
ptrdiff_t
block_table_batch_stride
=
block_tables_desc
->
stride
(
0
);
const
ptrdiff_t
cache_lens_stride
=
cache_lens_desc
->
stride
(
0
);
return
utils
::
Result
<
PagedAttentionInfo
>
(
PagedAttentionInfo
{
dtype
,
block_tables_dt
,
scale
,
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
page_
block_size
,
max_num_blocks_per_seq
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
});
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
block_table_batch_stride
,
cache_lens_stride
,
});
}
};
...
...
src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu
0 → 100644
View file @
1c18c046
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel_v2.cuh"
namespace
op
::
paged_attention
::
nvidia
{
namespace
{
constexpr
int
kMaxSplits
=
8
;
constexpr
size_t
ceilDiv
(
size_t
a
,
size_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
int
getSmCount
()
{
int
device
=
0
;
if
(
cudaGetDevice
(
&
device
)
!=
cudaSuccess
)
{
return
0
;
}
int
sm_count
=
0
;
if
(
cudaDeviceGetAttribute
(
&
sm_count
,
cudaDevAttrMultiProcessorCount
,
device
)
!=
cudaSuccess
)
{
return
0
;
}
return
sm_count
;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline
int
chooseNumSplitsHeuristic
(
size_t
num_heads
,
size_t
num_seqs
,
size_t
seqlen_k
,
int
sm_count
)
{
if
(
sm_count
<=
0
)
{
return
1
;
}
if
(
num_heads
==
0
||
num_seqs
==
0
)
{
return
1
;
}
if
(
seqlen_k
<=
256
)
{
return
1
;
}
const
size_t
base_blocks
=
num_heads
*
num_seqs
;
int
best_splits
=
1
;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t
best_score
=
(
ceilDiv
(
base_blocks
,
static_cast
<
size_t
>
(
sm_count
))
*
seqlen_k
);
size_t
prev_work_per_block
=
seqlen_k
;
for
(
int
s
=
2
;
s
<=
kMaxSplits
;
++
s
)
{
const
size_t
blocks
=
base_blocks
*
static_cast
<
size_t
>
(
s
);
const
size_t
waves_split
=
ceilDiv
(
blocks
,
static_cast
<
size_t
>
(
sm_count
));
const
size_t
work_per_block
=
ceilDiv
(
seqlen_k
,
static_cast
<
size_t
>
(
s
));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if
(
work_per_block
==
prev_work_per_block
)
{
continue
;
}
prev_work_per_block
=
work_per_block
;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const
size_t
waves_combine
=
ceilDiv
(
base_blocks
,
static_cast
<
size_t
>
(
sm_count
));
const
size_t
score
=
waves_split
*
work_per_block
+
waves_combine
;
if
(
score
<
best_score
)
{
best_score
=
score
;
best_splits
=
s
;
}
}
return
best_splits
;
}
}
// namespace
inline
bool
envBool
(
const
char
*
name
)
{
if
(
const
char
*
env
=
std
::
getenv
(
name
))
{
return
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
return
false
;
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128Warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeWarpKernel
<
Tindex
,
Tdata
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128Cta
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
// Default CTA variant (lower overhead).
op
::
paged_attention
::
cuda
::
flashAttentionDecodeCtaKernel
<
Tindex
,
Tdata
,
128
,
64
,
8
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128CtaTile16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeCtaKernel
<
Tindex
,
Tdata
,
128
,
64
,
16
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128Cta32
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
// Experimental 1-warp CTA variant for head_dim=128 (kPack=4).
op
::
paged_attention
::
cuda
::
flashAttentionDecodeCtaKernel
<
Tindex
,
Tdata
,
128
,
32
,
8
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128Cta32Tile16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeCtaKernel
<
Tindex
,
Tdata
,
128
,
32
,
16
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128CtaGqa4
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
// GQA fused kernel: CTA computes 4 query heads for one KV head (head_dim=128).
op
::
paged_attention
::
cuda
::
flashAttentionDecodeCtaGqaKernel
<
Tindex
,
Tdata
,
128
,
64
,
8
,
4
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128SplitKv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvWarpKernel
<
Tindex
,
Tdata
,
128
>
(
partial_acc
,
partial_m
,
partial_l
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128SplitKvCta
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvCtaKernel
<
Tindex
,
Tdata
,
128
,
64
,
8
>
(
partial_acc
,
partial_m
,
partial_l
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128SplitKvCtaTile16
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvCtaKernel
<
Tindex
,
Tdata
,
128
,
64
,
16
>
(
partial_acc
,
partial_m
,
partial_l
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128SplitKvCta32
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvCtaKernel
<
Tindex
,
Tdata
,
128
,
32
,
8
>
(
partial_acc
,
partial_m
,
partial_l
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128SplitKvCta32Tile16
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvCtaKernel
<
Tindex
,
Tdata
,
128
,
32
,
16
>
(
partial_acc
,
partial_m
,
partial_l
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd128SplitKvCombine
(
Tdata
*
out
,
const
float
*
partial_acc
,
const
float
*
partial_m
,
const
float
*
partial_l
,
int
num_splits
,
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvCombineWarpKernel
<
Tdata
,
128
>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
o_stride
);
}
template
<
typename
Tindex
>
infiniStatus_t
launch_decode_hd128_impl
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
// Default decode config (2026-01-22):
// decode_flash_cta8_64_gqa_splitkv_4
// Users can override any knob via the corresponding INFINIOP_FLASH_* env vars.
bool
use_cta
=
true
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_DECODE_KERNEL"
))
{
// Backward-compatible: any non-"cta" value means "warp".
use_cta
=
(
std
::
strcmp
(
env
,
"cta"
)
==
0
);
}
bool
use_gqa_fused
=
true
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_GQA_FUSED"
))
{
if
(
std
::
strcmp
(
env
,
"0"
)
==
0
||
std
::
strcmp
(
env
,
"false"
)
==
0
)
{
use_gqa_fused
=
false
;
}
else
{
use_gqa_fused
=
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
}
int
cta_tile
=
8
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_CTA_TILE"
))
{
const
int
v
=
std
::
atoi
(
env
);
if
(
v
==
8
||
v
==
16
)
{
cta_tile
=
v
;
}
}
int
cta_threads
=
64
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_CTA_THREADS"
))
{
const
int
v
=
std
::
atoi
(
env
);
if
(
v
==
32
||
v
==
64
)
{
cta_threads
=
v
;
}
}
dim3
block
(
use_cta
?
static_cast
<
uint32_t
>
(
cta_threads
)
:
32
);
bool
use_split
=
true
;
bool
use_split_auto
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_DECODE_SPLITKV"
))
{
if
(
std
::
strcmp
(
env
,
"auto"
)
==
0
)
{
use_split_auto
=
true
;
use_split
=
false
;
}
else
{
if
(
std
::
strcmp
(
env
,
"0"
)
==
0
||
std
::
strcmp
(
env
,
"false"
)
==
0
)
{
use_split
=
false
;
}
else
{
use_split
=
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
}
}
int
num_splits
=
4
;
bool
fixed_num_splits
=
true
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_NUM_SPLITS"
))
{
if
(
std
::
strcmp
(
env
,
"auto"
)
==
0
)
{
fixed_num_splits
=
false
;
}
else
{
num_splits
=
std
::
atoi
(
env
);
fixed_num_splits
=
(
num_splits
>
0
);
}
}
if
(
num_splits
<
1
)
{
num_splits
=
1
;
}
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
const
bool
debug_dispatch
=
envBool
(
"INFINIOP_FLASH_DEBUG_DISPATCH"
);
auto
dump_dispatch
=
[
&
](
const
char
*
path
)
{
if
(
!
debug_dispatch
)
{
return
;
}
// Avoid spamming: only print when the key dispatch signature changes.
struct
Sig
{
const
char
*
path
;
int
dtype
;
size_t
heads
;
size_t
kv_heads
;
size_t
seqs
;
size_t
pbs
;
size_t
max_blocks
;
int
cta_tile
;
int
cta_threads
;
int
split
;
int
split_auto
;
int
num_splits
;
int
fixed
;
int
gqa_fused
;
};
static
Sig
last
{};
static
bool
has_last
=
false
;
Sig
cur
{
path
,
static_cast
<
int
>
(
dtype
),
num_heads
,
num_kv_heads
,
num_seqs
,
page_block_size
,
max_num_blocks_per_seq
,
cta_tile
,
cta_threads
,
static_cast
<
int
>
(
use_split
),
static_cast
<
int
>
(
use_split_auto
),
num_splits
,
static_cast
<
int
>
(
fixed_num_splits
),
static_cast
<
int
>
(
use_gqa_fused
),
};
if
(
has_last
&&
cur
.
path
==
last
.
path
&&
cur
.
dtype
==
last
.
dtype
&&
cur
.
heads
==
last
.
heads
&&
cur
.
kv_heads
==
last
.
kv_heads
&&
cur
.
seqs
==
last
.
seqs
&&
cur
.
pbs
==
last
.
pbs
&&
cur
.
max_blocks
==
last
.
max_blocks
&&
cur
.
cta_tile
==
last
.
cta_tile
&&
cur
.
cta_threads
==
last
.
cta_threads
&&
cur
.
split
==
last
.
split
&&
cur
.
split_auto
==
last
.
split_auto
&&
cur
.
num_splits
==
last
.
num_splits
&&
cur
.
fixed
==
last
.
fixed
&&
cur
.
gqa_fused
==
last
.
gqa_fused
)
{
return
;
}
last
=
cur
;
has_last
=
true
;
fprintf
(
stderr
,
"[INFINIOP][paged_attention][hd128] dispatch: path=%s dtype=%d heads=%zu kv_heads=%zu seqs=%zu "
"pbs=%zu max_blocks=%zu cta_tile=%d cta_threads=%d split=%d split_auto=%d num_splits=%d fixed=%d gqa_fused=%d
\n
"
,
path
,
static_cast
<
int
>
(
dtype
),
num_heads
,
num_kv_heads
,
num_seqs
,
page_block_size
,
max_num_blocks_per_seq
,
cta_tile
,
cta_threads
,
static_cast
<
int
>
(
use_split
),
static_cast
<
int
>
(
use_split_auto
),
num_splits
,
static_cast
<
int
>
(
fixed_num_splits
),
static_cast
<
int
>
(
use_gqa_fused
));
};
// Split-kv auto mode: decide whether to split based on a heuristic.
if
(
use_split_auto
)
{
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const
size_t
seqlen_k
=
max_num_blocks_per_seq
*
page_block_size
;
const
int
sm_count
=
getSmCount
();
num_splits
=
chooseNumSplitsHeuristic
(
num_heads
,
num_seqs
,
seqlen_k
,
sm_count
);
if
(
const
char
*
dbg
=
std
::
getenv
(
"INFINIOP_FLASH_DEBUG_SPLITS"
))
{
if
(
std
::
strcmp
(
dbg
,
"1"
)
==
0
||
std
::
strcmp
(
dbg
,
"true"
)
==
0
)
{
static
size_t
last_seqlen_k
=
0
;
if
(
last_seqlen_k
!=
seqlen_k
)
{
last_seqlen_k
=
seqlen_k
;
fprintf
(
stderr
,
"[INFINIOP][paged_attention] splitkv auto(mode): sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d
\n
"
,
sm_count
,
num_heads
,
num_seqs
,
seqlen_k
,
num_splits
);
}
}
}
// If auto picks 1, fall back to non-split to avoid extra workspace and kernel overhead.
use_split
=
(
num_splits
>
1
);
}
// const bool debug_dispatch = [] {
// if (const char *env = std::getenv("INFINIOP_FLASH_DEBUG_DISPATCH")) {
// return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
// }
// return false;
// }();
// const char *selected_path = "unknown";
// Optional: fuse GQA groups (4) when seqlen_q=1 decode and alibi is disabled.
// This reuses K/V loads across query heads that share the same KV head.
// Controlled by INFINIOP_FLASH_GQA_FUSED (default: enabled).
if
(
use_gqa_fused
&&
use_cta
&&
!
use_split
&&
alibi_slopes
==
nullptr
&&
num_kv_heads
>
0
&&
num_heads
==
num_kv_heads
*
4
)
{
dump_dispatch
(
"cta_gqa_fused"
);
dim3
grid_gqa
(
static_cast
<
uint64_t
>
(
num_kv_heads
),
static_cast
<
uint64_t
>
(
num_seqs
),
1
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
flashAttentionDecodeHd128CtaGqa4
<
Tindex
,
half
><<<
grid_gqa
,
64
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
nullptr
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
return
INFINI_STATUS_SUCCESS
;
}
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
flashAttentionDecodeHd128CtaGqa4
<
Tindex
,
__nv_bfloat16
><<<
grid_gqa
,
64
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
nullptr
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
return
INFINI_STATUS_SUCCESS
;
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
dim3
grid
(
static_cast
<
uint64_t
>
(
num_heads
),
static_cast
<
uint64_t
>
(
num_seqs
),
1
);
if
(
use_split
)
{
dump_dispatch
(
use_cta
?
"splitkv_cta"
:
"splitkv_warp"
);
// }
if
(
!
fixed_num_splits
)
{
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const
size_t
seqlen_k
=
max_num_blocks_per_seq
*
page_block_size
;
const
int
sm_count
=
getSmCount
();
num_splits
=
chooseNumSplitsHeuristic
(
num_heads
,
num_seqs
,
seqlen_k
,
sm_count
);
if
(
const
char
*
dbg
=
std
::
getenv
(
"INFINIOP_FLASH_DEBUG_SPLITS"
))
{
if
(
std
::
strcmp
(
dbg
,
"1"
)
==
0
||
std
::
strcmp
(
dbg
,
"true"
)
==
0
)
{
static
size_t
last_seqlen_k
=
0
;
if
(
last_seqlen_k
!=
seqlen_k
)
{
last_seqlen_k
=
seqlen_k
;
fprintf
(
stderr
,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d
\n
"
,
sm_count
,
num_heads
,
num_seqs
,
seqlen_k
,
num_splits
);
}
}
}
}
const
size_t
n
=
num_seqs
*
num_heads
;
const
size_t
acc_elems
=
static_cast
<
size_t
>
(
kMaxSplits
)
*
n
*
128
;
const
size_t
m_elems
=
static_cast
<
size_t
>
(
kMaxSplits
)
*
n
;
const
size_t
l_elems
=
static_cast
<
size_t
>
(
kMaxSplits
)
*
n
;
const
size_t
needed_bytes
=
(
acc_elems
+
m_elems
+
l_elems
)
*
sizeof
(
float
);
if
(
workspace
==
nullptr
||
workspace_size
<
needed_bytes
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
float
*
ws
=
static_cast
<
float
*>
(
workspace
);
float
*
partial_acc
=
ws
;
float
*
partial_m
=
partial_acc
+
acc_elems
;
float
*
partial_l
=
partial_m
+
m_elems
;
dim3
grid_split
(
static_cast
<
uint64_t
>
(
num_heads
),
static_cast
<
uint64_t
>
(
num_seqs
),
static_cast
<
uint64_t
>
(
num_splits
));
dim3
block_split
(
use_cta
?
static_cast
<
uint32_t
>
(
cta_threads
)
:
32
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
if
(
use_cta
)
{
if
(
cta_threads
==
32
)
{
if
(
cta_tile
==
16
)
{
flashAttentionDecodeHd128SplitKvCta32Tile16
<
Tindex
,
half
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
else
{
flashAttentionDecodeHd128SplitKvCta32
<
Tindex
,
half
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
}
else
{
if
(
cta_tile
==
16
)
{
flashAttentionDecodeHd128SplitKvCtaTile16
<
Tindex
,
half
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
else
{
flashAttentionDecodeHd128SplitKvCta
<
Tindex
,
half
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
}
}
else
{
flashAttentionDecodeHd128SplitKv
<
Tindex
,
half
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
flashAttentionDecodeHd128SplitKvCombine
<
half
><<<
grid
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
o_stride
);
return
INFINI_STATUS_SUCCESS
;
}
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
if
(
use_cta
)
{
if
(
cta_threads
==
32
)
{
if
(
cta_tile
==
16
)
{
flashAttentionDecodeHd128SplitKvCta32Tile16
<
Tindex
,
__nv_bfloat16
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
else
{
flashAttentionDecodeHd128SplitKvCta32
<
Tindex
,
__nv_bfloat16
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
}
else
{
if
(
cta_tile
==
16
)
{
flashAttentionDecodeHd128SplitKvCtaTile16
<
Tindex
,
__nv_bfloat16
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
else
{
flashAttentionDecodeHd128SplitKvCta
<
Tindex
,
__nv_bfloat16
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
}
}
else
{
flashAttentionDecodeHd128SplitKv
<
Tindex
,
__nv_bfloat16
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
flashAttentionDecodeHd128SplitKvCombine
<
__nv_bfloat16
><<<
grid
,
32
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
o_stride
);
return
INFINI_STATUS_SUCCESS
;
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
dump_dispatch
(
use_cta
?
"cta_nosplit"
:
"warp_nosplit"
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
if
(
use_cta
)
{
if
(
cta_tile
==
16
)
{
if
(
cta_threads
==
32
)
{
flashAttentionDecodeHd128Cta32Tile16
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
else
{
flashAttentionDecodeHd128CtaTile16
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
}
else
{
if
(
cta_threads
==
32
)
{
flashAttentionDecodeHd128Cta32
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
else
{
flashAttentionDecodeHd128Cta
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
}
}
else
{
flashAttentionDecodeHd128Warp
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
return
INFINI_STATUS_SUCCESS
;
}
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
if
(
use_cta
)
{
if
(
cta_tile
==
16
)
{
if
(
cta_threads
==
32
)
{
flashAttentionDecodeHd128Cta32Tile16
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
else
{
flashAttentionDecodeHd128CtaTile16
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
}
else
{
if
(
cta_threads
==
32
)
{
flashAttentionDecodeHd128Cta32
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
else
{
flashAttentionDecodeHd128Cta
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
}
}
else
{
flashAttentionDecodeHd128Warp
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
return
INFINI_STATUS_SUCCESS
;
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
infiniStatus_t
launch_decode_hd128_i64
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int64_t
*
block_tables
,
const
int64_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
return
launch_decode_hd128_impl
<
int64_t
>
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
dtype
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
stream
);
}
infiniStatus_t
launch_decode_hd128_i32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int32_t
*
block_tables
,
const
int32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
return
launch_decode_hd128_impl
<
int32_t
>
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
dtype
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
stream
);
}
infiniStatus_t
launch_decode_hd128_u32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
uint32_t
*
block_tables
,
const
uint32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
return
launch_decode_hd128_impl
<
uint32_t
>
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
dtype
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
stream
);
}
}
// namespace op::paged_attention::nvidia
src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu
0 → 100644
View file @
1c18c046
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel_v2.cuh"
namespace
op
::
paged_attention
::
nvidia
{
namespace
{
constexpr
int
kMaxSplits
=
8
;
constexpr
size_t
ceilDiv
(
size_t
a
,
size_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
int
getSmCount
()
{
int
device
=
0
;
if
(
cudaGetDevice
(
&
device
)
!=
cudaSuccess
)
{
return
0
;
}
int
sm_count
=
0
;
if
(
cudaDeviceGetAttribute
(
&
sm_count
,
cudaDevAttrMultiProcessorCount
,
device
)
!=
cudaSuccess
)
{
return
0
;
}
return
sm_count
;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline
int
chooseNumSplitsHeuristic
(
size_t
num_heads
,
size_t
num_seqs
,
size_t
seqlen_k
,
int
sm_count
)
{
if
(
sm_count
<=
0
)
{
return
1
;
}
if
(
num_heads
==
0
||
num_seqs
==
0
)
{
return
1
;
}
if
(
seqlen_k
<=
256
)
{
return
1
;
}
const
size_t
base_blocks
=
num_heads
*
num_seqs
;
int
best_splits
=
1
;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t
best_score
=
(
ceilDiv
(
base_blocks
,
static_cast
<
size_t
>
(
sm_count
))
*
seqlen_k
);
size_t
prev_work_per_block
=
seqlen_k
;
for
(
int
s
=
2
;
s
<=
kMaxSplits
;
++
s
)
{
const
size_t
blocks
=
base_blocks
*
static_cast
<
size_t
>
(
s
);
const
size_t
waves_split
=
ceilDiv
(
blocks
,
static_cast
<
size_t
>
(
sm_count
));
const
size_t
work_per_block
=
ceilDiv
(
seqlen_k
,
static_cast
<
size_t
>
(
s
));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if
(
work_per_block
==
prev_work_per_block
)
{
continue
;
}
prev_work_per_block
=
work_per_block
;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const
size_t
waves_combine
=
ceilDiv
(
base_blocks
,
static_cast
<
size_t
>
(
sm_count
));
const
size_t
score
=
waves_split
*
work_per_block
+
waves_combine
;
if
(
score
<
best_score
)
{
best_score
=
score
;
best_splits
=
s
;
}
}
return
best_splits
;
}
}
// namespace
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd64Warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeWarpKernel
<
Tindex
,
Tdata
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd64Cta
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
// Default CTA variant (lower overhead).
op
::
paged_attention
::
cuda
::
flashAttentionDecodeCtaKernel
<
Tindex
,
Tdata
,
64
,
32
,
8
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd64CtaTile16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeCtaKernel
<
Tindex
,
Tdata
,
64
,
32
,
16
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd64SplitKv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
int
num_splits
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvWarpKernel
<
Tindex
,
Tdata
,
64
>
(
partial_acc
,
partial_m
,
partial_l
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
}
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
flashAttentionDecodeHd64SplitKvCombine
(
Tdata
*
out
,
const
float
*
partial_acc
,
const
float
*
partial_m
,
const
float
*
partial_l
,
int
num_splits
,
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
flashAttentionDecodeSplitKvCombineWarpKernel
<
Tdata
,
64
>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
o_stride
);
}
template
<
typename
Tindex
>
infiniStatus_t
launch_decode_hd64_impl
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
Tindex
*
block_tables
,
const
Tindex
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
dim3
grid
(
static_cast
<
uint64_t
>
(
num_heads
),
static_cast
<
uint64_t
>
(
num_seqs
),
1
);
bool
use_cta
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_DECODE_KERNEL"
))
{
use_cta
=
(
std
::
strcmp
(
env
,
"cta"
)
==
0
);
}
int
cta_tile
=
8
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_CTA_TILE"
))
{
const
int
v
=
std
::
atoi
(
env
);
if
(
v
==
8
||
v
==
16
)
{
cta_tile
=
v
;
}
}
// For head_dim=64 we use a 1-warp CTA (32 threads) with packed loads.
dim3
block
(
32
);
bool
use_split
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_DECODE_SPLITKV"
))
{
use_split
=
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
int
num_splits
=
4
;
bool
fixed_num_splits
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_NUM_SPLITS"
))
{
if
(
std
::
strcmp
(
env
,
"auto"
)
==
0
)
{
fixed_num_splits
=
false
;
}
else
{
num_splits
=
std
::
atoi
(
env
);
fixed_num_splits
=
(
num_splits
>
0
);
}
}
if
(
num_splits
<
1
)
{
num_splits
=
1
;
}
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
if
(
use_split
)
{
if
(
use_cta
)
{
// We currently only implement the split-kv path with warp kernels.
// The CTA kernel is a separate non-split implementation.
static
bool
warned
=
false
;
if
(
!
warned
)
{
warned
=
true
;
fprintf
(
stderr
,
"[INFINIOP][paged_attention] split-kv is enabled; ignoring INFINIOP_FLASH_DECODE_KERNEL=cta "
"(CTA split-kv not implemented yet)
\n
"
);
}
}
if
(
!
fixed_num_splits
)
{
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const
size_t
seqlen_k
=
max_num_blocks_per_seq
*
page_block_size
;
const
int
sm_count
=
getSmCount
();
num_splits
=
chooseNumSplitsHeuristic
(
num_heads
,
num_seqs
,
seqlen_k
,
sm_count
);
if
(
const
char
*
dbg
=
std
::
getenv
(
"INFINIOP_FLASH_DEBUG_SPLITS"
))
{
if
(
std
::
strcmp
(
dbg
,
"1"
)
==
0
||
std
::
strcmp
(
dbg
,
"true"
)
==
0
)
{
static
size_t
last_seqlen_k
=
0
;
if
(
last_seqlen_k
!=
seqlen_k
)
{
last_seqlen_k
=
seqlen_k
;
fprintf
(
stderr
,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d
\n
"
,
sm_count
,
num_heads
,
num_seqs
,
seqlen_k
,
num_splits
);
}
}
}
}
const
size_t
n
=
num_seqs
*
num_heads
;
const
size_t
acc_elems
=
static_cast
<
size_t
>
(
kMaxSplits
)
*
n
*
64
;
const
size_t
m_elems
=
static_cast
<
size_t
>
(
kMaxSplits
)
*
n
;
const
size_t
l_elems
=
static_cast
<
size_t
>
(
kMaxSplits
)
*
n
;
const
size_t
needed_bytes
=
(
acc_elems
+
m_elems
+
l_elems
)
*
sizeof
(
float
);
if
(
workspace
==
nullptr
||
workspace_size
<
needed_bytes
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
float
*
ws
=
static_cast
<
float
*>
(
workspace
);
float
*
partial_acc
=
ws
;
float
*
partial_m
=
partial_acc
+
acc_elems
;
float
*
partial_l
=
partial_m
+
m_elems
;
dim3
grid_split
(
static_cast
<
uint64_t
>
(
num_heads
),
static_cast
<
uint64_t
>
(
num_seqs
),
static_cast
<
uint64_t
>
(
num_splits
));
dim3
block_split
(
32
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
flashAttentionDecodeHd64SplitKv
<
Tindex
,
half
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
flashAttentionDecodeHd64SplitKvCombine
<
half
><<<
grid
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
o_stride
);
return
INFINI_STATUS_SUCCESS
;
}
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
flashAttentionDecodeHd64SplitKv
<
Tindex
,
__nv_bfloat16
><<<
grid_split
,
block_split
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
num_splits
);
flashAttentionDecodeHd64SplitKvCombine
<
__nv_bfloat16
><<<
grid
,
32
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
o_stride
);
return
INFINI_STATUS_SUCCESS
;
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
dtype
==
INFINI_DTYPE_F16
)
{
if
(
use_cta
)
{
if
(
cta_tile
==
16
)
{
flashAttentionDecodeHd64CtaTile16
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
else
{
flashAttentionDecodeHd64Cta
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
}
else
{
flashAttentionDecodeHd64Warp
<
Tindex
,
half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
return
INFINI_STATUS_SUCCESS
;
}
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
if
(
use_cta
)
{
if
(
cta_tile
==
16
)
{
flashAttentionDecodeHd64CtaTile16
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
else
{
flashAttentionDecodeHd64Cta
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
}
else
{
flashAttentionDecodeHd64Warp
<
Tindex
,
__nv_bfloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
__nv_bfloat16
*>
(
out
),
static_cast
<
const
__nv_bfloat16
*>
(
q
),
static_cast
<
const
__nv_bfloat16
*>
(
k_cache
),
static_cast
<
const
__nv_bfloat16
*>
(
v_cache
),
block_tables
,
cache_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
);
}
return
INFINI_STATUS_SUCCESS
;
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
infiniStatus_t
launch_decode_hd64_i64
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int64_t
*
block_tables
,
const
int64_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
return
launch_decode_hd64_impl
<
int64_t
>
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
dtype
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
stream
);
}
infiniStatus_t
launch_decode_hd64_i32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int32_t
*
block_tables
,
const
int32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
return
launch_decode_hd64_impl
<
int32_t
>
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
dtype
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
stream
);
}
infiniStatus_t
launch_decode_hd64_u32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
uint32_t
*
block_tables
,
const
uint32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
return
launch_decode_hd64_impl
<
uint32_t
>
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
dtype
,
block_tables
,
cache_lens
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
q_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
stream
);
}
}
// namespace op::paged_attention::nvidia
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
View file @
1c18c046
#include <cu
b/block/block_reduce.cu
h>
#include <cu
da_runtime.
h>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "paged_attention_nvidia.cuh"
template
<
typename
Tdata
,
typename
Tcompute
,
size_t
HEAD_SIZE
,
size_t
NUM_THREADS
>
INFINIOP_CUDA_KERNEL
pagedAttention
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
int64_t
*
block_tables
,
const
int64_t
*
seq_lens
,
const
float
*
alibi_slopes
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
pagedAttentionKernel
<
Tdata
,
Tcompute
,
HEAD_SIZE
,
NUM_THREADS
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
namespace
op
::
paged_attention
::
nvidia
{
infiniStatus_t
launch_decode_hd64_i64
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int64_t
*
block_tables
,
const
int64_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
);
infiniStatus_t
launch_decode_hd64_i32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int32_t
*
block_tables
,
const
int32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
);
infiniStatus_t
launch_decode_hd64_u32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
uint32_t
*
block_tables
,
const
uint32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
);
infiniStatus_t
launch_decode_hd128_i64
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int64_t
*
block_tables
,
const
int64_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
);
infiniStatus_t
launch_decode_hd128_i32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
int32_t
*
block_tables
,
const
int32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
);
infiniStatus_t
launch_decode_hd128_u32
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
uint32_t
*
block_tables
,
const
uint32_t
*
cache_lens
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
);
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
...
...
@@ -40,108 +79,284 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cache
_lens_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
info
=
PagedAttentionInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
seq_lens_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info
);
auto
info_res
=
PagedAttentionInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
cache_lens_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info_res
);
auto
info
=
info_res
.
take
();
// Reserve workspace for optional split-kv decode (partial acc + m/l).
// Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits.
constexpr
size_t
kMaxSplits
=
8
;
const
size_t
per_split
=
info
.
num_seqs
*
info
.
num_heads
*
(
info
.
head_size
+
2
)
*
sizeof
(
float
);
const
size_t
workspace_bytes
=
kMaxSplits
*
per_split
;
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
info
,
workspace_bytes
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
size_t
HEAD_SIZE
,
size_t
NUM_THREADS
>
infiniStatus_t
launchKernel
(
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
kv_block_stride
,
ptrdiff_t
kv_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
dim3
grid
(
uint64_t
(
num_heads
),
uint64_t
(
num_seqs
),
1
);
dim3
block
(
NUM_THREADS
);
size_t
shared_mem_size
=
(
HEAD_SIZE
+
max_num_blocks_per_seq
*
block_size
+
2
)
*
sizeof
(
float
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
pagedAttention
<
half
,
float
,
HEAD_SIZE
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
half
*
)
out
,
(
const
half
*
)
q
,
(
const
half
*
)
k_cache
,
(
const
half
*
)
v_cache
,
(
const
int64_t
*
)
block_tables
,
(
const
int64_t
*
)
seq_lens
,
(
const
float
*
)
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
pagedAttention
<
__nv_bfloat16
,
float
,
HEAD_SIZE
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
__nv_bfloat16
*
)
out
,
(
const
__nv_bfloat16
*
)
q
,
(
const
__nv_bfloat16
*
)
k_cache
,
(
const
__nv_bfloat16
*
)
v_cache
,
(
const
int64_t
*
)
block_tables
,
(
const
int64_t
*
)
seq_lens
,
(
const
float
*
)
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
pagedAttention
<
float
,
float
,
HEAD_SIZE
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
float
*
)
out
,
(
const
float
*
)
q
,
(
const
float
*
)
k_cache
,
(
const
float
*
)
v_cache
,
(
const
int64_t
*
)
block_tables
,
(
const
int64_t
*
)
seq_lens
,
(
const
float
*
)
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
seq
_lens
,
const
void
*
alibi_slopes
,
const
void
*
block_tables
,
const
void
*
cache
_lens
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);
#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
SWITCH_HEAD_SIZE
(
CUDA_BLOCK_SIZE_1024
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
SWITCH_HEAD_SIZE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
SWITCH_HEAD_SIZE
(
CUDA_BLOCK_SIZE_4096
)
bool
need_workspace
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_DECODE_SPLITKV"
))
{
// "auto" may enable split-kv depending on the runtime heuristic.
need_workspace
=
(
std
::
strcmp
(
env
,
"auto"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
// Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace.
need_workspace
=
(
_info
.
head_size
==
128
);
}
if
(
need_workspace
&&
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE
auto
stream
=
static_cast
<
cudaStream_t
>
(
stream_
);
return
INFINI_STATUS_SUCCESS
;
const
float
*
alibi_ptr
=
(
alibi_slopes
==
nullptr
)
?
nullptr
:
static_cast
<
const
float
*>
(
alibi_slopes
);
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I64
)
{
const
auto
*
block_table_i64
=
static_cast
<
const
int64_t
*>
(
block_tables
);
const
auto
*
cache_lens_i64
=
static_cast
<
const
int64_t
*>
(
cache_lens
);
switch
(
_info
.
head_size
)
{
case
64
:
return
launch_decode_hd64_i64
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_table_i64
,
cache_lens_i64
,
alibi_ptr
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
page_block_size
,
_info
.
q_stride
,
_info
.
k_batch_stride
,
_info
.
k_row_stride
,
_info
.
k_head_stride
,
_info
.
v_batch_stride
,
_info
.
v_row_stride
,
_info
.
v_head_stride
,
_info
.
o_stride
,
stream
);
case
128
:
return
launch_decode_hd128_i64
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_table_i64
,
cache_lens_i64
,
alibi_ptr
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
page_block_size
,
_info
.
q_stride
,
_info
.
k_batch_stride
,
_info
.
k_row_stride
,
_info
.
k_head_stride
,
_info
.
v_batch_stride
,
_info
.
v_row_stride
,
_info
.
v_head_stride
,
_info
.
o_stride
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I32
)
{
const
auto
*
block_table_i32
=
static_cast
<
const
int32_t
*>
(
block_tables
);
const
auto
*
cache_lens_i32
=
static_cast
<
const
int32_t
*>
(
cache_lens
);
switch
(
_info
.
head_size
)
{
case
64
:
return
launch_decode_hd64_i32
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_table_i32
,
cache_lens_i32
,
alibi_ptr
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
page_block_size
,
_info
.
q_stride
,
_info
.
k_batch_stride
,
_info
.
k_row_stride
,
_info
.
k_head_stride
,
_info
.
v_batch_stride
,
_info
.
v_row_stride
,
_info
.
v_head_stride
,
_info
.
o_stride
,
stream
);
case
128
:
return
launch_decode_hd128_i32
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_table_i32
,
cache_lens_i32
,
alibi_ptr
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
page_block_size
,
_info
.
q_stride
,
_info
.
k_batch_stride
,
_info
.
k_row_stride
,
_info
.
k_head_stride
,
_info
.
v_batch_stride
,
_info
.
v_row_stride
,
_info
.
v_head_stride
,
_info
.
o_stride
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_U32
)
{
const
auto
*
block_table_u32
=
static_cast
<
const
uint32_t
*>
(
block_tables
);
const
auto
*
cache_lens_u32
=
static_cast
<
const
uint32_t
*>
(
cache_lens
);
switch
(
_info
.
head_size
)
{
case
64
:
return
launch_decode_hd64_u32
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_table_u32
,
cache_lens_u32
,
alibi_ptr
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
page_block_size
,
_info
.
q_stride
,
_info
.
k_batch_stride
,
_info
.
k_row_stride
,
_info
.
k_head_stride
,
_info
.
v_batch_stride
,
_info
.
v_row_stride
,
_info
.
v_head_stride
,
_info
.
o_stride
,
stream
);
case
128
:
return
launch_decode_hd128_u32
(
workspace
,
workspace_size
,
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_table_u32
,
cache_lens_u32
,
alibi_ptr
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
page_block_size
,
_info
.
q_stride
,
_info
.
k_batch_stride
,
_info
.
k_row_stride
,
_info
.
k_head_stride
,
_info
.
v_batch_stride
,
_info
.
v_row_stride
,
_info
.
v_head_stride
,
_info
.
o_stride
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
// namespace op::paged_attention::nvidia
// #include <cub/block/block_reduce.cuh>
// #include "../../../devices/nvidia/nvidia_common.cuh"
// #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
// #include "../../../reduce/cuda/reduce.cuh"
// #include "../cuda/kernel.cuh"
// #include "paged_attention_nvidia.cuh"
// template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
// INFINIOP_CUDA_KERNEL pagedAttention(
// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
// const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes,
// const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq,
// const size_t block_size,
// const ptrdiff_t q_stride,
// const ptrdiff_t kv_block_stride,
// const ptrdiff_t kv_head_stride,
// const ptrdiff_t o_stride) {
// op::paged_attention::cuda::pagedAttentionKernel<Tdata, Tcompute, HEAD_SIZE, NUM_THREADS>(
// out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale,
// max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride);
// }
// namespace op::paged_attention::nvidia {
// struct Descriptor::Opaque {
// std::shared_ptr<device::nvidia::Handle::Internal> internal;
// };
// Descriptor::~Descriptor() {
// delete _opaque;
// }
// infiniStatus_t Descriptor::create(
// infiniopHandle_t handle,
// Descriptor **desc_ptr,
// infiniopTensorDescriptor_t out_desc,
// infiniopTensorDescriptor_t q_desc,
// infiniopTensorDescriptor_t k_cache_desc,
// infiniopTensorDescriptor_t v_cache_desc,
// infiniopTensorDescriptor_t block_tables_desc,
// infiniopTensorDescriptor_t seq_lens_desc,
// const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
// float scale) {
// auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
// CHECK_RESULT(info);
// *desc_ptr = new Descriptor(
// new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
// info.take(), 0, handle->device, handle->device_id);
// return INFINI_STATUS_SUCCESS;
// }
// template <size_t HEAD_SIZE, size_t NUM_THREADS>
// infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache,
// infiniDtype_t dtype,
// const void *block_tables, const void *seq_lens, const void *alibi_slopes,
// size_t num_heads, size_t num_seqs,
// size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size,
// ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride,
// cudaStream_t stream) {
// dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1);
// dim3 block(NUM_THREADS);
// size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float);
// if (dtype == INFINI_DTYPE_F16) {
// pagedAttention<half, float, HEAD_SIZE, NUM_THREADS>
// <<<grid, block, shared_mem_size, stream>>>(
// (half *)out,
// (const half *)q, (const half *)k_cache, (const half *)v_cache,
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
// scale, max_num_blocks_per_seq, block_size,
// q_stride, kv_block_stride, kv_head_stride, o_stride);
// } else if (dtype == INFINI_DTYPE_BF16) {
// pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS>
// <<<grid, block, shared_mem_size, stream>>>(
// (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache,
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
// scale, max_num_blocks_per_seq, block_size,
// q_stride, kv_block_stride, kv_head_stride, o_stride);
// } else if (dtype == INFINI_DTYPE_F32) {
// pagedAttention<float, float, HEAD_SIZE, NUM_THREADS>
// <<<grid, block, shared_mem_size, stream>>>(
// (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache,
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
// scale, max_num_blocks_per_seq, block_size,
// q_stride, kv_block_stride, kv_head_stride, o_stride);
// } else {
// return INFINI_STATUS_BAD_TENSOR_DTYPE;
// }
// return INFINI_STATUS_SUCCESS;
// }
// infiniStatus_t Descriptor::calculate(
// void *workspace, size_t workspace_size,
// void *out, const void *q, const void *k_cache, const void *v_cache,
// const void *block_tables, const void *seq_lens, const void *alibi_slopes,
// void *stream_) const {
// cudaStream_t stream = (cudaStream_t)stream_;
// #define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
// launchKernel<__H_SIZE, __B_SIZE>( \
// out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
// _info.num_heads, _info.num_seqs, \
// _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
// _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
// stream);
// #define SWITCH_HEAD_SIZE(__B_SIZE) \
// switch (_info.head_size) { \
// case 16: \
// LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
// break; \
// case 32: \
// LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
// break; \
// case 64: \
// LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
// break; \
// case 128: \
// LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
// break; \
// case 256: \
// LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
// break; \
// default: \
// return INFINI_STATUS_BAD_TENSOR_SHAPE; \
// }
// if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024)
// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512)
// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096)
// } else {
// return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
// }
// #undef LAUNCH_HEADSIZE_BLOCKSIZE
// #undef SWITCH_HEAD_SIZE
// return INFINI_STATUS_SUCCESS;
// }
// } // namespace op::paged_attention::nvidia
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
0 → 100644
View file @
1c18c046
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <cstdint>
#include <type_traits>
// Reuse warp-level primitives and math helpers from decode flash_attention kernels.
#include "../../paged_attention/cuda/kernel_v2.cuh"
namespace
op
::
paged_attention_prefill
::
cuda
{
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
cu_seqlens_q
,
size_t
num_seqs
)
{
size_t
low
=
0
,
high
=
(
num_seqs
==
0
)
?
0
:
(
num_seqs
-
1
);
while
(
low
<=
high
)
{
size_t
mid
=
(
low
+
high
)
>>
1
;
const
size_t
start
=
static_cast
<
size_t
>
(
cu_seqlens_q
[
mid
]);
const
size_t
end
=
static_cast
<
size_t
>
(
cu_seqlens_q
[
mid
+
1
]);
if
(
token_idx
>=
start
&&
token_idx
<
end
)
{
return
mid
;
}
else
if
(
token_idx
<
start
)
{
if
(
mid
==
0
)
{
break
;
}
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
}
}
return
0
;
}
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
>
__device__
void
PagedAttentionPrefillWarpKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
lane
=
threadIdx
.
x
;
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
q_token_local
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
int64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_token_local
>=
q_len
)
{
return
;
}
const
int
kv_len_total
=
static_cast
<
int
>
(
total_kv_lens_
[
seq_idx
]);
const
int
history_len
=
kv_len_total
-
q_len
;
const
int
allowed_k_len
=
history_len
+
q_token_local
+
1
;
if
(
allowed_k_len
<=
0
)
{
return
;
}
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
const
int64_t
q_token
=
q_start
+
static_cast
<
int64_t
>
(
q_token_local
);
const
Tdata
*
q_ptr
=
q_
+
q_token
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
Tdata
*
out_ptr
=
out_
+
q_token
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
static_cast
<
float
>
(
q_ptr
[
dim
]);
acc
[
i
]
=
0.0
f
;
}
#if defined(__CUDA_ARCH__)
float2
q_reg2
[
DIMS_PER_THREAD
/
2
];
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
q2
=
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__half22float2
(
q2
[
j
]);
}
}
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
q2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__bfloat1622float2
(
q2
[
j
]);
}
}
#endif
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
int
t_base
=
0
;
for
(
int
logical_block
=
0
;
t_base
<
allowed_k_len
;
++
logical_block
,
t_base
+=
pbs
)
{
int
physical_block
=
0
;
if
(
lane
==
0
)
{
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
}
physical_block
=
__shfl_sync
(
0xffffffff
,
physical_block
,
0
);
const
Tdata
*
k_base
=
k_cache_
+
static_cast
<
int64_t
>
(
physical_block
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
static_cast
<
int64_t
>
(
physical_block
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
allowed_k_len
-
t_base
);
for
(
int
token_in_block
=
0
;
token_in_block
<
token_end
;
++
token_in_block
)
{
const
int
t
=
t_base
+
token_in_block
;
const
Tdata
*
k_ptr
=
k_base
+
static_cast
<
int64_t
>
(
token_in_block
)
*
k_row_stride
;
const
Tdata
*
v_ptr
=
v_base
+
static_cast
<
int64_t
>
(
token_in_block
)
*
v_row_stride
;
float
qk
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__half22float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__bfloat1622float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
#endif
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
qk
+=
q_reg
[
i
]
*
static_cast
<
float
>
(
k_ptr
[
dim
]);
}
qk
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
qk
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
lane
==
0
)
{
float
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
const
int
causal_limit
=
allowed_k_len
-
1
;
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
))
*
kLog2e
;
}
const
float
m_new
=
fmaxf
(
m
,
score
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
score
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__half22float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
v_val
=
static_cast
<
float
>
(
v_ptr
[
dim
]);
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
v_val
;
}
}
}
}
float
inv_l
=
0.0
f
;
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
o
=
acc
[
i
]
*
inv_l
;
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
o
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
o
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
o
);
}
}
}
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
>
__global__
void
PagedAttentionPrefillWarpGlobalKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
lane
=
threadIdx
.
x
;
const
size_t
head_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
const
size_t
global_token_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
if
(
lane
>=
kWarpSize
||
head_idx
>=
num_heads
||
global_token_idx
>=
total_q_tokens
)
{
return
;
}
const
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
cu_seqlens_q_
,
num_seqs
);
const
int64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
int64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
const
int
q_token_local
=
static_cast
<
int
>
(
global_token_idx
-
static_cast
<
size_t
>
(
q_start
));
if
(
q_token_local
<
0
||
q_token_local
>=
q_len
)
{
return
;
}
const
int
kv_len_total
=
static_cast
<
int
>
(
total_kv_lens_
[
seq_idx
]);
const
int
history_len
=
kv_len_total
-
q_len
;
const
int
allowed_k_len
=
history_len
+
q_token_local
+
1
;
if
(
allowed_k_len
<=
0
)
{
return
;
}
const
int
num_queries_per_kv
=
static_cast
<
int
>
(
num_heads
/
num_kv_heads
);
const
int
kv_head_idx
=
static_cast
<
int
>
(
head_idx
)
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
const
Tdata
*
q_ptr
=
q_
+
static_cast
<
int64_t
>
(
global_token_idx
)
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
Tdata
*
out_ptr
=
out_
+
static_cast
<
int64_t
>
(
global_token_idx
)
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
static_cast
<
float
>
(
q_ptr
[
dim
]);
acc
[
i
]
=
0.0
f
;
}
#if defined(__CUDA_ARCH__)
float2
q_reg2
[
DIMS_PER_THREAD
/
2
];
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
q2
=
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__half22float2
(
q2
[
j
]);
}
}
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
q2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
__bfloat1622float2
(
q2
[
j
]);
}
}
#endif
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
// Iterate by pages to avoid per-token division/mod and redundant block_table loads.
int
t_base
=
0
;
for
(
int
logical_block
=
0
;
t_base
<
allowed_k_len
;
++
logical_block
,
t_base
+=
pbs
)
{
const
int32_t
phys
=
static_cast
<
int32_t
>
(
block_table
[
logical_block
]);
const
Tdata
*
k_base
=
k_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
allowed_k_len
-
t_base
);
for
(
int
token_in_block
=
0
;
token_in_block
<
token_end
;
++
token_in_block
)
{
const
int
t
=
t_base
+
token_in_block
;
const
Tdata
*
k_ptr
=
k_base
+
static_cast
<
int64_t
>
(
token_in_block
)
*
k_row_stride
;
const
Tdata
*
v_ptr
=
v_base
+
static_cast
<
int64_t
>
(
token_in_block
)
*
v_row_stride
;
float
qk
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__half22float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__bfloat1622float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
qk
+=
q_reg
[
i
]
*
static_cast
<
float
>
(
k_ptr
[
dim
]);
}
}
qk
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
qk
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
lane
==
0
)
{
float
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
const
int
causal_limit
=
allowed_k_len
-
1
;
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
))
*
kLog2e
;
}
const
float
m_new
=
fmaxf
(
m
,
score
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
score
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__half22float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
v_val
=
static_cast
<
float
>
(
v_ptr
[
dim
]);
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
v_val
;
}
}
}
}
float
inv_l
=
0.0
f
;
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
o
=
acc
[
i
]
*
inv_l
;
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
o
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
o
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
o
);
}
}
}
template
<
typename
Tindex
,
typename
Tdata
,
typename
Tcompute
,
int
HEAD_SIZE
>
__global__
void
PagedAttentionPrefillReferenceKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_heads
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
size_t
num_seqs
)
{
const
size_t
global_token_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
const
size_t
head_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
const
size_t
dim_idx
=
static_cast
<
size_t
>
(
threadIdx
.
x
);
if
(
dim_idx
>=
HEAD_SIZE
||
head_idx
>=
num_heads
)
{
return
;
}
const
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
cu_seqlens_q_
,
num_seqs
);
const
size_t
q_token_idx
=
global_token_idx
-
static_cast
<
size_t
>
(
cu_seqlens_q_
[
seq_idx
]);
const
size_t
q_len
=
static_cast
<
size_t
>
(
cu_seqlens_q_
[
seq_idx
+
1
]
-
cu_seqlens_q_
[
seq_idx
]);
const
size_t
total_kv_len
=
static_cast
<
size_t
>
(
total_kv_lens_
[
seq_idx
]);
const
size_t
history_len
=
total_kv_len
-
q_len
;
const
size_t
causal_limit
=
history_len
+
q_token_idx
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
size_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
const
Tdata
*
q_vec
=
q_
+
static_cast
<
int64_t
>
(
global_token_idx
)
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
Tdata
*
out_ptr
=
out_
+
static_cast
<
int64_t
>
(
global_token_idx
)
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
const
size_t
pbs
=
page_block_size
;
Tcompute
max_score
=
-
INFINITY
;
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size_t
page
=
t
/
pbs
;
const
size_t
off
=
t
-
page
*
pbs
;
const
ptrdiff_t
phys
=
static_cast
<
ptrdiff_t
>
(
block_table
[
page
]);
const
Tdata
*
k_vec
=
k_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
k_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
Tcompute
score
=
0
;
for
(
size_t
d
=
0
;
d
<
HEAD_SIZE
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
static_cast
<
Tcompute
>
(
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
));
}
if
(
score
>
max_score
)
{
max_score
=
score
;
}
}
Tcompute
sum_exp
=
0
;
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size_t
page
=
t
/
pbs
;
const
size_t
off
=
t
-
page
*
pbs
;
const
ptrdiff_t
phys
=
static_cast
<
ptrdiff_t
>
(
block_table
[
page
]);
const
Tdata
*
k_vec
=
k_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
k_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
Tcompute
score
=
0
;
for
(
size_t
d
=
0
;
d
<
HEAD_SIZE
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
static_cast
<
Tcompute
>
(
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
));
}
sum_exp
+=
static_cast
<
Tcompute
>
(
expf
(
static_cast
<
float
>
(
score
-
max_score
)));
}
const
Tcompute
inv_sum
=
static_cast
<
Tcompute
>
(
1.0
f
)
/
(
sum_exp
+
static_cast
<
Tcompute
>
(
1e-6
f
));
Tcompute
acc
=
0
;
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size_t
page
=
t
/
pbs
;
const
size_t
off
=
t
-
page
*
pbs
;
const
ptrdiff_t
phys
=
static_cast
<
ptrdiff_t
>
(
block_table
[
page
]);
const
Tdata
*
k_vec
=
k_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
k_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
Tcompute
score
=
0
;
for
(
size_t
d
=
0
;
d
<
HEAD_SIZE
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_vec
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
static_cast
<
Tcompute
>
(
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
));
}
const
Tcompute
prob
=
static_cast
<
Tcompute
>
(
expf
(
static_cast
<
float
>
(
score
-
max_score
)))
*
inv_sum
;
const
Tdata
*
v_vec
=
v_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
v_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
acc
+=
prob
*
static_cast
<
Tcompute
>
(
v_vec
[
dim_idx
]);
}
out_ptr
[
dim_idx
]
=
static_cast
<
Tdata
>
(
acc
);
}
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
,
int
BLOCK_M
,
int
BLOCK_N
>
__device__
void
PagedAttentionPrefillWarpCtaKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
BLOCK_M
>
0
&&
BLOCK_M
<=
16
,
"BLOCK_M must be small (warp-per-query design)."
);
static_assert
(
BLOCK_N
==
64
||
BLOCK_N
==
128
,
"BLOCK_N must be 64/128 in v0.4."
);
constexpr
int
kWarpSize
=
32
;
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
const
int
lane
=
threadIdx
.
x
&
(
kWarpSize
-
1
);
const
int
warp_id
=
threadIdx
.
x
/
kWarpSize
;
if
(
warp_id
>=
BLOCK_M
)
{
return
;
}
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
int64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
}
const
int
m_start
=
m_block
*
BLOCK_M
;
const
int
q_token_local
=
m_start
+
warp_id
;
// IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads()
// later. Tail tiles are handled by masking inactive warps.
if
(
m_start
>=
q_len
)
{
return
;
// uniform across the CTA
}
const
bool
is_active
=
(
q_token_local
<
q_len
);
const
int64_t
kv_len_total_i64
=
total_kv_lens_
[
seq_idx
];
const
int
kv_len_total
=
static_cast
<
int
>
(
kv_len_total_i64
);
// history_len = total_kv_len - q_len (KV already includes current q tokens).
const
int
history_len
=
kv_len_total
-
q_len
;
const
int
allowed_k_len
=
is_active
?
(
history_len
+
q_token_local
+
1
)
:
0
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
int64_t
q_token
=
q_start
;
if
(
is_active
)
{
q_token
+=
static_cast
<
int64_t
>
(
q_token_local
);
}
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
const
Tdata
*
q_ptr
=
nullptr
;
Tdata
*
out_ptr
=
nullptr
;
if
(
is_active
)
{
q_ptr
=
q_
+
q_token
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
out_ptr
=
out_
+
q_token
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
}
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
is_active
?
static_cast
<
float
>
(
q_ptr
[
dim
])
:
0.0
f
;
acc
[
i
]
=
0.0
f
;
}
#if defined(__CUDA_ARCH__)
float2
q_reg2
[
DIMS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
make_float2
(
q_reg
[
j
*
2
+
0
],
q_reg
[
j
*
2
+
1
]);
}
#endif
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
// For this CTA, we only need to scan up to the max allowed k among active warps.
const
int
max_q_in_tile
=
min
(
m_start
+
BLOCK_M
,
q_len
);
const
int
max_allowed_k_len
=
min
(
history_len
+
max_q_in_tile
,
kv_len_total
);
__shared__
int32_t
s_phys
[
BLOCK_N
];
__shared__
int32_t
s_off
[
BLOCK_N
];
// Ensure shared-memory tiles are aligned for half2/bfloat162 vector loads.
__shared__
__align__
(
16
)
Tdata
s_k
[
BLOCK_N
*
HEAD_SIZE
];
__shared__
__align__
(
16
)
Tdata
s_v
[
BLOCK_N
*
HEAD_SIZE
];
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
for
(
int
k_base
=
0
;
k_base
<
max_allowed_k_len
;
k_base
+=
BLOCK_N
)
{
const
int
tile_n
=
min
(
BLOCK_N
,
max_allowed_k_len
-
k_base
);
// Precompute page mapping once per token in the tile.
for
(
int
t
=
threadIdx
.
x
;
t
<
tile_n
;
t
+=
blockDim
.
x
)
{
const
int
kpos
=
k_base
+
t
;
const
int
page
=
(
pbs
==
256
)
?
(
kpos
>>
8
)
:
(
kpos
/
pbs
);
const
int
off
=
(
pbs
==
256
)
?
(
kpos
&
255
)
:
(
kpos
-
page
*
pbs
);
const
int32_t
phys
=
static_cast
<
int32_t
>
(
block_table
[
page
]);
s_phys
[
t
]
=
phys
;
s_off
[
t
]
=
off
;
}
__syncthreads
();
// Load K/V tile into shared memory (contiguous in head_dim).
const
int
tile_elems
=
tile_n
*
HEAD_SIZE
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
tile_elems
;
idx
+=
blockDim
.
x
)
{
const
int
t
=
idx
/
HEAD_SIZE
;
const
int
dim
=
idx
-
t
*
HEAD_SIZE
;
const
int32_t
phys
=
s_phys
[
t
];
const
int32_t
off
=
s_off
[
t
];
const
Tdata
*
k_base_ptr
=
k_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
k_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
const
Tdata
*
v_base_ptr
=
v_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
v_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
s_k
[
t
*
HEAD_SIZE
+
dim
]
=
k_base_ptr
[
dim
];
s_v
[
t
*
HEAD_SIZE
+
dim
]
=
v_base_ptr
[
dim
];
}
__syncthreads
();
// Each warp processes one query token and scans the K/V tile.
for
(
int
t
=
0
;
t
<
tile_n
;
++
t
)
{
const
int
kpos
=
k_base
+
t
;
if
(
kpos
>=
allowed_k_len
)
{
break
;
}
const
Tdata
*
k_ptr
=
s_k
+
t
*
HEAD_SIZE
;
const
Tdata
*
v_ptr
=
s_v
+
t
*
HEAD_SIZE
;
float
qk
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__half22float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__bfloat1622float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
#endif
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
qk
+=
q_reg
[
i
]
*
static_cast
<
float
>
(
k_ptr
[
dim
]);
}
qk
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
qk
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
lane
==
0
)
{
float
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
// Causal prefill: last position is (allowed_k_len - 1) for this query.
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
kpos
-
(
allowed_k_len
-
1
)))
*
kLog2e
;
}
const
float
m_new
=
fmaxf
(
m
,
score
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
score
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__half22float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
v_val
=
static_cast
<
float
>
(
v_ptr
[
dim
]);
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
v_val
;
}
}
}
__syncthreads
();
}
float
inv_l
=
0.0
f
;
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
out_val
=
acc
[
i
]
*
inv_l
;
if
(
!
is_active
)
{
continue
;
}
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
out_val
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
out_val
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
out_val
);
}
}
}
// Pipelined CTA kernel (FA2-style): stage K/V loads with cp.async and overlap global->shared
// copies with compute.
//
// Design notes:
// - Keep shared memory <= 48KB for compatibility with multi-arch builds that include SM75.
// - Iterate by paged blocks (logical pages) so each tile stays within one physical block and
// avoids per-token (page, off) mapping arrays in shared memory.
// - One warp computes one query token (same as warpcta kernels). Warps with shorter causal
// limits simply mask the tail tokens but still participate in CTA-wide barriers.
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
,
int
BLOCK_M
,
int
TOKENS_PER_TILE
,
int
STAGES
>
__device__
void
PagedAttentionPrefillWarpCtaKernelPipelined
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
BLOCK_M
>
0
&&
BLOCK_M
<=
16
,
"BLOCK_M must be <= 16."
);
static_assert
(
TOKENS_PER_TILE
==
32
,
"Pipelined CTA kernel currently assumes TOKENS_PER_TILE == 32."
);
static_assert
(
STAGES
>=
2
&&
STAGES
<=
3
,
"STAGES must be 2 or 3."
);
static_assert
(
sizeof
(
Tdata
)
==
2
,
"Pipelined CTA kernel supports only fp16/bf16."
);
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
lane
=
threadIdx
.
x
&
(
kWarpSize
-
1
);
const
int
warp_id
=
threadIdx
.
x
/
kWarpSize
;
if
(
warp_id
>=
BLOCK_M
)
{
return
;
}
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
int64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
}
const
int
m_start
=
m_block
*
BLOCK_M
;
const
int
q_token_local
=
m_start
+
warp_id
;
// Uniform return for empty tail CTAs (avoid deadlock with __syncthreads).
if
(
m_start
>=
q_len
)
{
return
;
}
const
bool
is_active
=
(
q_token_local
<
q_len
);
const
int
kv_len_total
=
static_cast
<
int
>
(
total_kv_lens_
[
seq_idx
]);
const
int
history_len
=
kv_len_total
-
q_len
;
const
int
allowed_k_len
=
is_active
?
(
history_len
+
q_token_local
+
1
)
:
0
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
int64_t
q_token
=
q_start
;
if
(
is_active
)
{
q_token
+=
static_cast
<
int64_t
>
(
q_token_local
);
}
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
const
Tdata
*
q_ptr
=
nullptr
;
Tdata
*
out_ptr
=
nullptr
;
if
(
is_active
)
{
q_ptr
=
q_
+
q_token
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
out_ptr
=
out_
+
q_token
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
}
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
is_active
?
static_cast
<
float
>
(
q_ptr
[
dim
])
:
0.0
f
;
acc
[
i
]
=
0.0
f
;
}
#if defined(__CUDA_ARCH__)
float2
q_reg2
[
DIMS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
make_float2
(
q_reg
[
j
*
2
+
0
],
q_reg
[
j
*
2
+
1
]);
}
#endif
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
// For this CTA, scan KV up to the max causal limit among active warps.
const
int
max_q_in_tile
=
min
(
m_start
+
BLOCK_M
,
q_len
);
const
int
max_allowed_k_len
=
min
(
history_len
+
max_q_in_tile
,
kv_len_total
);
if
(
max_allowed_k_len
<=
0
)
{
// Nothing to attend to (should be rare). Produce zeros.
if
(
is_active
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
out_ptr
[
dim
]
=
Tdata
{};
}
}
return
;
}
// cp.async uses 16B chunks; for fp16/bf16 that's 8 elements.
constexpr
int
CHUNK_ELEMS
=
8
;
constexpr
int
CHUNKS
=
HEAD_SIZE
/
CHUNK_ELEMS
;
constexpr
int
LOADS_PER_TILE
=
CHUNKS
*
TOKENS_PER_TILE
;
// Multi-stage pipeline buffers.
__shared__
__align__
(
16
)
Tdata
sh_k
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
__shared__
__align__
(
16
)
Tdata
sh_v
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
// Per-warp scratch for tile-wise softmax (scores over TOKENS_PER_TILE).
// We keep scores in shared so each lane can load its token score (lane -> token index),
// then weights are broadcast via warp shuffles to avoid extra shared-memory traffic.
__shared__
float
sh_scores
[
BLOCK_M
][
TOKENS_PER_TILE
];
// Store Q in shared (per warp). This enables more tile-level parallelism in score
// computation without expensive cross-lane shuffles of Q registers.
__shared__
__align__
(
16
)
Tdata
sh_q
[
BLOCK_M
][
HEAD_SIZE
];
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
const
int
tid
=
threadIdx
.
x
;
// Populate per-warp Q shared tile once.
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
sh_q
[
warp_id
][
dim
]
=
is_active
?
q_ptr
[
dim
]
:
Tdata
{};
}
__syncwarp
();
int
t_base
=
0
;
for
(
int
logical_block
=
0
;
t_base
<
max_allowed_k_len
;
++
logical_block
,
t_base
+=
pbs
)
{
const
int
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
const
Tdata
*
k_base
=
k_cache_
+
static_cast
<
int64_t
>
(
physical_block
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
static_cast
<
int64_t
>
(
physical_block
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
max_allowed_k_len
-
t_base
);
const
int
num_tiles
=
(
token_end
+
TOKENS_PER_TILE
-
1
)
/
TOKENS_PER_TILE
;
if
(
num_tiles
<=
0
)
{
continue
;
}
int
pending_groups
=
0
;
const
int
preload
=
min
(
STAGES
,
num_tiles
);
for
(
int
ti
=
0
;
ti
<
preload
;
++
ti
)
{
const
int
token_in_block
=
ti
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
blockDim
.
x
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
tile_n
)
{
const
Tdata
*
k_src
=
k_base
+
static_cast
<
int64_t
>
(
token_in_block
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
static_cast
<
int64_t
>
(
token_in_block
+
tok
)
*
v_row_stride
+
off
;
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_k
[
ti
][
tok
][
off
],
k_src
);
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_v
[
ti
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
op
::
paged_attention
::
cuda
::
cpAsyncCommit
();
++
pending_groups
;
}
int
desired_pending
=
pending_groups
-
1
;
if
(
desired_pending
<
0
)
{
desired_pending
=
0
;
}
if
(
desired_pending
>
(
STAGES
-
1
))
{
desired_pending
=
(
STAGES
-
1
);
}
op
::
paged_attention
::
cuda
::
cpAsyncWaitGroupRt
(
desired_pending
);
pending_groups
=
desired_pending
;
__syncthreads
();
for
(
int
tile_idx
=
0
;
tile_idx
<
num_tiles
;
++
tile_idx
)
{
const
int
buf
=
tile_idx
%
STAGES
;
const
int
token_in_block
=
tile_idx
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
const
int
global_k_base
=
t_base
+
token_in_block
;
// Tile-wise online softmax (more FA2-like than per-token update):
// 1) Compute scores for this tile (masked to each warp's causal limit).
// 2) Compute tile max + sumexp.
// 3) Accumulate weighted V for the tile.
// 4) Merge into running (m, l, acc) in a numerically stable way.
//
// NOTE: this does not yet implement MMA / full tile-level GEMM; it mainly reduces
// the serial (lane0) online-softmax update frequency from per-token to per-tile.
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
tile_sumexp
=
0.0
f
;
float
tile_m
=
-
INFINITY
;
if
(
allowed_k_len
>
0
)
{
// 1) scores
// Increase tile-level parallelism vs the previous per-token loop:
// split the warp into 4 groups of 8 lanes; each group computes one token score in parallel.
constexpr
int
LANES_PER_GROUP
=
8
;
constexpr
int
GROUPS_PER_WARP
=
4
;
constexpr
int
DIMS_PER_GROUP_LANE
=
HEAD_SIZE
/
LANES_PER_GROUP
;
static_assert
(
HEAD_SIZE
%
LANES_PER_GROUP
==
0
,
"HEAD_SIZE must be divisible by 8."
);
const
int
group_id
=
lane
/
LANES_PER_GROUP
;
// [0..3]
const
int
lane_g
=
lane
&
(
LANES_PER_GROUP
-
1
);
// [0..7]
const
unsigned
int
group_mask
=
0xFFu
<<
(
group_id
*
LANES_PER_GROUP
);
for
(
int
j_base
=
0
;
j_base
<
TOKENS_PER_TILE
;
j_base
+=
GROUPS_PER_WARP
)
{
const
int
j
=
j_base
+
group_id
;
// token index in [0..31]
const
int
kpos
=
global_k_base
+
j
;
const
bool
token_in_tile
=
(
j
<
tile_n
);
const
bool
token_unmasked
=
token_in_tile
&&
(
kpos
<
allowed_k_len
);
float
qk_part
=
0.0
f
;
if
(
token_unmasked
)
{
const
Tdata
*
k_ptr
=
&
sh_k
[
buf
][
j
][
0
];
const
int
dim_base
=
lane_g
*
DIMS_PER_GROUP_LANE
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
half2
*
q2
=
reinterpret_cast
<
const
half2
*>
(
&
sh_q
[
warp_id
][
dim_base
]);
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
t
=
0
;
t
<
DIMS_PER_GROUP_LANE
/
2
;
++
t
)
{
const
float2
qf
=
__half22float2
(
q2
[
t
]);
const
float2
kf
=
__half22float2
(
k2
[
t
]);
qk_part
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
__nv_bfloat162
*
q2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
&
sh_q
[
warp_id
][
dim_base
]);
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
t
=
0
;
t
<
DIMS_PER_GROUP_LANE
/
2
;
++
t
)
{
const
float2
qf
=
__bfloat1622float2
(
q2
[
t
]);
const
float2
kf
=
__bfloat1622float2
(
k2
[
t
]);
qk_part
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
t
=
0
;
t
<
DIMS_PER_GROUP_LANE
;
++
t
)
{
qk_part
+=
static_cast
<
float
>
(
sh_q
[
warp_id
][
dim_base
+
t
])
*
static_cast
<
float
>
(
k_ptr
[
dim_base
+
t
]);
}
}
}
// Reduce within 8-lane group.
for
(
int
offset
=
LANES_PER_GROUP
/
2
;
offset
>
0
;
offset
>>=
1
)
{
qk_part
+=
__shfl_down_sync
(
group_mask
,
qk_part
,
offset
,
LANES_PER_GROUP
);
}
if
(
lane_g
==
0
)
{
float
score
=
-
INFINITY
;
if
(
token_unmasked
)
{
score
=
qk_part
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
const
int
causal_limit
=
allowed_k_len
-
1
;
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
kpos
-
causal_limit
))
*
kLog2e
;
}
}
sh_scores
[
warp_id
][
j
]
=
score
;
}
}
__syncwarp
();
// 2) tile max + sumexp (lane t corresponds to token t within the tile)
const
float
score_lane
=
(
lane
<
tile_n
)
?
sh_scores
[
warp_id
][
lane
]
:
-
INFINITY
;
float
tile_m_tmp
=
op
::
paged_attention
::
cuda
::
warpReduceMax
(
score_lane
);
tile_m_tmp
=
__shfl_sync
(
0xffffffff
,
tile_m_tmp
,
0
);
tile_m
=
tile_m_tmp
;
float
w_lane
=
0.0
f
;
if
(
lane
<
tile_n
&&
tile_m
!=
-
INFINITY
)
{
w_lane
=
exp2f
(
score_lane
-
tile_m
);
}
float
sumexp_tmp
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
w_lane
);
sumexp_tmp
=
__shfl_sync
(
0xffffffff
,
sumexp_tmp
,
0
);
tile_sumexp
=
sumexp_tmp
;
// 3) weighted V for this tile (per lane owns HEAD_SIZE/32 dims)
float
acc_tile
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
acc_tile
[
i
]
=
0.0
f
;
}
if
(
tile_sumexp
>
0.0
f
)
{
for
(
int
j
=
0
;
j
<
tile_n
;
++
j
)
{
// Broadcast weight for token j from lane j.
const
float
wj
=
__shfl_sync
(
0xffffffff
,
w_lane
,
j
);
const
Tdata
*
v_ptr
=
&
sh_v
[
buf
][
j
][
0
];
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
jj
=
0
;
jj
<
DIMS_PER_THREAD
/
2
;
++
jj
)
{
const
float2
vf
=
__half22float2
(
v2
[
jj
]);
acc_tile
[
jj
*
2
+
0
]
+=
wj
*
vf
.
x
;
acc_tile
[
jj
*
2
+
1
]
+=
wj
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
jj
=
0
;
jj
<
DIMS_PER_THREAD
/
2
;
++
jj
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
jj
]);
acc_tile
[
jj
*
2
+
0
]
+=
wj
*
vf
.
x
;
acc_tile
[
jj
*
2
+
1
]
+=
wj
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
acc_tile
[
i
]
+=
wj
*
static_cast
<
float
>
(
v_ptr
[
dim
]);
}
}
}
}
// 4) merge tile into running (m, l, acc)
if
(
lane
==
0
)
{
if
(
tile_sumexp
>
0.0
f
&&
tile_m
!=
-
INFINITY
)
{
const
float
m_new
=
fmaxf
(
m
,
tile_m
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
tile_m
-
m_new
);
l
=
l
*
alpha
+
tile_sumexp
*
beta
;
m
=
m_new
;
}
else
{
alpha
=
1.0
f
;
beta
=
0.0
f
;
}
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
acc_tile
[
i
];
}
}
// IMPORTANT: warps in this CTA can have different allowed_k_len (due to causal mask + history),
// so they may finish the token loop at different times. We must not start prefetching into
// the circular shared-memory buffer until all warps finish consuming the current tile.
__syncthreads
();
// Prefetch the tile that will reuse this buffer (STAGES steps ahead).
const
int
prefetch_tile
=
tile_idx
+
STAGES
;
if
(
prefetch_tile
<
num_tiles
)
{
const
int
token_prefetch
=
prefetch_tile
*
TOKENS_PER_TILE
;
const
int
prefetch_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_prefetch
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
blockDim
.
x
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
prefetch_n
)
{
const
Tdata
*
k_src
=
k_base
+
static_cast
<
int64_t
>
(
token_prefetch
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
static_cast
<
int64_t
>
(
token_prefetch
+
tok
)
*
v_row_stride
+
off
;
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_k
[
buf
][
tok
][
off
],
k_src
);
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_v
[
buf
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
op
::
paged_attention
::
cuda
::
cpAsyncCommit
();
++
pending_groups
;
}
if
(
tile_idx
+
1
<
num_tiles
)
{
int
desired_pending2
=
pending_groups
-
1
;
if
(
desired_pending2
<
0
)
{
desired_pending2
=
0
;
}
if
(
desired_pending2
>
(
STAGES
-
1
))
{
desired_pending2
=
(
STAGES
-
1
);
}
op
::
paged_attention
::
cuda
::
cpAsyncWaitGroupRt
(
desired_pending2
);
pending_groups
=
desired_pending2
;
__syncthreads
();
}
}
op
::
paged_attention
::
cuda
::
cpAsyncWaitAll
();
__syncthreads
();
}
float
inv_l
=
0.0
f
;
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
out_val
=
acc
[
i
]
*
inv_l
;
if
(
!
is_active
)
{
continue
;
}
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
out_val
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
out_val
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
out_val
);
}
}
}
// Split-KV prefill (FA2-style): each split scans a shard of KV and writes partial (m, l, acc)
// to workspace. A separate combine kernel merges splits into the final output.
//
// Notes:
// - Implemented for the pipelined CTA kernel family (warpcta8pipe). We split by logical paged blocks.
// - Each warp still applies its own causal limit (allowed_k_len) so correctness is preserved.
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
,
int
BLOCK_M
,
int
TOKENS_PER_TILE
,
int
STAGES
>
__device__
void
PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv
(
float
*
partial_acc
,
// [num_splits, total_q_tokens, num_heads, head_size]
float
*
partial_m
,
// [num_splits, total_q_tokens, num_heads]
float
*
partial_l
,
// [num_splits, total_q_tokens, num_heads]
int
split_idx
,
int
num_splits
,
int
m_block
,
size_t
total_q_tokens
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
)
{
(
void
)
max_num_blocks_per_seq
;
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
BLOCK_M
>
0
&&
BLOCK_M
<=
16
,
"BLOCK_M must be <= 16."
);
static_assert
(
TOKENS_PER_TILE
==
32
,
"Split-KV prefill assumes TOKENS_PER_TILE == 32."
);
static_assert
(
STAGES
>=
2
&&
STAGES
<=
3
,
"STAGES must be 2 or 3."
);
static_assert
(
sizeof
(
Tdata
)
==
2
,
"Split-KV prefill supports only fp16/bf16."
);
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
lane
=
threadIdx
.
x
&
(
kWarpSize
-
1
);
const
int
warp_id
=
threadIdx
.
x
/
kWarpSize
;
if
(
warp_id
>=
BLOCK_M
)
{
return
;
}
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
int64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
}
const
int
m_start
=
m_block
*
BLOCK_M
;
const
int
q_token_local
=
m_start
+
warp_id
;
if
(
m_start
>=
q_len
)
{
return
;
// uniform
}
const
bool
is_active
=
(
q_token_local
<
q_len
);
const
int
kv_len_total
=
static_cast
<
int
>
(
total_kv_lens_
[
seq_idx
]);
const
int
history_len
=
kv_len_total
-
q_len
;
const
int
allowed_k_len
=
is_active
?
(
history_len
+
q_token_local
+
1
)
:
0
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
int64_t
q_token
=
q_start
;
if
(
is_active
)
{
q_token
+=
static_cast
<
int64_t
>
(
q_token_local
);
}
const
size_t
n
=
total_q_tokens
*
static_cast
<
size_t
>
(
num_heads
);
size_t
base
=
0
;
if
(
is_active
)
{
base
=
static_cast
<
size_t
>
(
q_token
)
*
static_cast
<
size_t
>
(
num_heads
)
+
static_cast
<
size_t
>
(
head_idx
);
}
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
const
Tdata
*
q_ptr
=
nullptr
;
if
(
is_active
)
{
q_ptr
=
q_
+
q_token
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
}
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
is_active
?
static_cast
<
float
>
(
q_ptr
[
dim
])
:
0.0
f
;
acc
[
i
]
=
0.0
f
;
}
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
const
int
max_q_in_tile
=
min
(
m_start
+
BLOCK_M
,
q_len
);
const
int
max_allowed_k_len
=
min
(
history_len
+
max_q_in_tile
,
kv_len_total
);
if
(
max_allowed_k_len
<=
0
)
{
if
(
is_active
)
{
const
size_t
idx
=
static_cast
<
size_t
>
(
split_idx
)
*
n
+
base
;
if
(
lane
==
0
)
{
partial_m
[
idx
]
=
-
INFINITY
;
partial_l
[
idx
]
=
0.0
f
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
]
=
0.0
f
;
}
}
return
;
}
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
const
int
num_blocks_total
=
(
max_allowed_k_len
+
pbs
-
1
)
/
pbs
;
const
int
blocks_per_split
=
(
num_blocks_total
+
num_splits
-
1
)
/
num_splits
;
const
int
start_block
=
split_idx
*
blocks_per_split
;
const
int
end_block
=
min
(
num_blocks_total
,
start_block
+
blocks_per_split
);
if
(
start_block
>=
end_block
)
{
if
(
is_active
)
{
const
size_t
idx
=
static_cast
<
size_t
>
(
split_idx
)
*
n
+
base
;
if
(
lane
==
0
)
{
partial_m
[
idx
]
=
-
INFINITY
;
partial_l
[
idx
]
=
0.0
f
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
]
=
0.0
f
;
}
}
return
;
}
const
int
max_allowed_k_len_split
=
min
(
max_allowed_k_len
,
end_block
*
pbs
);
constexpr
int
CHUNK_ELEMS
=
8
;
constexpr
int
CHUNKS
=
HEAD_SIZE
/
CHUNK_ELEMS
;
constexpr
int
LOADS_PER_TILE
=
CHUNKS
*
TOKENS_PER_TILE
;
__shared__
__align__
(
16
)
Tdata
sh_k
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
__shared__
__align__
(
16
)
Tdata
sh_v
[
STAGES
][
TOKENS_PER_TILE
][
HEAD_SIZE
];
__shared__
float
sh_scores
[
BLOCK_M
][
TOKENS_PER_TILE
];
const
int
tid
=
threadIdx
.
x
;
int
t_base
=
start_block
*
pbs
;
for
(
int
logical_block
=
start_block
;
t_base
<
max_allowed_k_len_split
;
++
logical_block
,
t_base
+=
pbs
)
{
const
int
physical_block
=
static_cast
<
int
>
(
block_table
[
logical_block
]);
const
Tdata
*
k_base
=
k_cache_
+
static_cast
<
int64_t
>
(
physical_block
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
const
Tdata
*
v_base
=
v_cache_
+
static_cast
<
int64_t
>
(
physical_block
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
const
int
token_end
=
min
(
pbs
,
max_allowed_k_len_split
-
t_base
);
const
int
num_tiles
=
(
token_end
+
TOKENS_PER_TILE
-
1
)
/
TOKENS_PER_TILE
;
if
(
num_tiles
<=
0
)
{
continue
;
}
int
pending_groups
=
0
;
const
int
preload
=
min
(
STAGES
,
num_tiles
);
for
(
int
ti
=
0
;
ti
<
preload
;
++
ti
)
{
const
int
token_in_block
=
ti
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
blockDim
.
x
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
tile_n
)
{
const
Tdata
*
k_src
=
k_base
+
static_cast
<
int64_t
>
(
token_in_block
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
static_cast
<
int64_t
>
(
token_in_block
+
tok
)
*
v_row_stride
+
off
;
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_k
[
ti
][
tok
][
off
],
k_src
);
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_v
[
ti
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
ti
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
op
::
paged_attention
::
cuda
::
cpAsyncCommit
();
++
pending_groups
;
}
int
desired_pending
=
pending_groups
-
1
;
if
(
desired_pending
<
0
)
{
desired_pending
=
0
;
}
if
(
desired_pending
>
(
STAGES
-
1
))
{
desired_pending
=
(
STAGES
-
1
);
}
op
::
paged_attention
::
cuda
::
cpAsyncWaitGroupRt
(
desired_pending
);
pending_groups
=
desired_pending
;
__syncthreads
();
for
(
int
tile_idx
=
0
;
tile_idx
<
num_tiles
;
++
tile_idx
)
{
const
int
buf
=
tile_idx
%
STAGES
;
const
int
token_in_block
=
tile_idx
*
TOKENS_PER_TILE
;
const
int
tile_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_in_block
);
const
int
global_k_base
=
t_base
+
token_in_block
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
tile_sumexp
=
0.0
f
;
float
tile_m
=
-
INFINITY
;
float
w_lane
=
0.0
f
;
if
(
allowed_k_len
>
0
)
{
// 1) scores
for
(
int
j
=
0
;
j
<
tile_n
;
++
j
)
{
const
int
kpos
=
global_k_base
+
j
;
const
bool
token_unmasked
=
(
kpos
<
allowed_k_len
);
float
qk
=
0.0
f
;
if
(
token_unmasked
)
{
const
Tdata
*
k_ptr
=
&
sh_k
[
buf
][
j
][
0
];
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
q2
=
reinterpret_cast
<
const
half2
*>
(
q_ptr
+
dim_base
);
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
DIMS_PER_THREAD
/
2
;
++
ii
)
{
const
float2
qf
=
__half22float2
(
q2
[
ii
]);
const
float2
kf
=
__half22float2
(
k2
[
ii
]);
qk
=
fmaf
(
qf
.
x
,
kf
.
x
,
qk
);
qk
=
fmaf
(
qf
.
y
,
kf
.
y
,
qk
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
q2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
q_ptr
+
dim_base
);
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
DIMS_PER_THREAD
/
2
;
++
ii
)
{
const
float2
qf
=
__bfloat1622float2
(
q2
[
ii
]);
const
float2
kf
=
__bfloat1622float2
(
k2
[
ii
]);
qk
=
fmaf
(
qf
.
x
,
kf
.
x
,
qk
);
qk
=
fmaf
(
qf
.
y
,
kf
.
y
,
qk
);
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
qk
=
fmaf
(
q_reg
[
i
],
static_cast
<
float
>
(
k_ptr
[
dim
]),
qk
);
}
}
}
qk
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
qk
);
if
(
lane
==
0
)
{
float
score
=
token_unmasked
?
(
qk
*
scale_log2
)
:
-
INFINITY
;
if
(
token_unmasked
&&
alibi_slope
!=
0.0
f
)
{
const
int
causal_limit
=
allowed_k_len
-
1
;
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
kpos
-
causal_limit
))
*
kLog2e
;
}
sh_scores
[
warp_id
][
j
]
=
score
;
}
}
__syncwarp
();
// 2) tile max / sumexp
float
max_tmp
=
-
INFINITY
;
if
(
lane
<
tile_n
)
{
max_tmp
=
sh_scores
[
warp_id
][
lane
];
}
max_tmp
=
op
::
paged_attention
::
cuda
::
warpReduceMax
(
max_tmp
);
max_tmp
=
__shfl_sync
(
0xffffffff
,
max_tmp
,
0
);
tile_m
=
max_tmp
;
if
(
lane
<
tile_n
)
{
const
float
s
=
sh_scores
[
warp_id
][
lane
];
w_lane
=
(
s
==
-
INFINITY
)
?
0.0
f
:
exp2f
(
s
-
tile_m
);
}
else
{
w_lane
=
0.0
f
;
}
float
sumexp_tmp
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
w_lane
);
sumexp_tmp
=
__shfl_sync
(
0xffffffff
,
sumexp_tmp
,
0
);
tile_sumexp
=
sumexp_tmp
;
// 3) weighted V for this tile
float
acc_tile
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
acc_tile
[
i
]
=
0.0
f
;
}
if
(
tile_sumexp
>
0.0
f
)
{
for
(
int
j
=
0
;
j
<
tile_n
;
++
j
)
{
const
float
wj
=
__shfl_sync
(
0xffffffff
,
w_lane
,
j
);
const
Tdata
*
v_ptr
=
&
sh_v
[
buf
][
j
][
0
];
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
jj
=
0
;
jj
<
DIMS_PER_THREAD
/
2
;
++
jj
)
{
const
float2
vf
=
__half22float2
(
v2
[
jj
]);
acc_tile
[
jj
*
2
+
0
]
+=
wj
*
vf
.
x
;
acc_tile
[
jj
*
2
+
1
]
+=
wj
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
jj
=
0
;
jj
<
DIMS_PER_THREAD
/
2
;
++
jj
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
jj
]);
acc_tile
[
jj
*
2
+
0
]
+=
wj
*
vf
.
x
;
acc_tile
[
jj
*
2
+
1
]
+=
wj
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
acc_tile
[
i
]
+=
wj
*
static_cast
<
float
>
(
v_ptr
[
dim
]);
}
}
}
}
// 4) merge tile into running (m, l, acc)
if
(
lane
==
0
)
{
if
(
tile_sumexp
>
0.0
f
&&
tile_m
!=
-
INFINITY
)
{
const
float
m_new
=
fmaxf
(
m
,
tile_m
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
tile_m
-
m_new
);
l
=
l
*
alpha
+
tile_sumexp
*
beta
;
m
=
m_new
;
}
else
{
alpha
=
1.0
f
;
beta
=
0.0
f
;
}
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
acc_tile
[
i
];
}
}
__syncthreads
();
const
int
prefetch_tile
=
tile_idx
+
STAGES
;
if
(
prefetch_tile
<
num_tiles
)
{
const
int
token_prefetch
=
prefetch_tile
*
TOKENS_PER_TILE
;
const
int
prefetch_n
=
min
(
TOKENS_PER_TILE
,
token_end
-
token_prefetch
);
for
(
int
li
=
tid
;
li
<
LOADS_PER_TILE
;
li
+=
blockDim
.
x
)
{
const
int
tok
=
li
/
CHUNKS
;
const
int
chunk
=
li
-
tok
*
CHUNKS
;
const
int
off
=
chunk
*
CHUNK_ELEMS
;
if
(
tok
<
prefetch_n
)
{
const
Tdata
*
k_src
=
k_base
+
static_cast
<
int64_t
>
(
token_prefetch
+
tok
)
*
k_row_stride
+
off
;
const
Tdata
*
v_src
=
v_base
+
static_cast
<
int64_t
>
(
token_prefetch
+
tok
)
*
v_row_stride
+
off
;
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_k
[
buf
][
tok
][
off
],
k_src
);
op
::
paged_attention
::
cuda
::
cpAsyncCaSharedGlobal16
(
&
sh_v
[
buf
][
tok
][
off
],
v_src
);
}
else
{
reinterpret_cast
<
uint4
*>
(
&
sh_k
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
reinterpret_cast
<
uint4
*>
(
&
sh_v
[
buf
][
tok
][
off
])[
0
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
op
::
paged_attention
::
cuda
::
cpAsyncCommit
();
++
pending_groups
;
}
if
(
tile_idx
+
1
<
num_tiles
)
{
int
desired_pending2
=
pending_groups
-
1
;
if
(
desired_pending2
<
0
)
{
desired_pending2
=
0
;
}
if
(
desired_pending2
>
(
STAGES
-
1
))
{
desired_pending2
=
(
STAGES
-
1
);
}
op
::
paged_attention
::
cuda
::
cpAsyncWaitGroupRt
(
desired_pending2
);
pending_groups
=
desired_pending2
;
__syncthreads
();
}
}
op
::
paged_attention
::
cuda
::
cpAsyncWaitAll
();
__syncthreads
();
}
if
(
is_active
)
{
const
size_t
idx
=
static_cast
<
size_t
>
(
split_idx
)
*
n
+
base
;
if
(
lane
==
0
)
{
partial_m
[
idx
]
=
m
;
partial_l
[
idx
]
=
l
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
partial_acc
[
idx
*
HEAD_SIZE
+
dim
]
=
acc
[
i
];
}
}
}
template
<
typename
Tdata
,
int
HEAD_SIZE
>
__device__
void
PagedAttentionPrefillSplitKvCombineWarpKernel
(
Tdata
*
out_
,
const
float
*
partial_acc
,
// [num_splits, total_q_tokens, num_heads, head_size]
const
float
*
partial_m
,
// [num_splits, total_q_tokens, num_heads]
const
float
*
partial_l
,
// [num_splits, total_q_tokens, num_heads]
int
num_splits
,
size_t
total_q_tokens
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
token_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
lane
=
threadIdx
.
x
;
constexpr
int
kWarpSize
=
32
;
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
const
int
num_heads
=
gridDim
.
x
;
const
size_t
n
=
total_q_tokens
*
static_cast
<
size_t
>
(
num_heads
);
const
size_t
base
=
static_cast
<
size_t
>
(
token_idx
)
*
static_cast
<
size_t
>
(
num_heads
)
+
static_cast
<
size_t
>
(
head_idx
);
float
m
=
-
INFINITY
;
if
(
lane
==
0
)
{
for
(
int
s
=
0
;
s
<
num_splits
;
++
s
)
{
m
=
fmaxf
(
m
,
partial_m
[
static_cast
<
size_t
>
(
s
)
*
n
+
base
]);
}
}
m
=
__shfl_sync
(
0xffffffff
,
m
,
0
);
float
l
=
0.0
f
;
if
(
lane
==
0
)
{
for
(
int
s
=
0
;
s
<
num_splits
;
++
s
)
{
const
float
ms
=
partial_m
[
static_cast
<
size_t
>
(
s
)
*
n
+
base
];
const
float
ls
=
partial_l
[
static_cast
<
size_t
>
(
s
)
*
n
+
base
];
if
(
ls
>
0.0
f
)
{
l
+=
ls
*
exp2f
(
ms
-
m
);
}
}
}
l
=
__shfl_sync
(
0xffffffff
,
l
,
0
);
const
float
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
Tdata
*
out_ptr
=
out_
+
static_cast
<
int64_t
>
(
token_idx
)
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
float
acc
=
0.0
f
;
for
(
int
s
=
0
;
s
<
num_splits
;
++
s
)
{
const
float
ms
=
partial_m
[
static_cast
<
size_t
>
(
s
)
*
n
+
base
];
const
float
w
=
exp2f
(
ms
-
m
);
acc
+=
partial_acc
[(
static_cast
<
size_t
>
(
s
)
*
n
+
base
)
*
HEAD_SIZE
+
dim
]
*
w
;
}
const
float
o
=
acc
*
inv_l
;
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
o
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
o
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
o
);
}
}
}
// Variant for large K tile where (K+V) shared memory would exceed the per-block limit on some GPUs.
// We keep K in shared memory for reuse across warps, but load V directly from global memory.
template
<
typename
Tindex
,
typename
Tdata
,
int
HEAD_SIZE
,
int
BLOCK_M
,
int
BLOCK_N
>
__device__
void
PagedAttentionPrefillWarpCtaKernelKOnly
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
static_assert
(
HEAD_SIZE
==
64
||
HEAD_SIZE
==
128
,
"Only head_size 64/128 supported in v0.4."
);
static_assert
(
BLOCK_M
>
0
&&
BLOCK_M
<=
16
,
"BLOCK_M must be <=16."
);
static_assert
(
BLOCK_N
>
0
&&
BLOCK_N
<=
128
,
"BLOCK_N must be <=128."
);
constexpr
int
kWarpSize
=
32
;
constexpr
int
DIMS_PER_THREAD
=
HEAD_SIZE
/
kWarpSize
;
static_assert
(
HEAD_SIZE
%
kWarpSize
==
0
,
"HEAD_SIZE must be divisible by 32."
);
const
int
lane
=
threadIdx
.
x
&
(
kWarpSize
-
1
);
const
int
warp_id
=
threadIdx
.
x
/
kWarpSize
;
if
(
warp_id
>=
BLOCK_M
)
{
return
;
}
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
int64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
}
const
int
m_start
=
m_block
*
BLOCK_M
;
const
int
q_token_local
=
m_start
+
warp_id
;
// IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads()
// later. Tail tiles are handled by masking inactive warps.
if
(
m_start
>=
q_len
)
{
return
;
// uniform across the CTA
}
const
bool
is_active
=
(
q_token_local
<
q_len
);
const
int
kv_len_total
=
static_cast
<
int
>
(
total_kv_lens_
[
seq_idx
]);
const
int
history_len
=
kv_len_total
-
q_len
;
const
int
allowed_k_len
=
is_active
?
(
history_len
+
q_token_local
+
1
)
:
0
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
int64_t
q_token
=
q_start
;
if
(
is_active
)
{
q_token
+=
static_cast
<
int64_t
>
(
q_token_local
);
}
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
const
Tdata
*
q_ptr
=
nullptr
;
Tdata
*
out_ptr
=
nullptr
;
if
(
is_active
)
{
q_ptr
=
q_
+
q_token
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
out_ptr
=
out_
+
q_token
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
}
float
q_reg
[
DIMS_PER_THREAD
];
float
acc
[
DIMS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
q_reg
[
i
]
=
is_active
?
static_cast
<
float
>
(
q_ptr
[
dim
])
:
0.0
f
;
acc
[
i
]
=
0.0
f
;
}
#if defined(__CUDA_ARCH__)
float2
q_reg2
[
DIMS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
q_reg2
[
j
]
=
make_float2
(
q_reg
[
j
*
2
+
0
],
q_reg
[
j
*
2
+
1
]);
}
#endif
float
m
=
-
INFINITY
;
float
l
=
0.0
f
;
const
int
max_q_in_tile
=
min
(
m_start
+
BLOCK_M
,
q_len
);
const
int
max_allowed_k_len
=
min
(
history_len
+
max_q_in_tile
,
kv_len_total
);
__shared__
int32_t
s_phys
[
BLOCK_N
];
__shared__
int32_t
s_off
[
BLOCK_N
];
__shared__
__align__
(
16
)
Tdata
s_k
[
BLOCK_N
*
HEAD_SIZE
];
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
for
(
int
k_base
=
0
;
k_base
<
max_allowed_k_len
;
k_base
+=
BLOCK_N
)
{
const
int
tile_n
=
min
(
BLOCK_N
,
max_allowed_k_len
-
k_base
);
for
(
int
t
=
threadIdx
.
x
;
t
<
tile_n
;
t
+=
blockDim
.
x
)
{
const
int
kpos
=
k_base
+
t
;
const
int
page
=
(
pbs
==
256
)
?
(
kpos
>>
8
)
:
(
kpos
/
pbs
);
const
int
off
=
(
pbs
==
256
)
?
(
kpos
&
255
)
:
(
kpos
-
page
*
pbs
);
const
int32_t
phys
=
static_cast
<
int32_t
>
(
block_table
[
page
]);
s_phys
[
t
]
=
phys
;
s_off
[
t
]
=
off
;
}
__syncthreads
();
const
int
tile_elems
=
tile_n
*
HEAD_SIZE
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
tile_elems
;
idx
+=
blockDim
.
x
)
{
const
int
t
=
idx
/
HEAD_SIZE
;
const
int
dim
=
idx
-
t
*
HEAD_SIZE
;
const
int32_t
phys
=
s_phys
[
t
];
const
int32_t
off
=
s_off
[
t
];
const
Tdata
*
k_base_ptr
=
k_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
k_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
s_k
[
t
*
HEAD_SIZE
+
dim
]
=
k_base_ptr
[
dim
];
}
__syncthreads
();
for
(
int
t
=
0
;
t
<
tile_n
;
++
t
)
{
const
int
kpos
=
k_base
+
t
;
if
(
kpos
>=
allowed_k_len
)
{
break
;
}
const
Tdata
*
k_ptr
=
s_k
+
t
*
HEAD_SIZE
;
const
int32_t
phys
=
s_phys
[
t
];
const
int32_t
off
=
s_off
[
t
];
const
Tdata
*
v_ptr
=
v_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
v_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
float
qk
=
0.0
f
;
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
k2
=
reinterpret_cast
<
const
half2
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__half22float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
k2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
k_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
qf
=
q_reg2
[
j
];
const
float2
kf
=
__bfloat1622float2
(
k2
[
j
]);
qk
+=
qf
.
x
*
kf
.
x
+
qf
.
y
*
kf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
qk
+=
q_reg
[
i
]
*
static_cast
<
float
>
(
k_ptr
[
dim
]);
}
}
qk
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
qk
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
lane
==
0
)
{
float
score
=
qk
*
scale_log2
;
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
(
alibi_slope
*
static_cast
<
float
>
(
kpos
-
(
allowed_k_len
-
1
)))
*
kLog2e
;
}
const
float
m_new
=
fmaxf
(
m
,
score
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
score
-
m_new
);
l
=
l
*
alpha
+
beta
;
m
=
m_new
;
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#if defined(__CUDA_ARCH__)
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__half22float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
const
int
dim_base
=
lane
*
DIMS_PER_THREAD
;
const
__nv_bfloat162
*
v2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
v_ptr
+
dim_base
);
#pragma unroll
for
(
int
j
=
0
;
j
<
DIMS_PER_THREAD
/
2
;
++
j
)
{
const
float2
vf
=
__bfloat1622float2
(
v2
[
j
]);
acc
[
j
*
2
+
0
]
=
acc
[
j
*
2
+
0
]
*
alpha
+
beta
*
vf
.
x
;
acc
[
j
*
2
+
1
]
=
acc
[
j
*
2
+
1
]
*
alpha
+
beta
*
vf
.
y
;
}
}
else
#endif
{
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
v_val
=
static_cast
<
float
>
(
v_ptr
[
dim
]);
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
v_val
;
}
}
}
__syncthreads
();
}
float
inv_l
=
0.0
f
;
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
DIMS_PER_THREAD
;
++
i
)
{
const
int
dim
=
lane
*
DIMS_PER_THREAD
+
i
;
const
float
out_val
=
acc
[
i
]
*
inv_l
;
if
(
!
is_active
)
{
continue
;
}
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
out_ptr
[
dim
]
=
__float2half_rn
(
out_val
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
__nv_bfloat16
>
)
{
out_ptr
[
dim
]
=
__float2bfloat16_rn
(
out_val
);
}
else
{
out_ptr
[
dim
]
=
static_cast
<
Tdata
>
(
out_val
);
}
}
}
// TensorCore (WMMA) score kernel (v0.4 experimental):
// - Target shape: head_dim=128, page_block_size=256, fp16.
// - Compute QK^T with WMMA into shared memory, then reuse the existing online-softmax + V accumulation
// pattern (SIMT) per query row.
//
// Notes:
// - This is a correctness-first kernel. It doesn't yet use MMA for PV (P * V) update.
// - We keep the same grid mapping as other prefill kernels: blockIdx = (head, seq, m_block).
template
<
int
kWarpSize
,
int
kBlockN
,
int
kHeadDim
,
int
kDimsPerThread
>
__device__
__forceinline__
void
PagedAttentionPrefillMmaScoreUpdateRow
(
int
lane
,
int
k_base
,
int
allowed_k_len
,
const
float
*
scores_row
,
// [kBlockN]
const
half
*
v_tile
,
// [kBlockN, kHeadDim]
float
scale_log2
,
float
alibi_slope_log2
,
float
&
m
,
float
&
l
,
float
*
acc
)
{
// [kDimsPerThread]
// Max over keys in this tile.
float
local_max
=
-
INFINITY
;
for
(
int
t
=
lane
;
t
<
kBlockN
;
t
+=
kWarpSize
)
{
const
int
kpos
=
k_base
+
t
;
if
(
kpos
>=
allowed_k_len
)
{
continue
;
}
float
score
=
scores_row
[
t
]
*
scale_log2
;
if
(
alibi_slope_log2
!=
0.0
f
)
{
score
+=
alibi_slope_log2
*
static_cast
<
float
>
(
kpos
-
(
allowed_k_len
-
1
));
}
local_max
=
fmaxf
(
local_max
,
score
);
}
float
tile_m
=
op
::
paged_attention
::
cuda
::
warpReduceMax
(
local_max
);
tile_m
=
__shfl_sync
(
0xffffffff
,
tile_m
,
0
);
// Sumexp + weighted V over keys in this tile, partitioned by lanes.
float
sumexp_lane
=
0.0
f
;
float
acc_tile
[
kDimsPerThread
]
=
{
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
const
int
dim_base
=
lane
*
kDimsPerThread
;
if
(
tile_m
!=
-
INFINITY
)
{
for
(
int
t
=
lane
;
t
<
kBlockN
;
t
+=
kWarpSize
)
{
const
int
kpos
=
k_base
+
t
;
if
(
kpos
>=
allowed_k_len
)
{
continue
;
}
float
score
=
scores_row
[
t
]
*
scale_log2
;
if
(
alibi_slope_log2
!=
0.0
f
)
{
score
+=
alibi_slope_log2
*
static_cast
<
float
>
(
kpos
-
(
allowed_k_len
-
1
));
}
const
float
w
=
exp2f
(
score
-
tile_m
);
sumexp_lane
+=
w
;
const
half
*
v_ptr
=
v_tile
+
t
*
kHeadDim
+
dim_base
;
const
half2
*
v2
=
reinterpret_cast
<
const
half2
*>
(
v_ptr
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kDimsPerThread
/
2
;
++
j
)
{
const
float2
vf
=
__half22float2
(
v2
[
j
]);
acc_tile
[
j
*
2
+
0
]
+=
w
*
vf
.
x
;
acc_tile
[
j
*
2
+
1
]
+=
w
*
vf
.
y
;
}
}
}
float
tile_sumexp
=
op
::
paged_attention
::
cuda
::
warpReduceSum
(
sumexp_lane
);
tile_sumexp
=
__shfl_sync
(
0xffffffff
,
tile_sumexp
,
0
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
lane
==
0
)
{
if
(
tile_sumexp
>
0.0
f
&&
tile_m
!=
-
INFINITY
)
{
const
float
m_new
=
fmaxf
(
m
,
tile_m
);
alpha
=
exp2f
(
m
-
m_new
);
beta
=
exp2f
(
tile_m
-
m_new
);
l
=
l
*
alpha
+
tile_sumexp
*
beta
;
m
=
m_new
;
}
else
{
alpha
=
1.0
f
;
beta
=
0.0
f
;
}
}
alpha
=
__shfl_sync
(
0xffffffff
,
alpha
,
0
);
beta
=
__shfl_sync
(
0xffffffff
,
beta
,
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kDimsPerThread
;
++
i
)
{
acc
[
i
]
=
acc
[
i
]
*
alpha
+
beta
*
acc_tile
[
i
];
}
}
template
<
int
kWarpSize
,
int
kHeadDim
,
int
kDimsPerThread
>
__device__
__forceinline__
void
PagedAttentionPrefillMmaScoreWriteRow
(
int
lane
,
bool
active
,
int
q_token_local
,
int64_t
q_start
,
int
head_idx
,
half
*
out_
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
float
l
,
const
float
*
acc
)
{
// [kDimsPerThread]
if
(
!
active
)
{
return
;
}
float
inv_l
=
0.0
f
;
if
(
lane
==
0
)
{
inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
inv_l
=
__shfl_sync
(
0xffffffff
,
inv_l
,
0
);
const
int64_t
q_token
=
q_start
+
static_cast
<
int64_t
>
(
q_token_local
);
half
*
out_ptr
=
out_
+
q_token
*
o_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
o_head_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kDimsPerThread
;
++
i
)
{
const
int
dim
=
lane
*
kDimsPerThread
+
i
;
out_ptr
[
dim
]
=
__float2half_rn
(
acc
[
i
]
*
inv_l
);
}
}
template
<
typename
Tindex
>
__device__
void
PagedAttentionPrefillWarpCta8MmaHd128Kernel
(
half
*
out_
,
const
half
*
q_
,
const
half
*
k_cache_
,
const
half
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
(
void
)
max_num_blocks_per_seq
;
constexpr
int
kWarpSize
=
32
;
constexpr
int
kWarps
=
8
;
constexpr
int
kHeadDim
=
128
;
// Extra padding in the K dimension to reduce shared-memory bank conflicts for ldmatrix / wmma loads.
// NOTE: FA2 uses a swizzled smem layout; padding is a smaller step that keeps our code simple.
constexpr
int
kHeadDimSmem
=
136
;
// must be a multiple of 8 for wmma::load_matrix_sync
constexpr
int
kBlockM
=
16
;
// 2 rows per warp
// Keep static shared memory <= 48KB for compatibility with build targets that cap SMEM at 0xC000.
// kBlockN=64 brings s_q+s_k+s_v+s_scores+s_phys/s_off down to ~41KB.
constexpr
int
kBlockN
=
64
;
constexpr
int
kDimsPerThread
=
kHeadDim
/
kWarpSize
;
static_assert
(
kHeadDim
%
kWarpSize
==
0
,
"head_dim must be divisible by 32."
);
const
int
lane
=
threadIdx
.
x
&
(
kWarpSize
-
1
);
const
int
warp_id
=
threadIdx
.
x
/
kWarpSize
;
if
(
warp_id
>=
kWarps
)
{
return
;
}
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
int64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
}
const
int
m_start
=
m_block
*
kBlockM
;
// Uniform early return for empty tail tiles (avoid deadlock with __syncthreads()).
if
(
m_start
>=
q_len
)
{
return
;
}
const
int
kv_len_total
=
static_cast
<
int
>
(
total_kv_lens_
[
seq_idx
]);
const
int
history_len
=
kv_len_total
-
q_len
;
// Clamp max k length for this CTA based on the last active query row in the tile.
const
int
max_q_in_tile
=
min
(
m_start
+
kBlockM
,
q_len
);
const
int
max_allowed_k_len
=
min
(
history_len
+
max_q_in_tile
,
kv_len_total
);
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
static_cast
<
int
>
(
num_kv_heads
);
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
constexpr
float
kLog2e
=
1.4426950408889634
f
;
const
float
scale_log2
=
scale
*
kLog2e
;
const
float
alibi_slope_log2
=
alibi_slope
*
kLog2e
;
const
int
pbs
=
static_cast
<
int
>
(
page_block_size
);
const
Tindex
*
block_table
=
block_tables_
+
static_cast
<
int64_t
>
(
seq_idx
)
*
static_cast
<
int64_t
>
(
block_table_batch_stride
);
// Shared memory:
// - s_q: [kBlockM, kHeadDimSmem] (padded)
// - s_k/s_v: [kBlockN, kHeadDim]
// - s_scores: [kBlockM, kBlockN] raw dot products (no scale / alibi)
__shared__
__align__
(
16
)
half
s_q
[
kBlockM
*
kHeadDimSmem
];
__shared__
int32_t
s_phys
[
kBlockN
];
__shared__
int32_t
s_off
[
kBlockN
];
__shared__
__align__
(
16
)
half
s_k
[
kBlockN
*
kHeadDimSmem
];
__shared__
__align__
(
16
)
half
s_v
[
kBlockN
*
kHeadDimSmem
];
__shared__
__align__
(
16
)
float
s_scores
[
kBlockM
*
kBlockN
];
// Load Q tile (pad inactive rows with 0).
for
(
int
idx
=
threadIdx
.
x
;
idx
<
kBlockM
*
kHeadDim
;
idx
+=
blockDim
.
x
)
{
const
int
r
=
idx
/
kHeadDim
;
const
int
d
=
idx
-
r
*
kHeadDim
;
const
int
q_token_local
=
m_start
+
r
;
if
(
q_token_local
<
q_len
)
{
const
int64_t
q_token
=
q_start
+
static_cast
<
int64_t
>
(
q_token_local
);
const
half
*
q_ptr
=
q_
+
q_token
*
q_stride
+
static_cast
<
int64_t
>
(
head_idx
)
*
q_head_stride
;
s_q
[
r
*
kHeadDimSmem
+
d
]
=
q_ptr
[
d
];
}
else
{
s_q
[
r
*
kHeadDimSmem
+
d
]
=
__float2half_rn
(
0.0
f
);
}
}
__syncthreads
();
// Two rows per warp: row0=warp_id, row1=warp_id+kWarps.
const
int
row0
=
warp_id
;
const
int
row1
=
warp_id
+
kWarps
;
const
bool
active0
=
(
row0
<
kBlockM
)
&&
((
m_start
+
row0
)
<
q_len
);
const
bool
active1
=
(
row1
<
kBlockM
)
&&
((
m_start
+
row1
)
<
q_len
);
const
int
allowed0
=
active0
?
min
(
history_len
+
(
m_start
+
row0
)
+
1
,
kv_len_total
)
:
0
;
const
int
allowed1
=
active1
?
min
(
history_len
+
(
m_start
+
row1
)
+
1
,
kv_len_total
)
:
0
;
float
m0
=
-
INFINITY
,
l0
=
0.0
f
;
float
m1
=
-
INFINITY
,
l1
=
0.0
f
;
float
acc0
[
kDimsPerThread
]
=
{
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
float
acc1
[
kDimsPerThread
]
=
{
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
// Iterate over K/V tiles.
for
(
int
k_base
=
0
;
k_base
<
max_allowed_k_len
;
k_base
+=
kBlockN
)
{
// Map logical k positions to physical blocks for this tile (pad the tail with -1).
for
(
int
t
=
threadIdx
.
x
;
t
<
kBlockN
;
t
+=
blockDim
.
x
)
{
const
int
kpos
=
k_base
+
t
;
if
(
kpos
<
max_allowed_k_len
)
{
const
int
page
=
(
pbs
==
256
)
?
(
kpos
>>
8
)
:
(
kpos
/
pbs
);
const
int
off
=
(
pbs
==
256
)
?
(
kpos
&
255
)
:
(
kpos
-
page
*
pbs
);
s_phys
[
t
]
=
static_cast
<
int32_t
>
(
block_table
[
page
]);
s_off
[
t
]
=
off
;
}
else
{
s_phys
[
t
]
=
-
1
;
s_off
[
t
]
=
0
;
}
}
__syncthreads
();
// Load K/V tile into shared memory (pad with 0 for inactive tokens).
for
(
int
idx
=
threadIdx
.
x
;
idx
<
kBlockN
*
kHeadDim
;
idx
+=
blockDim
.
x
)
{
const
int
t
=
idx
/
kHeadDim
;
const
int
d
=
idx
-
t
*
kHeadDim
;
const
int32_t
phys
=
s_phys
[
t
];
if
(
phys
>=
0
)
{
const
int32_t
off
=
s_off
[
t
];
const
half
*
k_ptr
=
k_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
k_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
k_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
k_head_stride
;
const
half
*
v_ptr
=
v_cache_
+
static_cast
<
int64_t
>
(
phys
)
*
v_batch_stride
+
static_cast
<
int64_t
>
(
off
)
*
v_row_stride
+
static_cast
<
int64_t
>
(
kv_head_idx
)
*
v_head_stride
;
s_k
[
t
*
kHeadDimSmem
+
d
]
=
k_ptr
[
d
];
s_v
[
t
*
kHeadDimSmem
+
d
]
=
v_ptr
[
d
];
}
else
{
s_k
[
t
*
kHeadDimSmem
+
d
]
=
__float2half_rn
(
0.0
f
);
s_v
[
t
*
kHeadDimSmem
+
d
]
=
__float2half_rn
(
0.0
f
);
}
}
__syncthreads
();
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
// WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows.
// For kBlockN=64, only the first 4 warps participate in WMMA score computation.
namespace
wmma
=
nvcuda
::
wmma
;
constexpr
int
kNSub
=
kBlockN
/
16
;
if
(
warp_id
<
kNSub
)
{
wmma
::
fragment
<
wmma
::
matrix_a
,
16
,
16
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
16
,
16
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
16
,
16
,
16
,
float
>
c_frag
;
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
const
int
n_sub
=
warp_id
;
// [0, kNSub)
const
half
*
q_tile
=
s_q
;
const
half
*
k_tile
=
s_k
+
(
n_sub
*
16
)
*
kHeadDimSmem
;
// K loop (head_dim=128).
#pragma unroll
for
(
int
kk
=
0
;
kk
<
(
kHeadDim
/
16
);
++
kk
)
{
wmma
::
load_matrix_sync
(
a_frag
,
q_tile
+
kk
*
16
,
kHeadDimSmem
);
wmma
::
load_matrix_sync
(
b_frag
,
k_tile
+
kk
*
16
,
kHeadDimSmem
);
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
}
float
*
scores_tile
=
s_scores
+
n_sub
*
16
;
wmma
::
store_matrix_sync
(
scores_tile
,
c_frag
,
kBlockN
,
wmma
::
mem_row_major
);
}
#else
// No WMMA support on this architecture: fall back to scalar dot in the existing kernels.
// (We keep scores as 0 so this kernel is effectively incorrect; host dispatch must avoid selecting it.)
if
(
threadIdx
.
x
==
0
)
{
// Intentionally empty.
}
#endif
__syncthreads
();
// Online softmax + V update per row handled by the same warp across tiles.
if
(
row0
<
kBlockM
)
{
PagedAttentionPrefillMmaScoreUpdateRow
<
kWarpSize
,
kBlockN
,
kHeadDim
,
kDimsPerThread
>
(
lane
,
k_base
,
allowed0
,
s_scores
+
row0
*
kBlockN
,
s_v
,
scale_log2
,
alibi_slope_log2
,
m0
,
l0
,
acc0
);
}
if
(
row1
<
kBlockM
)
{
PagedAttentionPrefillMmaScoreUpdateRow
<
kWarpSize
,
kBlockN
,
kHeadDim
,
kDimsPerThread
>
(
lane
,
k_base
,
allowed1
,
s_scores
+
row1
*
kBlockN
,
s_v
,
scale_log2
,
alibi_slope_log2
,
m1
,
l1
,
acc1
);
}
__syncthreads
();
}
// Write outputs.
if
(
row0
<
kBlockM
)
{
PagedAttentionPrefillMmaScoreWriteRow
<
kWarpSize
,
kHeadDim
,
kDimsPerThread
>
(
lane
,
active0
,
m_start
+
row0
,
q_start
,
head_idx
,
out_
,
o_stride
,
o_head_stride
,
l0
,
acc0
);
}
if
(
row1
<
kBlockM
)
{
PagedAttentionPrefillMmaScoreWriteRow
<
kWarpSize
,
kHeadDim
,
kDimsPerThread
>
(
lane
,
active1
,
m_start
+
row1
,
q_start
,
head_idx
,
out_
,
o_stride
,
o_head_stride
,
l1
,
acc1
);
}
}
}
// namespace op::paged_attention_prefill::cuda
#endif
src/infiniop/ops/paged_attention_prefill/info.h
View file @
1c18c046
...
...
@@ -3,6 +3,7 @@
#include "../../../utils.h"
#include "../../tensor.h"
#include <cstring>
#include <iostream>
#include <optional>
#include <vector>
...
...
@@ -14,21 +15,30 @@ class PagedAttentionPrefillInfo {
public:
infiniDtype_t
dtype
;
infiniDtype_t
index_dtype
;
float
scale
;
size_t
num_seqs
;
size_t
total_q_tokens
;
size_t
num_heads
;
size_t
num_kv_heads
;
size_t
head_size
;
size_t
block_size
;
size_t
page_
block_size
;
size_t
max_num_blocks_per_seq
;
size_t
total_q_token
s
;
size_t
num_block
s
;
ptrdiff_t
q_stride
;
ptrdiff_t
q_head_stride
;
ptrdiff_t
kv_block_stride
;
ptrdiff_t
kv_head_stride
;
ptrdiff_t
k_batch_stride
;
ptrdiff_t
k_row_stride
;
ptrdiff_t
k_head_stride
;
ptrdiff_t
v_batch_stride
;
ptrdiff_t
v_row_stride
;
ptrdiff_t
v_head_stride
;
ptrdiff_t
o_stride
;
ptrdiff_t
o_head_stride
;
ptrdiff_t
block_table_batch_stride
;
static
utils
::
Result
<
PagedAttentionPrefillInfo
>
create
(
infiniopTensorDescriptor_t
out_desc
,
...
...
@@ -36,89 +46,161 @@ public:
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cum_seq
_
lens_q_desc
,
infiniopTensorDescriptor_t
total_kv
_lens_desc
,
infiniopTensorDescriptor_t
cum_seqlens_q_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
dtype
=
q_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
);
if
(
out_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
cum_seq_lens_q_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
seq_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
// q/out: [total_q, heads, head_dim]
if
(
q_desc
->
ndim
()
!=
3
||
out_desc
->
ndim
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// FA2 paged KV layout: [num_blocks, page_block_size, kv_heads, head_dim]
if
(
k_cache_desc
->
ndim
()
!=
4
||
v_cache_desc
->
ndim
()
!=
4
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
block_tables_desc
->
ndim
()
!=
2
||
total_kv_lens_desc
->
ndim
()
!=
1
||
cum_seqlens_q_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
CHECK_OR_RETURN
(
q_desc
->
stride
(
2
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
out_desc
->
stride
(
2
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
k_cache_desc
->
stride
(
3
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
v_cache_desc
->
stride
(
3
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
// Index dtypes: allow I32/I64/U32 (v0.4 roadmap allows internal conversion to I32).
const
auto
block_tables_dt
=
block_tables_desc
->
dtype
();
if
(
!
((
block_tables_dt
==
INFINI_DTYPE_I64
)
||
(
block_tables_dt
==
INFINI_DTYPE_I32
)
||
(
block_tables_dt
==
INFINI_DTYPE_U32
)))
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now
// (matches current paged_attention_prefill signature). We will convert to int32 internally later.
if
(
total_kv_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
cum_seqlens_q_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
CHECK_OR_RETURN
(
block_tables_desc
->
stride
(
1
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
if
(
alibi_slopes_desc
.
has_value
()
&&
alibi_slopes_desc
.
value
()
!=
nullptr
)
{
if
(
alibi_slopes_desc
.
value
()
->
dtype
()
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
alibi_slopes_desc
.
value
()
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
CHECK_OR_RETURN
(
alibi_slopes_desc
.
value
()
->
stride
(
0
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
}
auto
k_shape
=
k_cache_desc
->
shape
();
auto
v_shape
=
v_cache_desc
->
shape
();
auto
block_tables_shape
=
block_tables_desc
->
shape
();
auto
seq_lens_shape
=
seq_lens_desc
->
shape
();
auto
cum_seq_lens_q_shape
=
cum_seq_lens_q_desc
->
shape
();
const
auto
q_shape
=
q_desc
->
shape
();
const
auto
k_shape
=
k_cache_desc
->
shape
();
if
(
k_shape
.
size
()
!=
4
||
v_shape
.
size
()
!=
4
)
{
const
size_t
total_q_tokens
=
q_shape
[
0
];
const
size_t
num_heads
=
q_shape
[
1
];
const
size_t
head_size
=
q_shape
[
2
];
const
size_t
num_blocks
=
k_shape
[
0
];
const
size_t
page_block_size
=
k_shape
[
2
];
const
size_t
num_kv_heads
=
k_shape
[
1
];
if
(
head_size
!=
64
&&
head_size
!=
128
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
num_heads
%
num_kv_heads
!=
0
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
block_tables_shape
.
size
()
!=
2
)
{
// v_cache must match the inferred K layout.
const
auto
v_shape
=
v_cache_desc
->
shape
();
if
(
v_shape
[
0
]
!=
num_blocks
||
v_shape
[
3
]
!=
head_size
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
seq_lens_shape
.
size
()
!=
1
||
cum_seq_lens_q_shape
.
size
()
!=
1
)
{
if
(
v_shape
[
1
]
!=
num_kv_heads
||
v_shape
[
2
]
!=
page_block_size
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
cum_seq_lens_q_
shape
[
0
]
!=
seq_lens_shape
[
0
]
+
1
)
{
return
INFINI_STATUS_BAD_
PARAM
;
if
(
v_cache_desc
->
shape
()
[
0
]
!=
k_shape
[
0
]
||
v_cache_desc
->
shape
()[
3
]
!=
k_shape
[
3
]
)
{
return
INFINI_STATUS_BAD_
TENSOR_SHAPE
;
}
// Q shape: [total_tokens, heads, dim]
auto
q_shape
=
q_desc
->
shape
();
if
(
q_shape
.
size
()
!=
3
)
{
if
(
out_desc
->
shape
()[
0
]
!=
q_shape
[
0
]
||
out_desc
->
shape
()[
1
]
!=
q_shape
[
1
]
||
out_desc
->
shape
()[
2
]
!=
q_shape
[
2
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
total_q_tokens
=
q_shape
[
0
];
size_t
num_heads
=
q_shape
[
1
];
size_t
head_size
=
q_shape
[
2
];
if
(
head_size
>
1024
)
{
const
size_t
num_seqs
=
total_kv_lens_desc
->
shape
()[
0
];
if
(
cum_seqlens_q_desc
->
shape
()[
0
]
!=
num_seqs
+
1
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
size_t
num_seqs
=
seq_lens_shape
[
0
];
size_t
num_kv_heads
=
k_shape
[
1
];
size_t
block_size
=
k_shape
[
2
];
size_t
max_num_blocks_per_seq
=
block_tables_shape
[
1
];
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
q_head_stride
=
q_desc
->
stride
(
1
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
const
size_t
max_num_blocks_per_seq
=
block_tables_desc
->
shape
()[
1
];
// Strides (in elements)
const
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
const
ptrdiff_t
q_head_stride
=
q_desc
->
stride
(
1
);
const
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
const
ptrdiff_t
o_head_stride
=
out_desc
->
stride
(
1
);
const
ptrdiff_t
k_batch_stride
=
k_cache_desc
->
stride
(
0
);
const
ptrdiff_t
k_row_stride
=
k_cache_desc
->
stride
(
2
);
const
ptrdiff_t
k_head_stride
=
k_cache_desc
->
stride
(
1
);
const
ptrdiff_t
v_batch_stride
=
v_cache_desc
->
stride
(
0
);
const
ptrdiff_t
v_row_stride
=
v_cache_desc
->
stride
(
2
);
const
ptrdiff_t
v_head_stride
=
v_cache_desc
->
stride
(
1
);
const
ptrdiff_t
block_table_batch_stride
=
block_tables_desc
->
stride
(
0
);
if
(
const
char
*
dbg
=
std
::
getenv
(
"INFINIOP_DEBUG_PREFILL_INFO"
))
{
static
bool
printed
=
false
;
if
(
!
printed
&&
std
::
strcmp
(
dbg
,
"1"
)
==
0
)
{
const
auto
bt_shape
=
block_tables_desc
->
shape
();
std
::
fprintf
(
stderr
,
"[infiniop][flash_attention_prefill][info] k_shape=[%zu,%zu,%zu,%zu] k_strides=[%td,%td,%td,%td] (row_stride=%td head_stride=%td)
\n
"
,
static_cast
<
size_t
>
(
k_shape
[
0
]),
static_cast
<
size_t
>
(
k_shape
[
1
]),
static_cast
<
size_t
>
(
k_shape
[
2
]),
static_cast
<
size_t
>
(
k_shape
[
3
]),
k_cache_desc
->
stride
(
0
),
k_cache_desc
->
stride
(
1
),
k_cache_desc
->
stride
(
2
),
k_cache_desc
->
stride
(
3
),
k_row_stride
,
k_head_stride
);
std
::
fprintf
(
stderr
,
"[infiniop][flash_attention_prefill][info] block_tables shape=[%zu,%zu] strides=[%td,%td]
\n
"
,
static_cast
<
size_t
>
(
bt_shape
[
0
]),
static_cast
<
size_t
>
(
bt_shape
[
1
]),
block_tables_desc
->
stride
(
0
),
block_tables_desc
->
stride
(
1
));
printed
=
true
;
}
}
return
utils
::
Result
<
PagedAttentionPrefillInfo
>
(
PagedAttentionPrefillInfo
{
dtype
,
block_tables_dt
,
scale
,
num_seqs
,
total_q_tokens
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
page_
block_size
,
max_num_blocks_per_seq
,
total_q_token
s
,
num_block
s
,
q_stride
,
q_head_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
});
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
block_table_batch_stride
,
});
}
};
}
// namespace op::paged_attention_prefill
#endif
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
1c18c046
#include <cuda_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_nvidia.cuh"
template
<
typename
Tdata
,
typename
Tcompute
>
infiniStatus_t
launchPagedAttentionPrefill
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
int64_t
*
block_tables
,
const
int64_t
*
seq_lens
,
const
int64_t
*
cum_seq_lens_q
,
const
float
*
alibi_slopes
,
const
size_t
num_heads
,
const
size_t
num_seqs
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
size_t
total_q_tokens
,
const
size_t
head_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
q_head_stride
,
#include "../cuda/kernel_v2.cuh"
namespace
op
::
paged_attention_prefill
::
nvidia
{
namespace
{
constexpr
size_t
ceilDiv
(
size_t
a
,
size_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
const
char
*
default_prefill_kernel
(
const
PagedAttentionPrefillInfo
&
info
)
{
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if
(
info
.
page_block_size
==
256
&&
(
info
.
dtype
==
INFINI_DTYPE_F16
||
info
.
dtype
==
INFINI_DTYPE_BF16
))
{
if
(
info
.
head_size
==
128
)
{
return
"warpcta8pipe"
;
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return
"warpcta8"
;
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128Warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpKernel
<
Tindex
,
Tdata
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64Warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpKernel
<
Tindex
,
Tdata
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 4 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
128
,
4
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 4 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
64
,
4
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
128
,
8
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8N128
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelKOnly
<
Tindex
,
Tdata
,
128
,
8
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta8
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
64
,
8
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8Pipe
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelined
<
Tindex
,
Tdata
,
128
,
8
,
32
,
2
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8Mma
(
half
*
out
,
const
half
*
q
,
const
half
*
k_cache
,
const
half
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCta8MmaHd128Kernel
<
Tindex
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta8Pipe
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelined
<
Tindex
,
Tdata
,
64
,
8
,
32
,
2
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8PipeSplitKv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
)
{
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const
int
num_m_blocks
=
static_cast
<
int
>
((
total_q_tokens
+
8
-
1
)
/
8
);
const
int
bz
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int
split_idx
=
bz
/
num_m_blocks
;
const
int
m_block
=
bz
-
split_idx
*
num_m_blocks
;
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv
<
Tindex
,
Tdata
,
128
,
8
,
32
,
2
>
(
partial_acc
,
partial_m
,
partial_l
,
split_idx
,
num_splits
,
m_block
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta8PipeSplitKv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
)
{
const
int
num_m_blocks
=
static_cast
<
int
>
((
total_q_tokens
+
8
-
1
)
/
8
);
const
int
bz
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int
split_idx
=
bz
/
num_m_blocks
;
const
int
m_block
=
bz
-
split_idx
*
num_m_blocks
;
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv
<
Tindex
,
Tdata
,
64
,
8
,
32
,
2
>
(
partial_acc
,
partial_m
,
partial_l
,
split_idx
,
num_splits
,
m_block
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
}
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128SplitKvCombine
(
Tdata
*
out
,
const
float
*
partial_acc
,
const
float
*
partial_m
,
const
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillSplitKvCombineWarpKernel
<
Tdata
,
128
>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64SplitKvCombine
(
Tdata
*
out
,
const
float
*
partial_acc
,
const
float
*
partial_m
,
const
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillSplitKvCombineWarpKernel
<
Tdata
,
64
>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 16 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
128
,
16
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 16 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
64
,
16
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
,
typename
Tcompute
>
infiniStatus_t
launch_prefill_ref
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
if
(
total_q_tokens
==
0
||
num_heads
==
0
)
{
const
dim3
grid
(
static_cast
<
uint32_t
>
(
total_q_tokens
),
static_cast
<
uint32_t
>
(
num_heads
),
1
);
const
dim3
block
(
static_cast
<
uint32_t
>
(
head_size
),
1
,
1
);
if
(
head_size
==
64
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillReferenceKernel
<
Tindex
,
Tdata
,
Tcompute
,
64
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
num_seqs
);
return
INFINI_STATUS_SUCCESS
;
}
if
(
head_size
==
128
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillReferenceKernel
<
Tindex
,
Tdata
,
Tcompute
,
128
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
num_seqs
);
return
INFINI_STATUS_SUCCESS
;
}
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
const
dim3
block
(
32
,
1
,
1
);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
total_q_tokens
),
1
);
switch
(
head_size
)
{
case
64
:
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpGlobalKernel
<
Tindex
,
Tdata
,
64
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpGlobalKernel
<
Tindex
,
Tdata
,
128
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
dim3
grid
(
total_q_tokens
,
num_heads
);
dim3
block
(
head_size
);
constexpr
int
kWarps
=
4
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
op
::
paged_attention_prefill
::
cuda
::
pagedAttentionPrefillKernel
<
Tdata
,
Tcompute
>
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
cum_seq_lens_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
kv_block_stride
,
kv_head_stride
,
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
head_size
,
num_seqs
);
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta8
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta8
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
namespace
op
::
paged_attention_prefill
::
nvidia
{
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8pipe
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta8Pipe
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta8Pipe
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8mma
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
// Current WMMA kernel only supports fp16 + head_dim=128.
if
constexpr
(
!
std
::
is_same_v
<
Tdata
,
half
>
)
{
return
launch_prefill_warpcta8pipe
<
Tindex
,
Tdata
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
head_size
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
stream
);
}
if
(
head_size
!=
128
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const
char
*
force_env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_MMA_FORCE"
);
const
bool
force_mma
=
(
force_env
!=
nullptr
)
&&
(
std
::
strcmp
(
force_env
,
"1"
)
==
0
);
const
size_t
seqlen_k_est
=
max_num_blocks_per_seq
*
page_block_size
;
if
(
!
force_mma
&&
seqlen_k_est
>
4096
)
{
static
bool
warned
=
false
;
if
(
!
warned
)
{
std
::
fprintf
(
stderr
,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.
\n
"
,
seqlen_k_est
);
warned
=
true
;
}
return
launch_prefill_warpcta8pipe
<
Tindex
,
Tdata
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
head_size
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
stream
);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int
device
=
0
;
cudaDeviceProp
prop
{};
if
(
cudaGetDevice
(
&
device
)
==
cudaSuccess
&&
cudaGetDeviceProperties
(
&
prop
,
device
)
==
cudaSuccess
)
{
if
(
prop
.
major
<
7
)
{
return
launch_prefill_warpcta8pipe
<
Tindex
,
Tdata
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
head_size
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
stream
);
}
}
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
16
))));
PagedAttentionPrefillHd128WarpCta8Mma
<
Tindex
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8pipe_splitkv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
int
num_splits
,
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kMaxSplits
=
8
;
if
(
num_splits
<
1
)
{
num_splits
=
1
;
}
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
size_t
num_m_blocks
=
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
num_m_blocks
*
static_cast
<
size_t
>
(
num_splits
)));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
break
;
case
128
:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
break
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// Combine: one warp per (token, head).
const
dim3
block2
(
32
);
const
dim3
grid2
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
total_q_tokens
),
1
);
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64SplitKvCombine
<
Tdata
>
<<<
grid2
,
block2
,
0
,
stream
>>>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128SplitKvCombine
<
Tdata
>
<<<
grid2
,
block2
,
0
,
stream
>>>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8n128
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
// Only meaningful for head_dim=128.
if
(
head_size
!=
128
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
PagedAttentionPrefillHd128WarpCta8N128
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
16
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta16
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta16
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
}
// namespace
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
...
...
@@ -68,22 +1249,87 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq
_lens_desc
,
infiniopTensorDescriptor_t
cum_seq
_
lens_q_desc
,
infiniopTensorDescriptor_t
total_kv
_lens_desc
,
infiniopTensorDescriptor_t
cum_seqlens_q_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
info
=
PagedAttentionPrefillInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
seq_lens_desc
,
cum_seq_lens_q_desc
,
block_tables_desc
,
total_kv_lens_desc
,
cum_seqlens_q_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info
);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool
use_splitkv
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_SPLITKV"
))
{
use_splitkv
=
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
int
num_splits
=
1
;
if
(
use_splitkv
)
{
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_NUM_SPLITS"
))
{
const
int
v
=
std
::
atoi
(
env
);
if
(
v
>
0
)
{
num_splits
=
v
;
}
}
else
{
num_splits
=
4
;
}
constexpr
int
kMaxSplits
=
8
;
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
}
const
size_t
n
=
info
->
total_q_tokens
*
info
->
num_heads
;
const
size_t
splitkv_workspace_bytes
=
use_splitkv
?
(
static_cast
<
size_t
>
(
num_splits
)
*
n
*
(
info
->
head_size
+
2
)
*
sizeof
(
float
))
:
0
;
// FA2-style kernel needs a workspace scratch for:
// - converting block_tables + total_kv_lens to int32
// - storing softmax LSE (only required to satisfy the upstream kernel contract)
// bool want_fa2 = false;
// if (const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL")) {
// want_fa2 = (std::strcmp(k_env, "fa2") == 0);
// }
// bool fa2_materialize_kv = false;
// if (const char *env = std::getenv("INFINIOP_FA2_MATERIALIZE_PAGED_KV")) {
// fa2_materialize_kv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
// }
// size_t fa2_workspace_bytes = 0;
// // FA2 prefill supports both fp16 and bf16 inputs (head_dim=128, block_size=256).
// // Workspace sizing is identical since both are 16-bit element types.
// if (want_fa2 && (info->dtype == INFINI_DTYPE_F16 || info->dtype == INFINI_DTYPE_BF16) &&
// info->head_size == 128 && info->page_block_size == 256) {
// const size_t bt_bytes = info->num_seqs * info->max_num_blocks_per_seq * sizeof(int);
// const size_t len_bytes = info->num_seqs * sizeof(int);
// const size_t cuq_bytes = (info->num_seqs + 1) * sizeof(int);
// const size_t cuk_bytes = (info->num_seqs + 1) * sizeof(int);
// const size_t lse_bytes = info->num_heads * info->total_q_tokens * sizeof(float);
// // Add a small alignment slack since we sub-allocate with alignment.
// fa2_workspace_bytes = bt_bytes + len_bytes + cuq_bytes + cuk_bytes + lse_bytes + 64;
// // Optional: materialize paged KV into the FA2-friendly physical layout
// // [num_blocks, page_block_size, kv_heads, head_dim] (token-major) to avoid
// // extremely strided loads when the framework stores KV as
// // [num_blocks, kv_heads, page_block_size, head_dim] (head-major).
// if (fa2_materialize_kv) {
// // Materialize per-seq contiguous KV in *sequence order*:
// // [num_seqs, max_num_blocks_per_seq * page_block_size, kv_heads, head_dim].
// const size_t kv_elems =
// info->num_seqs * info->max_num_blocks_per_seq * info->page_block_size * info->num_kv_heads * info->head_size;
// const size_t kv_bytes = kv_elems * sizeof(uint16_t); // 16-bit (fp16/bf16)
// // K + V + alignment slack
// fa2_workspace_bytes += 2 * kv_bytes + 64;
// }
// }
const
size_t
workspace_bytes
=
splitkv_workspace_bytes
;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
info
.
take
(),
workspace_bytes
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
...
...
@@ -92,35 +1338,379 @@ infiniStatus_t Descriptor::calculate(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
seq
_lens
,
const
void
*
cum_seq
_
lens_q
,
const
void
*
total_kv
_lens
,
const
void
*
cum_seqlens_q
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
auto
stream
=
static_cast
<
cudaStream_t
>
(
stream_
);
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
const
float
*
alibi_ptr
=
(
alibi_slopes
==
nullptr
)
?
nullptr
:
static_cast
<
const
float
*>
(
alibi_slopes
);
const
auto
*
total_kv_lens_i64
=
static_cast
<
const
int64_t
*>
(
total_kv_lens
);
const
auto
*
cu_seqlens_q_i64
=
static_cast
<
const
int64_t
*>
(
cum_seqlens_q
);
bool
use_splitkv
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_SPLITKV"
))
{
use_splitkv
=
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
int
num_splits
=
1
;
if
(
use_splitkv
)
{
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_NUM_SPLITS"
))
{
const
int
v
=
std
::
atoi
(
env
);
if
(
v
>
0
)
{
num_splits
=
v
;
}
}
else
{
// Conservative default; users can override.
num_splits
=
4
;
}
constexpr
int
kMaxSplits
=
8
;
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
const
size_t
n
=
_info
.
total_q_tokens
*
_info
.
num_heads
;
const
size_t
required
=
static_cast
<
size_t
>
(
num_splits
)
*
n
*
(
_info
.
head_size
+
2
)
*
sizeof
(
float
);
if
(
workspace_size
<
required
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
}
if
(
use_splitkv
)
{
const
size_t
n
=
_info
.
total_q_tokens
*
_info
.
num_heads
;
float
*
partial_acc
=
static_cast
<
float
*>
(
workspace
);
float
*
partial_m
=
partial_acc
+
static_cast
<
size_t
>
(
num_splits
)
*
n
*
_info
.
head_size
;
float
*
partial_l
=
partial_m
+
static_cast
<
size_t
>
(
num_splits
)
*
n
;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
return
LAUNCH_KERNEL
(
half
,
float
);
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_BF16
)
{
return
LAUNCH_KERNEL
(
__nv_bfloat16
,
float
);
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_F32
)
{
return
LAUNCH_KERNEL
(
float
,
float
);
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I64
)
{
DISPATCH_SPLITKV
(
int64_t
,
half
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I32
)
{
DISPATCH_SPLITKV
(
int32_t
,
half
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_U32
)
{
DISPATCH_SPLITKV
(
uint32_t
,
half
,
block_tables
);
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
_info
.
dtype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I64
)
{
DISPATCH_SPLITKV
(
int64_t
,
__nv_bfloat16
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I32
)
{
DISPATCH_SPLITKV
(
int32_t
,
__nv_bfloat16
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_U32
)
{
DISPATCH_SPLITKV
(
uint32_t
,
__nv_bfloat16
,
block_tables
);
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I64
)
{
DISPATCH_INDEX
(
int64_t
);
}
else
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I32
)
{
DISPATCH_INDEX
(
int32_t
);
}
else
if
(
_info
.
index_dtype
==
INFINI_DTYPE_U32
)
{
DISPATCH_INDEX
(
uint32_t
);
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
// namespace op::paged_attention_prefill::nvidia
// #include <cuda_fp16.h>
// #include <float.h>
// #include <math.h>
// #include <stdint.h>
// #include "../../../devices/nvidia/nvidia_common.cuh"
// #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
// #include "../cuda/kernel.cuh"
// #include "paged_attention_prefill_nvidia.cuh"
// template <typename Tdata, typename Tcompute>
// infiniStatus_t launchPagedAttentionPrefill(
// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
// const int64_t *block_tables,
// const int64_t *seq_lens,
// const int64_t *cum_seq_lens_q,
// const float *alibi_slopes,
// const size_t num_heads,
// const size_t num_seqs,
// const size_t num_kv_heads,
// const float scale,
// const size_t max_num_blocks_per_seq,
// const size_t block_size,
// const size_t total_q_tokens,
// const size_t head_size,
// const ptrdiff_t kv_block_stride,
// const ptrdiff_t kv_head_stride,
// const ptrdiff_t q_stride,
// const ptrdiff_t q_head_stride,
// cudaStream_t stream) {
// if (total_q_tokens == 0 || num_heads == 0) {
// return INFINI_STATUS_BAD_TENSOR_SHAPE;
// }
// dim3 grid(total_q_tokens, num_heads);
// dim3 block(head_size);
// op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
// <<<grid, block, 0, stream>>>(
// out, q, k_cache, v_cache,
// block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
// num_heads, num_kv_heads, scale,
// max_num_blocks_per_seq, block_size,
// kv_block_stride, kv_head_stride,
// q_stride, q_head_stride,
// head_size,
// num_seqs);
// return INFINI_STATUS_SUCCESS;
// }
// namespace op::paged_attention_prefill::nvidia {
// struct Descriptor::Opaque {
// std::shared_ptr<device::nvidia::Handle::Internal> internal;
// };
// Descriptor::~Descriptor() {
// delete _opaque;
// }
// infiniStatus_t Descriptor::create(
// infiniopHandle_t handle,
// Descriptor **desc_ptr,
// infiniopTensorDescriptor_t out_desc,
// infiniopTensorDescriptor_t q_desc,
// infiniopTensorDescriptor_t k_cache_desc,
// infiniopTensorDescriptor_t v_cache_desc,
// infiniopTensorDescriptor_t block_tables_desc,
// infiniopTensorDescriptor_t seq_lens_desc,
// infiniopTensorDescriptor_t cum_seq_lens_q_desc,
// const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
// float scale) {
// auto info = PagedAttentionPrefillInfo::create(
// out_desc, q_desc, k_cache_desc, v_cache_desc,
// block_tables_desc, seq_lens_desc,
// cum_seq_lens_q_desc,
// alibi_slopes_desc, scale);
// CHECK_RESULT(info);
// *desc_ptr = new Descriptor(
// new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
// info.take(), 0, handle->device, handle->device_id);
// return INFINI_STATUS_SUCCESS;
// }
// infiniStatus_t Descriptor::calculate(
// void *workspace, size_t workspace_size,
// void *out, const void *q, const void *k_cache, const void *v_cache,
// const void *block_tables,
// const void *seq_lens,
// const void *cum_seq_lens_q,
// const void *alibi_slopes,
// void *stream_) const {
// cudaStream_t stream = (cudaStream_t)stream_;
// #define LAUNCH_KERNEL(Tdata, Tcompute) \
// launchPagedAttentionPrefill<Tdata, Tcompute>( \
// (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
// (const float *)alibi_slopes, \
// _info.num_heads, _info.num_seqs, _info.num_kv_heads, \
// _info.scale, _info.max_num_blocks_per_seq, \
// _info.block_size, _info.total_q_tokens, \
// _info.head_size, \
// _info.kv_block_stride, _info.kv_head_stride, \
// _info.q_stride, _info.q_head_stride, \
// stream)
// if (_info.dtype == INFINI_DTYPE_F16) {
// return LAUNCH_KERNEL(half, float);
// } else if (_info.dtype == INFINI_DTYPE_BF16) {
// return LAUNCH_KERNEL(__nv_bfloat16, float);
// } else if (_info.dtype == INFINI_DTYPE_F32) {
// return LAUNCH_KERNEL(float, float);
// }
// return INFINI_STATUS_BAD_TENSOR_DTYPE;
// }
// } // namespace op::paged_attention_prefill::nvidia
test/infiniop/paged_attention.py
View file @
1c18c046
...
...
@@ -100,13 +100,12 @@ _TEST_CASES_ = [
]
# Data types for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
,
InfiniDtype
.
F32
]
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5e-2
},
InfiniDtype
.
F32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
}
# Global flags for controlling test behavior
...
...
test/infiniop/paged_attention_prefill.py
View file @
1c18c046
...
...
@@ -32,10 +32,9 @@ _TEST_CASES = [
(
16
,
128
,
128
,
128
,
8
,
16
,
4
),
]
_TENSOR_DTYPES
=
[
InfiniDtype
.
F32
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
]
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
]
_TOLERANCE_MAP
=
{
InfiniDtype
.
F32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
InfiniDtype
.
F16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
2e-2
,
"rtol"
:
2e-2
},
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment