Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5a683756
Commit
5a683756
authored
Jan 12, 2025
by
Po Yen, Chen
Browse files
Re-format kernel
parent
b618806b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1030 additions
and
910 deletions
+1030
-910
example/ck_tile/18_paged_attention/include/kernel/paged_attention_kernel.hpp
...paged_attention/include/kernel/paged_attention_kernel.hpp
+1030
-910
No files found.
example/ck_tile/18_paged_attention/include/kernel/paged_attention_kernel.hpp
View file @
5a683756
...
@@ -24,41 +24,42 @@
...
@@ -24,41 +24,42 @@
#include "attention/dtype_fp8.cuh"
#include "attention/dtype_fp8.cuh"
#include "quantization/fp8/amd/quant_utils.cuh"
#include "quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) &&
(defined(__gfx90a__) || defined(__gfx940__) ||
\
#if defined(__HIPCC__) && \
defined(__gfx941__) || defined(__gfx942__))
(defined(__gfx90a__) || defined(__gfx940__) ||
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#define __HIP__MI300_MI250__
#endif
#endif
#if defined(NDEBUG)
#if defined(NDEBUG)
#undef NDEBUG
#undef NDEBUG
#include <assert.h>
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#define NDEBUG
#else
#else
#define UNREACHABLE_CODE assert(false);
#define UNREACHABLE_CODE assert(false);
#endif
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b)
-
1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b)
-
1) / (b))
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
using
floatx4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float
))))
float
;
using
floatx4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float
))))
float
;
using
float16x4
=
using
float16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
_Float16
))))
_Float16
;
__attribute__
((
__vector_size__
(
4
*
sizeof
(
_Float16
))))
_Float16
;
typedef
float16x4
_Half4
;
typedef
float16x4
_Half4
;
typedef
struct
_Half8
{
typedef
struct
_Half8
{
_Half4
xy
[
2
];
_Half4
xy
[
2
];
}
_Half8
;
}
_Half8
;
using
bit16_t
=
uint16_t
;
using
bit16_t
=
uint16_t
;
using
bit16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
uint16_t
))))
uint16_t
;
using
bit16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
uint16_t
))))
uint16_t
;
typedef
bit16x4
_B16x4
;
typedef
bit16x4
_B16x4
;
typedef
struct
_B16x8
{
typedef
struct
_B16x8
{
_B16x4
xy
[
2
];
_B16x4
xy
[
2
];
}
_B16x8
;
}
_B16x8
;
...
@@ -68,144 +69,191 @@ using bit8_t = uint8_t;
...
@@ -68,144 +69,191 @@ using bit8_t = uint8_t;
////// Non temporal load stores ///////
////// Non temporal load stores ///////
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
load
(
T
*
addr
)
{
__device__
__forceinline__
T
load
(
T
*
addr
)
{
return
addr
[
0
];
return
addr
[
0
];
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
void
store
(
T
value
,
T
*
addr
)
{
__device__
__forceinline__
void
store
(
T
value
,
T
*
addr
)
{
addr
[
0
]
=
value
;
addr
[
0
]
=
value
;
}
}
template
<
typename
T
,
int
absz
,
int
cbid
,
int
blgp
>
template
<
typename
T
,
int
absz
,
int
cbid
,
int
blgp
>
__device__
__forceinline__
floatx4
gcn_mfma_instr
(
const
_B16x4
&
inpA
,
__device__
__forceinline__
floatx4
gcn_mfma_instr
(
const
_B16x4
&
inpA
,
const
_B16x4
&
inpB
,
const
_B16x4
&
inpB
,
const
floatx4
&
inpC
)
{
const
floatx4
&
inpC
)
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
{
return
__builtin_amdgcn_mfma_f32_4x4x4f16
(
inpA
,
inpB
,
inpC
,
absz
,
cbid
,
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
blgp
);
{
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__builtin_amdgcn_mfma_f32_4x4x4f16
(
inpA
,
inpB
,
inpC
,
absz
,
cbid
,
blgp
);
return
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
inpA
,
inpB
,
inpC
,
absz
,
cbid
,
}
blgp
);
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
}
else
{
{
return
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
inpA
,
inpB
,
inpC
,
absz
,
cbid
,
blgp
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
float
to_float
(
const
T
&
inp
)
{
__device__
__forceinline__
float
to_float
(
const
T
&
inp
)
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
float
)
inp
;
return
(
float
)
inp
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__bfloat162float
(
inp
);
return
__bfloat162float
(
inp
);
}
else
{
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
float
to_float_b16
(
const
bit16_t
&
inp
)
{
__device__
__forceinline__
float
to_float_b16
(
const
bit16_t
&
inp
)
union
tmpcvt
{
{
union
tmpcvt
{
bit16_t
u
;
bit16_t
u
;
_Float16
f
;
_Float16
f
;
__hip_bfloat16
b
;
__hip_bfloat16
b
;
}
t16
;
}
t16
;
t16
.
u
=
inp
;
t16
.
u
=
inp
;
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
float
)
t16
.
f
;
return
(
float
)
t16
.
f
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__bfloat162float
(
t16
.
b
);
return
__bfloat162float
(
t16
.
b
);
}
else
{
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
from_float
(
const
float
&
inp
)
{
__device__
__forceinline__
T
from_float
(
const
float
&
inp
)
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
_Float16
)
inp
;
return
(
_Float16
)
inp
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__float2bfloat16
(
inp
);
return
__float2bfloat16
(
inp
);
}
else
{
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
_B16x4
from_floatx4
(
const
floatx4
&
inp
)
{
__device__
__forceinline__
_B16x4
from_floatx4
(
const
floatx4
&
inp
)
union
tmpcvt
{
{
union
tmpcvt
{
uint16_t
u
;
uint16_t
u
;
_Float16
f
;
_Float16
f
;
__hip_bfloat16
b
;
__hip_bfloat16
b
;
}
t16
;
}
t16
;
_B16x4
ret
;
_B16x4
ret
;
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
#pragma unroll
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
t16
.
f
=
(
_Float16
)
inp
[
i
];
t16
.
f
=
(
_Float16
)
inp
[
i
];
ret
[
i
]
=
t16
.
u
;
ret
[
i
]
=
t16
.
u
;
}
}
return
ret
;
return
ret
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
}
#pragma unroll
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
t16
.
b
=
__float2bfloat16
(
inp
[
i
]);
t16
.
b
=
__float2bfloat16
(
inp
[
i
]);
ret
[
i
]
=
t16
.
u
;
ret
[
i
]
=
t16
.
u
;
}
}
return
ret
;
return
ret
;
}
else
{
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
_B16x4
addx4
(
const
_B16x4
&
inp1
,
__device__
__forceinline__
_B16x4
addx4
(
const
_B16x4
&
inp1
,
const
_B16x4
&
inp2
)
const
_B16x4
&
inp2
)
{
{
union
tmpcvt
{
union
tmpcvt
{
uint16_t
u
;
uint16_t
u
;
_Float16
f
;
_Float16
f
;
__hip_bfloat16
b
;
__hip_bfloat16
b
;
}
t1
,
t2
,
res
;
}
t1
,
t2
,
res
;
_B16x4
ret
;
_B16x4
ret
;
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
#pragma unroll
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
t1
.
u
=
inp1
[
i
];
t1
.
u
=
inp1
[
i
];
t2
.
u
=
inp2
[
i
];
t2
.
u
=
inp2
[
i
];
res
.
f
=
t1
.
f
+
t2
.
f
;
res
.
f
=
t1
.
f
+
t2
.
f
;
ret
[
i
]
=
res
.
u
;
ret
[
i
]
=
res
.
u
;
}
}
return
ret
;
return
ret
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
}
#pragma unroll
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
t1
.
u
=
inp1
[
i
];
t1
.
u
=
inp1
[
i
];
t2
.
u
=
inp2
[
i
];
t2
.
u
=
inp2
[
i
];
res
.
b
=
t1
.
b
+
t2
.
b
;
res
.
b
=
t1
.
b
+
t2
.
b
;
ret
[
i
]
=
res
.
u
;
ret
[
i
]
=
res
.
u
;
}
}
return
ret
;
return
ret
;
}
else
{
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
}
}
template
<
typename
T
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
>
template
<
typename
T
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
>
__device__
__forceinline__
_B16x8
scaled_convert_b8x8
(
const
_B8x8
input
,
__device__
__forceinline__
_B16x8
scaled_convert_b8x8
(
const
_B8x8
input
,
const
float
scale
)
const
float
scale
)
{
{
union
alignas
(
16
)
{
union
alignas
(
16
)
{
uint4
u4
;
uint4
u4
;
_B16x8
u16x8
;
_B16x8
u16x8
;
vllm
::
bf16_8_t
b16x8
;
vllm
::
bf16_8_t
b16x8
;
}
tmp
;
}
tmp
;
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
tmp
.
u4
=
vllm
::
fp8
::
scaled_convert
<
uint4
,
_B8x8
,
KV_DTYPE
>
(
input
,
scale
);
tmp
.
u4
=
vllm
::
fp8
::
scaled_convert
<
uint4
,
_B8x8
,
KV_DTYPE
>
(
input
,
scale
);
return
tmp
.
u16x8
;
return
tmp
.
u16x8
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
}
tmp
.
b16x8
=
vllm
::
fp8
::
scaled_convert
<
vllm
::
bf16_8_t
,
_B8x8
,
KV_DTYPE
>
(
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
input
,
scale
);
{
tmp
.
b16x8
=
vllm
::
fp8
::
scaled_convert
<
vllm
::
bf16_8_t
,
_B8x8
,
KV_DTYPE
>
(
input
,
scale
);
return
tmp
.
u16x8
;
return
tmp
.
u16x8
;
}
else
{
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
}
}
...
@@ -214,9 +262,13 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
...
@@ -214,9 +262,13 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// block (partition size)
// block (partition size)
template
<
typename
scalar_t
,
typename
cache_t
,
template
<
typename
scalar_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
typename
cache_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
GQA_RATIO
>
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_kernel
(
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
...
@@ -224,20 +276,26 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -224,20 +276,26 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
// head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
float
k_scale
,
float
v_scale
,
int
max_ctx_blocks
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
float
k_scale
,
float
v_scale
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
constexpr
int
NWARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NWARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
...
@@ -251,11 +309,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -251,11 +309,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
partition_start_token_idx
=
partition_idx
*
partition_size
;
const
int
partition_start_token_idx
=
partition_idx
*
partition_size
;
// exit if partition is out of context for seq
// exit if partition is out of context for seq
if
(
partition_start_token_idx
>=
context_len
)
{
if
(
partition_start_token_idx
>=
context_len
)
{
return
;
return
;
}
}
constexpr
int
QHLOOP
=
constexpr
int
QHLOOP
=
DIVIDE_ROUND_UP
(
GQA_RATIO
,
4
);
// each 4 lanes fetch 4 different qheads,
DIVIDE_ROUND_UP
(
GQA_RATIO
,
4
);
// each 4 lanes fetch 4 different qheads,
// total qheads =8, so qhloop is 2
// total qheads =8, so qhloop is 2
constexpr
int
GQA_RATIO4
=
4
*
QHLOOP
;
constexpr
int
GQA_RATIO4
=
4
*
QHLOOP
;
__shared__
float
shared_qk_max
[
NWARPS
][
GQA_RATIO4
+
1
];
__shared__
float
shared_qk_max
[
NWARPS
][
GQA_RATIO4
+
1
];
...
@@ -266,16 +324,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -266,16 +324,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
_B16x8
Klocal
[
KHELOOP
];
_B16x8
Klocal
[
KHELOOP
];
_B8x8
Klocalb8
[
KHELOOP
];
_B8x8
Klocalb8
[
KHELOOP
];
constexpr
int
VHELOOP
=
constexpr
int
VHELOOP
=
HEAD_SIZE
/
HEAD_SIZE
/
WARP_SIZE
;
// v head_size dimension is distributed across lanes
WARP_SIZE
;
// v head_size dimension is distributed across lanes
constexpr
int
VTLOOP
=
8
;
// 16 separate 4xtokens across warp -> 16/2
constexpr
int
VTLOOP
=
8
;
// 16 separate 4xtokens across warp -> 16/2
// 8xtokens
// 8xtokens
_B16x8
Vlocal
[
VHELOOP
][
VTLOOP
];
_B16x8
Vlocal
[
VHELOOP
][
VTLOOP
];
_B8x8
Vlocalb8
[
VHELOOP
][
VTLOOP
];
_B8x8
Vlocalb8
[
VHELOOP
][
VTLOOP
];
floatx4
dout
[
QHLOOP
];
floatx4
dout
[
QHLOOP
];
float
qk_max
[
QHLOOP
];
float
qk_max
[
QHLOOP
];
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
dout
[
h
]
=
{
0
};
dout
[
h
]
=
{
0
};
qk_max
[
h
]
=
-
FLT_MAX
;
qk_max
[
h
]
=
-
FLT_MAX
;
}
}
...
@@ -283,16 +341,19 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -283,16 +341,19 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const
int
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
int
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
int
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
int
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
int
warp_start_token_idx
=
const
int
warp_start_token_idx
=
partition_start_token_idx
+
warpid
*
WARP_SIZE
;
partition_start_token_idx
+
warpid
*
WARP_SIZE
;
if
(
warp_start_token_idx
>=
context_len
)
{
// warp out of context
if
(
warp_start_token_idx
>=
context_len
)
#pragma unroll
{
// warp out of context
for
(
int
h
=
0
;
h
<
GQA_RATIO4
;
h
++
)
{
#pragma unroll
for
(
int
h
=
0
;
h
<
GQA_RATIO4
;
h
++
)
{
shared_qk_max
[
warpid
][
h
]
=
-
FLT_MAX
;
shared_qk_max
[
warpid
][
h
]
=
-
FLT_MAX
;
shared_exp_sum
[
warpid
][
h
]
=
0.0
f
;
shared_exp_sum
[
warpid
][
h
]
=
0.0
f
;
}
}
}
else
{
// warp within context
}
else
{
// warp within context
const
int
num_context_blocks
=
DIVIDE_ROUND_UP
(
context_len
,
BLOCK_SIZE
);
const
int
num_context_blocks
=
DIVIDE_ROUND_UP
(
context_len
,
BLOCK_SIZE
);
const
int
last_ctx_block
=
num_context_blocks
-
1
;
const
int
last_ctx_block
=
num_context_blocks
-
1
;
...
@@ -302,23 +363,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -302,23 +363,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const
int
local_token_idx
=
threadIdx
.
x
;
const
int
local_token_idx
=
threadIdx
.
x
;
const
int
global_token_idx
=
partition_start_token_idx
+
local_token_idx
;
const
int
global_token_idx
=
partition_start_token_idx
+
local_token_idx
;
const
int
block_idx
=
(
global_token_idx
<
context_len
)
const
int
block_idx
=
?
global_token_idx
/
BLOCK_SIZE
(
global_token_idx
<
context_len
)
?
global_token_idx
/
BLOCK_SIZE
:
last_ctx_block
;
:
last_ctx_block
;
// fetch block number for q and k
// fetch block number for q and k
// int32 physical_block_number leads to overflow when multiplied with
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
// kv_block_stride
const
int64_t
physical_block_number
=
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// fetch vphysical block numbers up front
// fetch vphysical block numbers up front
constexpr
int
VBLOCKS
=
8
*
VTLOOP
/
BLOCK_SIZE
;
constexpr
int
VBLOCKS
=
8
*
VTLOOP
/
BLOCK_SIZE
;
int
vphysical_blocks
[
VBLOCKS
];
int
vphysical_blocks
[
VBLOCKS
];
const
int
warp_start_block_idx
=
warp_start_token_idx
/
BLOCK_SIZE
;
const
int
warp_start_block_idx
=
warp_start_token_idx
/
BLOCK_SIZE
;
if
constexpr
(
GQA_RATIO
<
12
)
{
if
constexpr
(
GQA_RATIO
<
12
)
#pragma unroll
{
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
#pragma unroll
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
const
int
vblock_idx
=
warp_start_block_idx
+
b
;
const
int
vblock_idx
=
warp_start_block_idx
+
b
;
const
int
vblock_idx_ctx
=
const
int
vblock_idx_ctx
=
(
vblock_idx
<=
last_ctx_block
)
?
vblock_idx
:
last_ctx_block
;
(
vblock_idx
<=
last_ctx_block
)
?
vblock_idx
:
last_ctx_block
;
...
@@ -327,20 +388,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -327,20 +388,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const
scalar_t
*
q_ptr
=
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
wg_start_head_idx
*
HEAD_SIZE
;
q
+
seq_idx
*
q_stride
+
wg_start_head_idx
*
HEAD_SIZE
;
const
_B16x8
*
q_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
q_ptr
);
const
_B16x8
*
q_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
q_ptr
);
const
int
qhead_elemh8
=
laneid
/
4
;
const
int
qhead_elemh8
=
laneid
/
4
;
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
-
1
;
h
++
)
{
for
(
int
h
=
0
;
h
<
QHLOOP
-
1
;
h
++
)
{
const
int
qhead_idx
=
h
*
4
+
lane4id
;
const
int
qhead_idx
=
h
*
4
+
lane4id
;
Qlocal
[
h
]
=
q_ptrh8
[
qhead_idx
*
HEAD_SIZE
/
8
+
qhead_elemh8
];
Qlocal
[
h
]
=
q_ptrh8
[
qhead_idx
*
HEAD_SIZE
/
8
+
qhead_elemh8
];
}
}
const
int
final_qhead_idx
=
4
*
(
QHLOOP
-
1
)
+
lane4id
;
const
int
final_qhead_idx
=
4
*
(
QHLOOP
-
1
)
+
lane4id
;
if
(
final_qhead_idx
<
GQA_RATIO
)
{
if
(
final_qhead_idx
<
GQA_RATIO
)
Qlocal
[
QHLOOP
-
1
]
=
{
q_ptrh8
[
final_qhead_idx
*
HEAD_SIZE
/
8
+
qhead_elemh8
];
Qlocal
[
QHLOOP
-
1
]
=
q_ptrh8
[
final_qhead_idx
*
HEAD_SIZE
/
8
+
qhead_elemh8
];
}
else
{
}
else
{
Qlocal
[
QHLOOP
-
1
].
xy
[
0
]
=
{
0
};
Qlocal
[
QHLOOP
-
1
].
xy
[
0
]
=
{
0
};
Qlocal
[
QHLOOP
-
1
].
xy
[
1
]
=
{
0
};
Qlocal
[
QHLOOP
-
1
].
xy
[
1
]
=
{
0
};
}
}
...
@@ -351,17 +414,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -351,17 +414,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const
int
physical_block_offset
=
const
int
physical_block_offset
=
local_token_idx
%
BLOCK_SIZE
;
// since x=half8, physical_block_offset
local_token_idx
%
BLOCK_SIZE
;
// since x=half8, physical_block_offset
// is already cast as _H8
// is already cast as _H8
if
constexpr
(
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
const
_B16x8
*
k_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
k_ptr
);
const
_B16x8
*
k_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
k_ptr
);
#pragma unroll
#pragma unroll
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
Klocal
[
d
]
=
k_ptrh8
[
d
*
BLOCK_SIZE
+
physical_block_offset
];
Klocal
[
d
]
=
k_ptrh8
[
d
*
BLOCK_SIZE
+
physical_block_offset
];
}
}
}
else
{
}
else
{
constexpr
int
X
=
16
/
sizeof
(
cache_t
);
constexpr
int
X
=
16
/
sizeof
(
cache_t
);
const
cache_t
*
k_ptr2
=
k_ptr
+
physical_block_offset
*
X
;
const
cache_t
*
k_ptr2
=
k_ptr
+
physical_block_offset
*
X
;
#pragma unroll
#pragma unroll
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
const
int
head_elem
=
d
*
8
;
const
int
head_elem
=
d
*
8
;
const
int
offset1
=
head_elem
/
X
;
const
int
offset1
=
head_elem
/
X
;
const
int
offset2
=
head_elem
%
X
;
const
int
offset2
=
head_elem
%
X
;
...
@@ -371,20 +439,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -371,20 +439,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
float
alibi_slope
[
QHLOOP
];
float
alibi_slope
[
QHLOOP
];
if
(
alibi_slopes
!=
nullptr
)
{
if
(
alibi_slopes
!=
nullptr
)
#pragma unroll
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
const
int
qhead_idx
=
h
*
4
+
lane4id
;
const
int
qhead_idx
=
h
*
4
+
lane4id
;
alibi_slope
[
h
]
=
(
qhead_idx
<
GQA_RATIO
)
alibi_slope
[
h
]
=
?
alibi_slopes
[
wg_start_head_idx
+
qhead_idx
]
(
qhead_idx
<
GQA_RATIO
)
?
alibi_slopes
[
wg_start_head_idx
+
qhead_idx
]
:
0.
f
;
:
0.
f
;
}
}
}
}
// fetch vphysical block numbers up front
// fetch vphysical block numbers up front
if
constexpr
(
GQA_RATIO
>=
12
)
{
if
constexpr
(
GQA_RATIO
>=
12
)
#pragma unroll
{
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
#pragma unroll
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
const
int
vblock_idx
=
warp_start_block_idx
+
b
;
const
int
vblock_idx
=
warp_start_block_idx
+
b
;
const
int
vblock_idx_ctx
=
const
int
vblock_idx_ctx
=
(
vblock_idx
<=
last_ctx_block
)
?
vblock_idx
:
last_ctx_block
;
(
vblock_idx
<=
last_ctx_block
)
?
vblock_idx
:
last_ctx_block
;
...
@@ -393,48 +464,53 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -393,48 +464,53 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
const
cache_t
*
v_ptr
=
v_cache
+
wg_start_kv_head_idx
*
kv_head_stride
;
const
cache_t
*
v_ptr
=
v_cache
+
wg_start_kv_head_idx
*
kv_head_stride
;
if
constexpr
(
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
const
_B16x8
*
v_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
v_ptr
);
const
_B16x8
*
v_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
v_ptr
);
// iterate over each v block
// iterate over each v block
#pragma unroll
#pragma unroll
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
// int32 physical_block_number leads to overflow when multiplied with
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
// kv_block_stride
const
int64_t
vphysical_block_number
=
const
int64_t
vphysical_block_number
=
static_cast
<
int64_t
>
(
vphysical_blocks
[
b
]);
static_cast
<
int64_t
>
(
vphysical_blocks
[
b
]);
const
_B16x8
*
v_ptrh8b
=
v_ptrh8
+
(
vphysical_block_number
*
kv_block_stride
)
/
8
;
const
_B16x8
*
v_ptrh8b
=
v_ptrh8
+
(
vphysical_block_number
*
kv_block_stride
)
/
8
;
// iterate over each head elem (within head_size)
// iterate over each head elem (within head_size)
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
VHELOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
VHELOOP
;
h
++
)
{
const
int
head_size_elem
=
h
*
WARP_SIZE
+
laneid
;
const
int
head_size_elem
=
h
*
WARP_SIZE
+
laneid
;
const
_B16x8
*
v_ptrh8be
=
v_ptrh8b
+
head_size_elem
*
BLOCK_SIZE
/
8
;
const
_B16x8
*
v_ptrh8be
=
v_ptrh8b
+
head_size_elem
*
BLOCK_SIZE
/
8
;
// iterate over all velems within block
// iterate over all velems within block
#pragma unroll
#pragma unroll
for
(
int
d
=
0
;
d
<
BLOCK_SIZE
/
8
;
d
++
)
{
for
(
int
d
=
0
;
d
<
BLOCK_SIZE
/
8
;
d
++
)
{
Vlocal
[
h
][
b
*
BLOCK_SIZE
/
8
+
d
]
=
v_ptrh8be
[
d
];
Vlocal
[
h
][
b
*
BLOCK_SIZE
/
8
+
d
]
=
v_ptrh8be
[
d
];
}
}
}
}
}
}
}
else
{
}
else
{
const
_B8x8
*
v_ptrh8
=
reinterpret_cast
<
const
_B8x8
*>
(
v_ptr
);
const
_B8x8
*
v_ptrh8
=
reinterpret_cast
<
const
_B8x8
*>
(
v_ptr
);
// iterate over each v block
// iterate over each v block
#pragma unroll
#pragma unroll
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
// int32 physical_block_number leads to overflow when multiplied with
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
// kv_block_stride
const
int64_t
vphysical_block_number
=
const
int64_t
vphysical_block_number
=
static_cast
<
int64_t
>
(
vphysical_blocks
[
b
]);
static_cast
<
int64_t
>
(
vphysical_blocks
[
b
]);
const
_B8x8
*
v_ptrh8b
=
v_ptrh8
+
(
vphysical_block_number
*
kv_block_stride
)
/
8
;
const
_B8x8
*
v_ptrh8b
=
v_ptrh8
+
(
vphysical_block_number
*
kv_block_stride
)
/
8
;
// iterate over each head elem (within head_size)
// iterate over each head elem (within head_size)
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
VHELOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
VHELOOP
;
h
++
)
{
const
int
head_size_elem
=
h
*
WARP_SIZE
+
laneid
;
const
int
head_size_elem
=
h
*
WARP_SIZE
+
laneid
;
const
_B8x8
*
v_ptrh8be
=
v_ptrh8b
+
head_size_elem
*
BLOCK_SIZE
/
8
;
const
_B8x8
*
v_ptrh8be
=
v_ptrh8b
+
head_size_elem
*
BLOCK_SIZE
/
8
;
// iterate over all velems within block
// iterate over all velems within block
#pragma unroll
#pragma unroll
for
(
int
d
=
0
;
d
<
BLOCK_SIZE
/
8
;
d
++
)
{
for
(
int
d
=
0
;
d
<
BLOCK_SIZE
/
8
;
d
++
)
{
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const
_B8x8
Vlocalb8
=
v_ptrh8be
[
d
];
const
_B8x8
Vlocalb8
=
v_ptrh8be
[
d
];
Vlocal
[
h
][
b
*
BLOCK_SIZE
/
8
+
d
]
=
Vlocal
[
h
][
b
*
BLOCK_SIZE
/
8
+
d
]
=
...
@@ -444,91 +520,80 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -444,91 +520,80 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
}
}
if
constexpr
(
KV_DTYPE
!=
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
!=
vllm
::
Fp8KVCacheDataType
::
kAuto
)
#pragma unroll
{
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
#pragma unroll
Klocal
[
d
]
=
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
scaled_convert_b8x8
<
scalar_t
,
KV_DTYPE
>
(
Klocalb8
[
d
],
k_scale
);
{
}
Klocal
[
d
]
=
scaled_convert_b8x8
<
scalar_t
,
KV_DTYPE
>
(
Klocalb8
[
d
],
k_scale
);
}
}
}
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
#pragma unroll
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
0
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
Klocal
[
0
].
xy
[
0
],
dout
[
h
]);
{
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
0
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
0
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
0
].
xy
[
0
],
dout
[
h
]);
Klocal
[
0
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
0
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
0
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
1
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
1
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
1
].
xy
[
0
],
dout
[
h
]);
Klocal
[
1
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
1
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
1
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
1
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
2
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
2
].
xy
[
0
],
dout
[
h
]);
Klocal
[
1
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
2
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
2
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
2
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
3
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
3
].
xy
[
0
],
dout
[
h
]);
Klocal
[
2
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
3
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
3
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
2
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
4
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
4
].
xy
[
0
],
dout
[
h
]);
Klocal
[
2
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
4
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
4
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
3
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
5
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
5
].
xy
[
0
],
dout
[
h
]);
Klocal
[
3
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
5
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
5
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
3
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
6
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
6
].
xy
[
0
],
dout
[
h
]);
Klocal
[
3
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
6
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
6
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
4
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
7
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
7
].
xy
[
0
],
dout
[
h
]);
Klocal
[
4
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
7
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
7
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
4
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
if
constexpr
(
KHELOOP
>
8
)
Klocal
[
4
].
xy
[
1
],
dout
[
h
]);
{
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
5
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
dout
[
h
]
=
Klocal
[
5
].
xy
[
0
],
dout
[
h
]);
gcn_mfma_instr
<
scalar_t
,
4
,
8
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
8
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
5
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
dout
[
h
]
=
Klocal
[
5
].
xy
[
1
],
dout
[
h
]);
gcn_mfma_instr
<
scalar_t
,
4
,
8
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
8
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
6
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
dout
[
h
]
=
Klocal
[
6
].
xy
[
0
],
dout
[
h
]);
gcn_mfma_instr
<
scalar_t
,
4
,
9
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
9
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
6
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
dout
[
h
]
=
Klocal
[
6
].
xy
[
1
],
dout
[
h
]);
gcn_mfma_instr
<
scalar_t
,
4
,
9
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
9
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
7
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
dout
[
h
]
=
Klocal
[
7
].
xy
[
0
],
dout
[
h
]);
gcn_mfma_instr
<
scalar_t
,
4
,
10
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
10
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
7
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
dout
[
h
]
=
Klocal
[
7
].
xy
[
1
],
dout
[
h
]);
gcn_mfma_instr
<
scalar_t
,
4
,
10
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
10
].
xy
[
1
],
dout
[
h
]);
if
constexpr
(
KHELOOP
>
8
)
{
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
8
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
gcn_mfma_instr
<
scalar_t
,
4
,
11
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
11
].
xy
[
0
],
dout
[
h
]);
Klocal
[
8
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
8
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
gcn_mfma_instr
<
scalar_t
,
4
,
11
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
11
].
xy
[
1
],
dout
[
h
]);
Klocal
[
8
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
9
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
gcn_mfma_instr
<
scalar_t
,
4
,
12
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
12
].
xy
[
0
],
dout
[
h
]);
Klocal
[
9
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
9
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
gcn_mfma_instr
<
scalar_t
,
4
,
12
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
12
].
xy
[
1
],
dout
[
h
]);
Klocal
[
9
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
10
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
gcn_mfma_instr
<
scalar_t
,
4
,
13
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
13
].
xy
[
0
],
dout
[
h
]);
Klocal
[
10
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
10
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
gcn_mfma_instr
<
scalar_t
,
4
,
13
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
13
].
xy
[
1
],
dout
[
h
]);
Klocal
[
10
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
11
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
gcn_mfma_instr
<
scalar_t
,
4
,
14
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
14
].
xy
[
0
],
dout
[
h
]);
Klocal
[
11
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
11
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
gcn_mfma_instr
<
scalar_t
,
4
,
14
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
14
].
xy
[
1
],
dout
[
h
]);
Klocal
[
11
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
12
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
gcn_mfma_instr
<
scalar_t
,
4
,
15
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
15
].
xy
[
0
],
dout
[
h
]);
Klocal
[
12
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
12
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
gcn_mfma_instr
<
scalar_t
,
4
,
15
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
15
].
xy
[
1
],
dout
[
h
]);
Klocal
[
12
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
13
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
13
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
13
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
13
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
14
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
14
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
14
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
14
].
xy
[
1
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
15
,
0
>
(
Qlocal
[
h
].
xy
[
0
],
Klocal
[
15
].
xy
[
0
],
dout
[
h
]);
dout
[
h
]
=
gcn_mfma_instr
<
scalar_t
,
4
,
15
,
0
>
(
Qlocal
[
h
].
xy
[
1
],
Klocal
[
15
].
xy
[
1
],
dout
[
h
]);
}
// KHELOOP>8
}
// KHELOOP>8
dout
[
h
]
*=
scale
;
dout
[
h
]
*=
scale
;
}
}
// transpose dout so that 4 token ids are in each lane, and 4 heads are across
// transpose dout so that 4 token ids are in each lane, and 4 heads are across
// 4 lanes
// 4 lanes
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
floatx4
tmp
=
{
0
};
floatx4
tmp
=
{
0
};
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
const
float
B
=
(
lane4id
==
i
)
?
1.0
f
:
0.0
f
;
const
float
B
=
(
lane4id
==
i
)
?
1.0
f
:
0.0
f
;
// const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
// const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
tmp
=
__builtin_amdgcn_mfma_f32_4x4x1f32
(
dout
[
h
][
i
],
B
,
tmp
,
0
,
0
,
0
);
tmp
=
__builtin_amdgcn_mfma_f32_4x4x1f32
(
dout
[
h
][
i
],
B
,
tmp
,
0
,
0
,
0
);
...
@@ -539,50 +604,58 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -539,50 +604,58 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const
int
lane4_token_idx
=
4
*
(
global_token_idx
>>
2
);
const
int
lane4_token_idx
=
4
*
(
global_token_idx
>>
2
);
const
int
alibi_offset
=
lane4_token_idx
-
context_len
+
1
;
const
int
alibi_offset
=
lane4_token_idx
-
context_len
+
1
;
if
(
alibi_slopes
!=
nullptr
)
{
if
(
alibi_slopes
!=
nullptr
)
#pragma unroll
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dout
[
h
][
i
]
+=
alibi_slope
[
h
]
*
(
alibi_offset
+
i
);
dout
[
h
][
i
]
+=
alibi_slope
[
h
]
*
(
alibi_offset
+
i
);
}
}
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
qk_max
[
h
]
=
-
FLT_MAX
;
qk_max
[
h
]
=
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
qk_max
[
h
]
=
(
lane4_token_idx
+
i
<
context_len
)
{
?
fmaxf
(
qk_max
[
h
],
dout
[
h
][
i
])
qk_max
[
h
]
=
:
qk_max
[
h
];
(
lane4_token_idx
+
i
<
context_len
)
?
fmaxf
(
qk_max
[
h
],
dout
[
h
][
i
])
:
qk_max
[
h
];
}
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
4
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
4
;
mask
/=
2
)
{
qk_max
[
h
]
=
fmaxf
(
qk_max
[
h
],
__shfl_xor
(
qk_max
[
h
],
mask
));
qk_max
[
h
]
=
fmaxf
(
qk_max
[
h
],
__shfl_xor
(
qk_max
[
h
],
mask
));
}
}
}
}
float
exp_sum
[
QHLOOP
];
float
exp_sum
[
QHLOOP
];
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
exp_sum
[
h
]
=
0.0
f
;
exp_sum
[
h
]
=
0.0
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dout
[
h
][
i
]
=
(
lane4_token_idx
+
i
<
context_len
)
{
?
__expf
(
dout
[
h
][
i
]
-
qk_max
[
h
])
dout
[
h
][
i
]
=
:
0.0
f
;
(
lane4_token_idx
+
i
<
context_len
)
?
__expf
(
dout
[
h
][
i
]
-
qk_max
[
h
])
:
0.0
f
;
exp_sum
[
h
]
+=
dout
[
h
][
i
];
exp_sum
[
h
]
+=
dout
[
h
][
i
];
}
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
4
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
4
;
mask
/=
2
)
{
exp_sum
[
h
]
+=
__shfl_xor
(
exp_sum
[
h
],
mask
);
exp_sum
[
h
]
+=
__shfl_xor
(
exp_sum
[
h
],
mask
);
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
const
int
head_idx
=
4
*
h
+
lane4id
;
const
int
head_idx
=
4
*
h
+
lane4id
;
shared_qk_max
[
warpid
][
head_idx
]
=
qk_max
[
h
];
shared_qk_max
[
warpid
][
head_idx
]
=
qk_max
[
h
];
shared_exp_sum
[
warpid
][
head_idx
]
=
exp_sum
[
h
];
shared_exp_sum
[
warpid
][
head_idx
]
=
exp_sum
[
h
];
...
@@ -592,95 +665,86 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -592,95 +665,86 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
__syncthreads
();
__syncthreads
();
const
int
num_heads
=
gridDim
.
z
*
GQA_RATIO
;
const
int
num_heads
=
gridDim
.
z
*
GQA_RATIO
;
float
*
max_logits_ptr
=
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
partition_idx
;
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
partition_idx
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
partition_idx
;
float
*
exp_sums_ptr
=
#pragma unroll
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
partition_idx
;
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
#pragma unroll
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
float
global_qk_max
=
-
FLT_MAX
;
float
global_qk_max
=
-
FLT_MAX
;
float
warp_qk_max
[
NWARPS
];
float
warp_qk_max
[
NWARPS
];
const
int
head_idx
=
4
*
h
+
lane4id
;
const
int
head_idx
=
4
*
h
+
lane4id
;
#pragma unroll
#pragma unroll
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
warp_qk_max
[
w
]
=
shared_qk_max
[
w
][
head_idx
];
warp_qk_max
[
w
]
=
shared_qk_max
[
w
][
head_idx
];
global_qk_max
=
fmaxf
(
global_qk_max
,
warp_qk_max
[
w
]);
global_qk_max
=
fmaxf
(
global_qk_max
,
warp_qk_max
[
w
]);
}
}
float
global_exp_sum
=
0.0
f
;
float
global_exp_sum
=
0.0
f
;
#pragma unroll
#pragma unroll
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
global_exp_sum
+=
{
shared_exp_sum
[
w
][
head_idx
]
*
__expf
(
warp_qk_max
[
w
]
-
global_qk_max
);
global_exp_sum
+=
shared_exp_sum
[
w
][
head_idx
]
*
__expf
(
warp_qk_max
[
w
]
-
global_qk_max
);
}
}
if
(
head_idx
<
GQA_RATIO
)
{
if
(
head_idx
<
GQA_RATIO
)
max_logits_ptr
[(
wg_start_head_idx
+
head_idx
)
*
max_num_partitions
]
=
{
global_qk_max
;
max_logits_ptr
[(
wg_start_head_idx
+
head_idx
)
*
max_num_partitions
]
=
global_qk_max
;
exp_sums_ptr
[(
wg_start_head_idx
+
head_idx
)
*
max_num_partitions
]
=
exp_sums_ptr
[(
wg_start_head_idx
+
head_idx
)
*
max_num_partitions
]
=
global_exp_sum
;
global_exp_sum
;
}
}
const
float
global_inv_sum_scale
=
const
float
global_inv_sum_scale
=
__fdividef
(
1.
f
,
global_exp_sum
+
1e-6
f
)
*
__fdividef
(
1.
f
,
global_exp_sum
+
1e-6
f
)
*
__expf
(
qk_max
[
h
]
-
global_qk_max
);
__expf
(
qk_max
[
h
]
-
global_qk_max
);
dout
[
h
]
*=
global_inv_sum_scale
;
dout
[
h
]
*=
global_inv_sum_scale
;
}
}
// logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
// logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
// are 4x16 tokens across warp
// are 4x16 tokens across warp
_B16x4
logits
[
QHLOOP
];
_B16x4
logits
[
QHLOOP
];
#pragma unroll
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
logits
[
h
]
=
from_floatx4
<
scalar_t
>
(
dout
[
h
]);
logits
[
h
]
=
from_floatx4
<
scalar_t
>
(
dout
[
h
]);
}
}
__shared__
_B16x4
vout_shared
[
QHLOOP
][
VHELOOP
][
WARP_SIZE
][
NWARPS
+
1
];
__shared__
_B16x4
vout_shared
[
QHLOOP
][
VHELOOP
][
WARP_SIZE
][
NWARPS
+
1
];
if
(
warp_start_token_idx
>=
context_len
)
{
// warp out of context
if
(
warp_start_token_idx
>=
context_len
)
#pragma unroll
{
// warp out of context
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
{
#pragma unroll
#pragma unroll
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
{
#pragma unroll
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
vout_shared
[
qh
][
vh
][
laneid
][
warpid
]
=
{
0
};
vout_shared
[
qh
][
vh
][
laneid
][
warpid
]
=
{
0
};
}
}
}
}
}
else
{
// warp in context
}
// iterate across heads
else
#pragma unroll
{
// warp in context
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
{
// iterate across heads
// iterate over each v head elem (within head_size)
#pragma unroll
#pragma unroll
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
{
// iterate over each v head elem (within head_size)
#pragma unroll
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
floatx4
acc
=
{
0
};
floatx4
acc
=
{
0
};
// iterate over tokens
// iterate over tokens
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
0
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
0
].
xy
[
0
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
0
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
0
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
1
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
0
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
1
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
0
].
xy
[
1
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
2
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
1
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
3
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
1
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
2
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
1
].
xy
[
0
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
4
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
2
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
5
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
2
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
3
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
1
].
xy
[
1
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
6
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
3
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
7
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
3
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
4
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
2
].
xy
[
0
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
8
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
4
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
9
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
4
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
5
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
2
].
xy
[
1
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
10
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
5
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
11
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
5
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
6
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
3
].
xy
[
0
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
12
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
6
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
13
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
6
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
7
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
3
].
xy
[
1
],
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
14
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
7
].
xy
[
0
],
acc
);
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
15
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
7
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
8
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
4
].
xy
[
0
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
9
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
4
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
10
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
5
].
xy
[
0
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
11
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
5
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
12
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
6
].
xy
[
0
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
13
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
6
].
xy
[
1
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
14
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
7
].
xy
[
0
],
acc
);
acc
=
gcn_mfma_instr
<
scalar_t
,
4
,
15
,
0
>
(
logits
[
qh
],
Vlocal
[
vh
][
7
].
xy
[
1
],
acc
);
vout_shared
[
qh
][
vh
][
laneid
][
warpid
]
=
from_floatx4
<
scalar_t
>
(
acc
);
vout_shared
[
qh
][
vh
][
laneid
][
warpid
]
=
from_floatx4
<
scalar_t
>
(
acc
);
}
}
}
}
...
@@ -688,42 +752,48 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -688,42 +752,48 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
__syncthreads
();
__syncthreads
();
if
(
warpid
==
0
)
{
if
(
warpid
==
0
)
{
// const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const
float
out_scale
=
const
float
out_scale
=
(
fp8_out_scale_ptr
!=
nullptr
)
?
1.0
f
/
(
*
fp8_out_scale_ptr
)
:
1.0
f
;
(
fp8_out_scale_ptr
!=
nullptr
)
?
1.0
f
/
(
*
fp8_out_scale_ptr
)
:
1.0
f
;
_B16x4
vout
[
QHLOOP
][
VHELOOP
];
_B16x4
vout
[
QHLOOP
][
VHELOOP
];
// iterate across heads
// iterate across heads
#pragma unroll
#pragma unroll
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
{
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
// iterate over each v head elem (within head_size)
{
#pragma unroll
// iterate over each v head elem (within head_size)
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
#pragma unroll
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
vout
[
qh
][
vh
]
=
{
0
};
vout
[
qh
][
vh
]
=
{
0
};
#pragma unroll
#pragma unroll
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
vout
[
qh
][
vh
]
=
{
addx4
<
scalar_t
>
(
vout
[
qh
][
vh
],
vout_shared
[
qh
][
vh
][
laneid
][
w
]);
vout
[
qh
][
vh
]
=
addx4
<
scalar_t
>
(
vout
[
qh
][
vh
],
vout_shared
[
qh
][
vh
][
laneid
][
w
]);
}
}
}
}
}
}
if
(
context_len
>
partition_size
)
{
if
(
context_len
>
partition_size
)
scalar_t
*
out_ptr
=
out
+
{
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
partition_idx
*
HEAD_SIZE
;
const
int
out_num_partitions
=
max_num_partitions
;
const
int
out_num_partitions
=
max_num_partitions
;
bit16_t
*
out_ptr_b16
=
reinterpret_cast
<
bit16_t
*>
(
out_ptr
);
bit16_t
*
out_ptr_b16
=
reinterpret_cast
<
bit16_t
*>
(
out_ptr
);
#pragma unroll
#pragma unroll
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
{
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
#pragma unroll
{
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
#pragma unroll
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
const
int
head_size_elem
=
vh
*
WARP_SIZE
+
laneid
;
const
int
head_size_elem
=
vh
*
WARP_SIZE
+
laneid
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
const
int
head_idx
=
4
*
qh
+
i
;
const
int
head_idx
=
4
*
qh
+
i
;
if
(
head_idx
<
GQA_RATIO
)
{
if
(
head_idx
<
GQA_RATIO
)
{
out_ptr_b16
[(
wg_start_head_idx
+
head_idx
)
*
out_num_partitions
*
out_ptr_b16
[(
wg_start_head_idx
+
head_idx
)
*
out_num_partitions
*
HEAD_SIZE
+
HEAD_SIZE
+
head_size_elem
]
=
vout
[
qh
][
vh
][
i
];
head_size_elem
]
=
vout
[
qh
][
vh
][
i
];
...
@@ -732,31 +802,42 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -732,31 +802,42 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
}
}
}
// context_len > partition_size
}
// context_len > partition_size
else
{
else
{
bit8_t
*
final_out_ptr_b8
;
bit8_t
*
final_out_ptr_b8
;
bit16_t
*
final_out_ptr_b16
;
bit16_t
*
final_out_ptr_b16
;
if
constexpr
(
std
::
is_same
<
OUTT
,
bit8_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
OUTT
,
bit8_t
>::
value
)
{
final_out_ptr_b8
=
final_out
+
seq_idx
*
num_heads
*
HEAD_SIZE
;
final_out_ptr_b8
=
final_out
+
seq_idx
*
num_heads
*
HEAD_SIZE
;
}
else
{
}
else
{
OUTT
*
out_ptr
=
final_out
+
seq_idx
*
num_heads
*
HEAD_SIZE
;
OUTT
*
out_ptr
=
final_out
+
seq_idx
*
num_heads
*
HEAD_SIZE
;
final_out_ptr_b16
=
reinterpret_cast
<
bit16_t
*>
(
out_ptr
);
final_out_ptr_b16
=
reinterpret_cast
<
bit16_t
*>
(
out_ptr
);
}
}
#pragma unroll
#pragma unroll
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
{
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
#pragma unroll
{
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
#pragma unroll
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
const
int
head_size_elem
=
vh
*
WARP_SIZE
+
laneid
;
const
int
head_size_elem
=
vh
*
WARP_SIZE
+
laneid
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
const
int
head_idx
=
4
*
qh
+
i
;
const
int
head_idx
=
4
*
qh
+
i
;
if
(
head_idx
<
GQA_RATIO
)
{
if
(
head_idx
<
GQA_RATIO
)
if
constexpr
(
std
::
is_same
<
OUTT
,
bit8_t
>::
value
)
{
{
if
constexpr
(
std
::
is_same
<
OUTT
,
bit8_t
>::
value
)
{
const
float
tmpf
=
const
float
tmpf
=
out_scale
*
to_float_b16
<
scalar_t
>
(
vout
[
qh
][
vh
][
i
]);
out_scale
*
to_float_b16
<
scalar_t
>
(
vout
[
qh
][
vh
][
i
]);
const
OUTT
tmp
=
hip_fp8
(
tmpf
).
data
;
const
OUTT
tmp
=
hip_fp8
(
tmpf
).
data
;
final_out_ptr_b8
[(
wg_start_head_idx
+
head_idx
)
*
HEAD_SIZE
+
final_out_ptr_b8
[(
wg_start_head_idx
+
head_idx
)
*
HEAD_SIZE
+
head_size_elem
]
=
tmp
;
head_size_elem
]
=
tmp
;
}
else
{
}
else
{
final_out_ptr_b16
[(
wg_start_head_idx
+
head_idx
)
*
HEAD_SIZE
+
final_out_ptr_b16
[(
wg_start_head_idx
+
head_idx
)
*
HEAD_SIZE
+
head_size_elem
]
=
vout
[
qh
][
vh
][
i
];
head_size_elem
]
=
vout
[
qh
][
vh
][
i
];
}
}
...
@@ -769,10 +850,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -769,10 +850,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
typename
OUTT
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
template
<
typename
scalar_t
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
typename
OUTT
,
__global__
int
HEAD_SIZE
,
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
int
NUM_THREADS
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -781,13 +865,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
...
@@ -781,13 +865,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
// max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context_len
,
PARTITION_SIZE
);
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
if
(
num_partitions
==
1
)
{
// if num_partitions==1, main kernel will write to out directly, no work in
// if num_partitions==1, main kernel will write to out directly, no work in
// reduction kernel
// reduction kernel
return
;
return
;
...
@@ -801,10 +888,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
...
@@ -801,10 +888,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max num partitions supported is warp_size * NPAR_LOOPS
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__
float
shared_exp_sums
[
NPAR_LOOPS
*
WARP_SIZE
];
__shared__
float
shared_exp_sums
[
NPAR_LOOPS
*
WARP_SIZE
];
if
(
warpid
==
0
)
{
if
(
warpid
==
0
)
const
float
*
max_logits_ptr
=
max_logits
+
{
seq_idx
*
num_heads
*
max_num_partitions
+
const
float
*
max_logits_ptr
=
head_idx
*
max_num_partitions
;
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
// valid partition is the last valid partition in case threadid > num
// valid partition is the last valid partition in case threadid > num
// partitions
// partitions
...
@@ -812,70 +899,78 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
...
@@ -812,70 +899,78 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
float
reg_max_logit
[
NPAR_LOOPS
];
float
reg_max_logit
[
NPAR_LOOPS
];
const
int
last_valid_partition
=
num_partitions
-
1
;
const
int
last_valid_partition
=
num_partitions
-
1
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
valid_partition
[
i
]
=
valid_partition
[
i
]
=
(
partition_no
<
num_partitions
)
?
partition_no
:
last_valid_partition
;
(
partition_no
<
num_partitions
)
?
partition_no
:
last_valid_partition
;
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
reg_max_logit
[
i
]
=
max_logits_ptr
[
valid_partition
[
i
]];
reg_max_logit
[
i
]
=
max_logits_ptr
[
valid_partition
[
i
]];
}
}
float
max_logit
=
reg_max_logit
[
0
];
float
max_logit
=
reg_max_logit
[
0
];
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
max_logit
=
fmaxf
(
max_logit
,
reg_max_logit
[
i
]);
max_logit
=
fmaxf
(
max_logit
,
reg_max_logit
[
i
]);
}
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor
(
max_logit
,
mask
));
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor
(
max_logit
,
mask
));
}
}
const
float
*
exp_sums_ptr
=
exp_sums
+
const
float
*
exp_sums_ptr
=
seq_idx
*
num_heads
*
max_num_partitions
+
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
head_idx
*
max_num_partitions
;
float
rescaled_exp_sum
[
NPAR_LOOPS
];
float
rescaled_exp_sum
[
NPAR_LOOPS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
rescaled_exp_sum
[
i
]
=
exp_sums_ptr
[
valid_partition
[
i
]];
rescaled_exp_sum
[
i
]
=
exp_sums_ptr
[
valid_partition
[
i
]];
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
rescaled_exp_sum
[
i
]
*=
(
partition_no
<
num_partitions
)
rescaled_exp_sum
[
i
]
*=
?
expf
(
reg_max_logit
[
i
]
-
max_logit
)
(
partition_no
<
num_partitions
)
?
expf
(
reg_max_logit
[
i
]
-
max_logit
)
:
0.0
f
;
:
0.0
f
;
}
}
float
global_exp_sum
=
rescaled_exp_sum
[
0
];
float
global_exp_sum
=
rescaled_exp_sum
[
0
];
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
global_exp_sum
+=
rescaled_exp_sum
[
i
];
global_exp_sum
+=
rescaled_exp_sum
[
i
];
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
shared_exp_sums
[
partition_no
]
=
rescaled_exp_sum
[
i
];
shared_exp_sums
[
partition_no
]
=
rescaled_exp_sum
[
i
];
}
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
global_exp_sum
+=
__shfl_xor
(
global_exp_sum
,
mask
);
global_exp_sum
+=
__shfl_xor
(
global_exp_sum
,
mask
);
}
}
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
shared_global_exp_sum
=
global_exp_sum
;
shared_global_exp_sum
=
global_exp_sum
;
}
}
}
// warpid == 0
}
// warpid == 0
const
scalar_t
*
tmp_out_ptr
=
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
threadIdx
.
x
;
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
threadIdx
.
x
;
constexpr
int
MAX_NPAR
=
64
;
constexpr
int
MAX_NPAR
=
64
;
scalar_t
tmps
[
MAX_NPAR
];
scalar_t
tmps
[
MAX_NPAR
];
const
float
dzero
=
0.0
f
;
const
float
dzero
=
0.0
f
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
tmps
[
j
]
=
from_float
<
scalar_t
>
(
dzero
);
tmps
[
j
]
=
from_float
<
scalar_t
>
(
dzero
);
}
}
const
int
last_partition_offset
=
(
num_partitions
-
1
)
*
HEAD_SIZE
;
const
int
last_partition_offset
=
(
num_partitions
-
1
)
*
HEAD_SIZE
;
...
@@ -884,32 +979,32 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
...
@@ -884,32 +979,32 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
constexpr
int
JCHUNK
=
16
;
constexpr
int
JCHUNK
=
16
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
JCHUNK
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
for
(
int
j
=
0
;
j
<
JCHUNK
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
// lastj is last valid partition
// lastj is last valid partition
const
int
lastj_offset
=
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
idx
++
;
}
}
__syncthreads
();
__syncthreads
();
if
(
num_partitions
>
JCHUNK
)
{
if
(
num_partitions
>
JCHUNK
)
#pragma unroll
{
for
(
int
j
=
JCHUNK
*
HEAD_SIZE
;
j
<
2
*
JCHUNK
*
HEAD_SIZE
;
#pragma unroll
j
+=
HEAD_SIZE
)
{
for
(
int
j
=
JCHUNK
*
HEAD_SIZE
;
j
<
2
*
JCHUNK
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
const
int
lastj_offset
=
{
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
idx
++
;
}
}
if
(
num_partitions
>
2
*
JCHUNK
)
{
if
(
num_partitions
>
2
*
JCHUNK
)
#pragma unroll
{
for
(
int
j
=
2
*
JCHUNK
*
HEAD_SIZE
;
j
<
MAX_NPAR
*
HEAD_SIZE
;
#pragma unroll
j
+=
HEAD_SIZE
)
{
for
(
int
j
=
2
*
JCHUNK
*
HEAD_SIZE
;
j
<
MAX_NPAR
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
const
int
lastj_offset
=
{
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
idx
++
;
}
}
...
@@ -918,64 +1013,77 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
...
@@ -918,64 +1013,77 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// Aggregate tmp_out to out.
// Aggregate tmp_out to out.
float
acc
=
0.0
f
;
float
acc
=
0.0
f
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
JCHUNK
;
j
++
)
{
for
(
int
j
=
0
;
j
<
JCHUNK
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
}
if
(
num_partitions
>
JCHUNK
)
{
if
(
num_partitions
>
JCHUNK
)
#pragma unroll
{
for
(
int
j
=
JCHUNK
;
j
<
2
*
JCHUNK
;
j
++
)
{
#pragma unroll
for
(
int
j
=
JCHUNK
;
j
<
2
*
JCHUNK
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
}
if
(
num_partitions
>
2
*
JCHUNK
)
{
if
(
num_partitions
>
2
*
JCHUNK
)
#pragma unroll
{
for
(
int
j
=
2
*
JCHUNK
;
j
<
MAX_NPAR
;
j
++
)
{
#pragma unroll
for
(
int
j
=
2
*
JCHUNK
;
j
<
MAX_NPAR
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
}
}
}
}
}
for
(
int
p
=
1
;
p
<
NPAR_LOOPS
;
p
++
)
{
for
(
int
p
=
1
;
p
<
NPAR_LOOPS
;
p
++
)
if
(
num_partitions
>
p
*
MAX_NPAR
)
{
{
if
(
num_partitions
>
p
*
MAX_NPAR
)
{
idx
=
0
;
idx
=
0
;
#pragma unroll
#pragma unroll
for
(
int
j
=
p
*
MAX_NPAR
*
HEAD_SIZE
;
j
<
(
p
+
1
)
*
MAX_NPAR
*
HEAD_SIZE
;
for
(
int
j
=
p
*
MAX_NPAR
*
HEAD_SIZE
;
j
<
(
p
+
1
)
*
MAX_NPAR
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
j
+=
HEAD_SIZE
)
{
// lastj is last valid partition
// lastj is last valid partition
const
int
lastj_offset
=
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
idx
++
;
}
}
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
+
p
*
MAX_NPAR
];
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
+
p
*
MAX_NPAR
];
}
}
}
}
}
}
const
float
inv_global_exp_sum
=
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
shared_global_exp_sum
+
1e-6
f
);
__fdividef
(
1.0
f
,
shared_global_exp_sum
+
1e-6
f
);
// const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const
float
out_scale
=
const
float
out_scale
=
(
fp8_out_scale_ptr
!=
nullptr
)
?
1.0
f
/
(
*
fp8_out_scale_ptr
)
:
1.0
f
;
(
fp8_out_scale_ptr
!=
nullptr
)
?
1.0
f
/
(
*
fp8_out_scale_ptr
)
:
1.0
f
;
acc
*=
inv_global_exp_sum
;
acc
*=
inv_global_exp_sum
;
acc
*=
out_scale
;
acc
*=
out_scale
;
OUTT
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
OUTT
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
if
constexpr
(
std
::
is_same
<
OUTT
,
bit8_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
OUTT
,
bit8_t
>::
value
)
{
out_ptr
[
threadIdx
.
x
]
=
hip_fp8
(
acc
).
data
;
out_ptr
[
threadIdx
.
x
]
=
hip_fp8
(
acc
).
data
;
}
else
{
}
else
{
out_ptr
[
threadIdx
.
x
]
=
from_float
<
scalar_t
>
(
acc
);
out_ptr
[
threadIdx
.
x
]
=
from_float
<
scalar_t
>
(
acc
);
}
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
cache_t
,
template
<
typename
scalar_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
typename
cache_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
GQA_RATIO
>
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_kernel
(
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
...
@@ -983,28 +1091,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
...
@@ -983,28 +1091,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
// head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
float
k_scale
,
float
v_scale
,
int
max_ctx_blocks
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
float
k_scale
,
float
v_scale
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
UNREACHABLE_CODE
UNREACHABLE_CODE
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
typename
OUTT
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
template
<
typename
scalar_t
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
typename
OUTT
,
__global__
int
HEAD_SIZE
,
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
int
NUM_THREADS
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -1014,6 +1131,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
...
@@ -1014,6 +1131,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions, head_size]
// max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_partitions
,
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
){
UNREACHABLE_CODE
}
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment