Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd5fa7e0
Unverified
Commit
dd5fa7e0
authored
May 21, 2025
by
Hosang
Committed by
GitHub
May 21, 2025
Browse files
[ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (#17004)
Signed-off-by:
Hosang Yoon
<
hosang.yoon@amd.com
>
parent
2b161045
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1930 additions
and
189 deletions
+1930
-189
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+5
-1
csrc/rocm/attention.cu
csrc/rocm/attention.cu
+1880
-171
tests/kernels/attention/test_attention.py
tests/kernels/attention/test_attention.py
+7
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+2
-1
vllm/attention/ops/chunked_prefill_paged_decode.py
vllm/attention/ops/chunked_prefill_paged_decode.py
+2
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+34
-14
No files found.
benchmarks/kernels/benchmark_paged_attention.py
View file @
dd5fa7e0
...
...
@@ -84,7 +84,10 @@ def main(
if
version
==
"v2"
:
if
current_platform
.
is_rocm
():
global
PARTITION_SIZE
PARTITION_SIZE
=
1024
if
not
args
.
custom_paged_attn
else
PARTITION_SIZE_ROCM
if
not
args
.
custom_paged_attn
and
not
current_platform
.
is_navi
():
PARTITION_SIZE
=
1024
else
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
(
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
...
...
@@ -159,6 +162,7 @@ def main(
scale
,
block_tables
,
seq_lens
,
None
,
block_size
,
max_seq_len
,
alibi_slopes
,
...
...
csrc/rocm/attention.cu
View file @
dd5fa7e0
...
...
@@ -30,6 +30,14 @@
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
#define __HIP__GFX11__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
...
...
@@ -43,7 +51,7 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#if defined(__HIP__GFX9__)
// TODO: Add NAVI support
#if defined(__HIP__GFX9__)
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
...
...
@@ -1482,198 +1490,1697 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
}
#el
se // !
defined(__HIP__GFX
9
__)
TODO: Add NAVI support
#el
if
defined(__HIP__GFX
11
__)
// clang-format off
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_mfma16_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
UNREACHABLE_CODE
using
floatx8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
float
))))
float
;
using
bit16_t
=
uint16_t
;
using
bit16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
uint16_t
))))
uint16_t
;
typedef
bit16x4
_B16x4
;
using
bit16x8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
uint16_t
))))
uint16_t
;
union
b16x8_u
{
bit16x8
u16x8
;
_B16x4
xy
[
2
];
};
typedef
b16x8_u
_B16x8
;
using
bit16x16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
uint16_t
))))
uint16_t
;
union
b16x16_u
{
bit16x16
u16x16
;
_B16x8
xy
[
2
];
};
typedef
b16x16_u
_B16x16
;
using
_B8x8
=
uint2
;
using
bit8_t
=
uint8_t
;
typedef
struct
_B8x16
{
_B8x8
xy
[
2
];
}
_B8x16
;
template
<
typename
T
,
int
absz
,
int
cbid
,
int
blgp
>
__device__
__forceinline__
floatx8
gcn_wmma16x16x16_instr
(
const
bit16x16
&
inpA
,
const
bit16x16
&
inpB
,
const
floatx8
&
inpC
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
inpA
,
inpB
,
inpC
);
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
inpA
,
inpB
,
inpC
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
template
<
typename
T
>
__device__
__forceinline__
float
to_float
(
const
T
&
inp
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
float
)
inp
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__bfloat162float
(
inp
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
template
<
typename
T
>
__device__
__forceinline__
T
from_float
(
const
float
&
inp
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
_Float16
)
inp
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__float2bfloat16
(
inp
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
template
<
typename
T
>
__device__
__forceinline__
_B16x8
from_floatx8
(
const
floatx8
&
inp
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
union
h2cvt
{
__half2
h2
[
4
];
_B16x8
b16x8
;
}
u
;
u
.
h2
[
0
]
=
__float22half2_rn
(
make_float2
(
inp
[
0
],
inp
[
1
]));
u
.
h2
[
1
]
=
__float22half2_rn
(
make_float2
(
inp
[
2
],
inp
[
3
]));
u
.
h2
[
2
]
=
__float22half2_rn
(
make_float2
(
inp
[
4
],
inp
[
5
]));
u
.
h2
[
3
]
=
__float22half2_rn
(
make_float2
(
inp
[
6
],
inp
[
7
]));
return
u
.
b16x8
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
union
b2cvt
{
__hip_bfloat162
b2
[
4
];
_B16x8
b16x8
;
}
u
;
u
.
b2
[
0
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
0
],
inp
[
1
]));
u
.
b2
[
1
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
2
],
inp
[
3
]));
u
.
b2
[
2
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
4
],
inp
[
5
]));
u
.
b2
[
3
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
6
],
inp
[
7
]));
return
u
.
b16x8
;
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
// clang-format off
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_mfma
4
_kernel
(
__launch_bounds__
(
NUM_THREADS
,
3
)
void
paged_attention_ll4mi_QKV_mfma
16
_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
UNREACHABLE_CODE
}
// clang-format on
constexpr
int
NWARPS
=
NUM_THREADS
/
WARP_SIZE
;
// 8 warps on gfx11
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
lane2id
=
laneid
%
2
;
const
int
lane4id
=
laneid
%
4
;
const
int
lane16id
=
laneid
%
16
;
const
int
rowid
=
laneid
/
16
;
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
typename
OUTT
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
UNREACHABLE_CODE
}
// clang-format on
const
int
seq_idx
=
blockIdx
.
x
;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if
(
query_start_loc_ptr
!=
nullptr
&&
(
query_start_loc_ptr
[
seq_idx
+
1
]
-
query_start_loc_ptr
[
seq_idx
])
!=
1
)
{
return
;
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
const
int
partition_idx
=
blockIdx
.
y
;
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
constexpr
int
T_PAR_SIZE
=
256
;
// token partition size set to 256
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
const
int
max_num_partitions
=
gridDim
.
y
;
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
const
int
context_len
=
context_lens
[
seq_idx
];
// length of a seq
template
<
typename
T
,
typename
KVT
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
typename
OUTT
,
int
PARTITION_SIZE_OLD
,
bool
ALIBI_ENABLED
>
void
paged_attention_custom_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_context_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
std
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
)
{
int
num_seqs
=
block_tables
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
const
int
partition_start_token_idx
=
partition_idx
*
T_PAR_SIZE
;
// exit if partition is out of context for seq
if
(
partition_start_token_idx
>=
context_len
)
{
return
;
}
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const
int
*
query_start_loc_ptr
=
query_start_loc
?
reinterpret_cast
<
const
int
*>
(
query_start_loc
.
value
().
data_ptr
())
:
nullptr
;
constexpr
int
GQA_RATIO2
=
DIVIDE_ROUND_UP
(
GQA_RATIO
,
2
);
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
__shared__
float
shared_qk_max
[
NWARPS
][
16
+
1
];
__shared__
float
shared_exp_sum
[
NWARPS
][
16
+
1
];
// shared_logits is used for multiple purposes
__shared__
_B16x16
shared_logits
[
NWARPS
][
2
][
16
][
2
];
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
KVT
*
key_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
key_cache
.
data_ptr
());
KVT
*
value_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
// NOTE: fp8_out_scale is optional.
const
auto
fp8_out_scale_ptr
=
fp8_out_scale
?
static_cast
<
const
float
*>
(
fp8_out_scale
.
value
().
data_ptr
())
:
nullptr
;
OUTT
*
out_ptr
=
reinterpret_cast
<
OUTT
*>
(
out
.
data_ptr
());
// for QK wmma16x16, layout is QHead/Tokenx16 across every 16 lanes,
// 32 Bytes HeadElements in each lane, 2x16B HeadElements across a row of warp
constexpr
int
ROWS_PER_WARP
=
WARP_SIZE
/
16
/
2
;
// rows refers to 16 lanes; refer dpp terminology
constexpr
int
CONTIGUOUS_KV_ELEMS_16B_LOAD
=
16
/
sizeof
(
cache_t
);
// 8 for 16 bit cache type, 16 for 8 bit types
constexpr
int
QKHE_PER_FETCH
=
CONTIGUOUS_KV_ELEMS_16B_LOAD
*
ROWS_PER_WARP
;
// each fetch across a warp fetches these many elements
constexpr
int
QKHELOOP
=
HEAD_SIZE
/
QKHE_PER_FETCH
;
// 2xQKHE_16B across
// warp
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_context_len
,
BLOCK_SIZE
);
_B16x16
Qlocal
[
QKHELOOP
/
2
];
// note that 16 contiguous elements of Q should
// be fetched per lane for 16 bit cache types
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr
int
PARTITION_SIZE
=
256
;
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_context_len
,
PARTITION_SIZE
);
const
int
gqa_ratio
=
num_heads
/
num_kv_heads
;
assert
(
num_heads
%
num_kv_heads
==
0
);
assert
(
head_size
==
HEAD_SIZE
);
constexpr
int
CONTIGUOUS_SCALAR_ELEMS_16B
=
16
/
sizeof
(
scalar_t
);
constexpr
int
NTHR
=
256
;
dim3
grid
(
num_seqs
,
max_num_partitions
,
num_kv_heads
);
dim3
block
(
NTHR
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
constexpr
int
TOKENS_PER_WARP
=
T_PAR_SIZE
/
NWARPS
;
// sub partition of tokens per warp for qk calculation
constexpr
int
TLOOP
=
TOKENS_PER_WARP
/
16
;
// each wmma16x16x16 instruction processes 16 tokens
// mfma4 kernel is faster than mfma16 for gqa_ratio <= 4
switch
(
gqa_ratio
)
{
case
1
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
1
);
break
;
case
2
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
2
);
break
;
case
3
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
3
);
break
;
case
4
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
4
);
break
;
case
5
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
5
);
break
;
case
6
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
6
);
break
;
case
7
:
_B16x16
Klocal
[
TLOOP
]
[
QKHELOOP
/
2
];
// can be interpreted as B8x16 for 8 bit types
const
int
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
int
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
int
total_num_heads
=
gridDim
.
z
*
GQA_RATIO
;
// for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps
// each wmma takes QH16xT16x16HE across warp
// repeat wmma across QKHELOOP dimension
// output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
// across 2 rows x 8 tokens per lane
const
int64_t
query_start_off
=
static_cast
<
int64_t
>
(
query_start_loc_ptr
?
query_start_loc_ptr
[
seq_idx
]
:
seq_idx
);
if
(
GQA_RATIO
==
1
)
{
const
int
local_qhead_idx
=
lane16id
%
GQA_RATIO
;
const
int
global_qhead_idx
=
wg_start_head_idx
+
local_qhead_idx
;
const
scalar_t
*
q_ptr
=
q
+
query_start_off
*
q_stride
+
global_qhead_idx
*
HEAD_SIZE
;
if
(
lane16id
<
GQA_RATIO
)
{
#pragma unroll
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
/
2
;
qkhe_depth
++
)
{
const
scalar_t
*
q_fetch_ptr
=
q_ptr
+
qkhe_depth
*
QKHE_PER_FETCH
*
2
;
const
_B16x16
*
q_fetch_ptr_32B
=
reinterpret_cast
<
const
_B16x16
*>
(
q_fetch_ptr
);
Qlocal
[
qkhe_depth
]
=
*
q_fetch_ptr_32B
;
}
}
}
else
{
// fetch Q in shared across warps and then write to registers
const
int
local_qhead_idx
=
2
*
warpid
+
rowid
;
const
int
global_qhead_idx
=
wg_start_head_idx
+
local_qhead_idx
;
const
scalar_t
*
q_ptr
=
q
+
query_start_off
*
q_stride
+
global_qhead_idx
*
HEAD_SIZE
;
const
int
qhead_element
=
lane16id
*
CONTIGUOUS_SCALAR_ELEMS_16B
;
if
((
local_qhead_idx
<
GQA_RATIO
)
&&
(
qhead_element
<
HEAD_SIZE
))
{
const
scalar_t
*
q_fetch_ptr
=
q_ptr
+
qhead_element
;
const
_B16x8
*
q_fetch_ptr_16B
=
reinterpret_cast
<
const
_B16x8
*>
(
q_fetch_ptr
);
_B16x8
tmp
=
*
q_fetch_ptr_16B
;
const
int
offset1
=
lane16id
/
2
;
// 16 contiguous chunks of head elems are spread across 8x2lanes
shared_logits
[
offset1
][
lane2id
][
local_qhead_idx
][
0
].
xy
[
0
]
=
tmp
;
}
__syncthreads
();
#pragma unroll
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
/
2
;
qkhe_depth
++
)
{
Qlocal
[
qkhe_depth
].
xy
[
0
]
=
shared_logits
[
qkhe_depth
][
0
][
lane16id
%
GQA_RATIO
][
0
].
xy
[
0
];
Qlocal
[
qkhe_depth
].
xy
[
1
]
=
shared_logits
[
qkhe_depth
][
1
][
lane16id
%
GQA_RATIO
][
0
].
xy
[
0
];
}
}
const
int
num_context_blocks
=
DIVIDE_ROUND_UP
(
context_len
,
BLOCK_SIZE
);
const
int
last_ctx_block
=
num_context_blocks
-
1
;
const
int
*
block_table_seq
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
int
kphysical_block_number
[
TLOOP
];
// fetch k physical block numbers
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
klocal_token_idx
=
TOKENS_PER_WARP
*
warpid
+
token_depth
*
16
+
lane16id
;
const
int
kglobal_token_idx
=
partition_start_token_idx
+
klocal_token_idx
;
const
int
kblock_idx
=
(
kglobal_token_idx
<
context_len
)
?
kglobal_token_idx
/
BLOCK_SIZE
:
last_ctx_block
;
kphysical_block_number
[
token_depth
]
=
block_table_seq
[
kblock_idx
];
}
constexpr
int
KX
=
16
/
sizeof
(
cache_t
);
const
cache_t
*
k_ptr
=
k_cache
+
wg_start_kv_head_idx
*
kv_head_stride
;
const
int
row_head_elem
=
0
;
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int64_t
kblock_number
=
static_cast
<
int64_t
>
(
kphysical_block_number
[
token_depth
]);
const
cache_t
*
k_ptr2
=
k_ptr
+
kblock_number
*
kv_block_stride
;
const
int
klocal_token_idx
=
TOKENS_PER_WARP
*
warpid
+
token_depth
*
16
+
lane16id
;
const
int
kglobal_token_idx
=
partition_start_token_idx
+
klocal_token_idx
;
const
int
kphysical_block_offset
=
klocal_token_idx
%
BLOCK_SIZE
;
const
cache_t
*
k_ptr3
=
k_ptr2
+
kphysical_block_offset
*
KX
;
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
;
qkhe_depth
++
)
{
const
int
head_elem
=
row_head_elem
+
qkhe_depth
*
QKHE_PER_FETCH
;
const
int
offset1
=
head_elem
/
KX
;
const
int
offset2
=
head_elem
%
KX
;
const
cache_t
*
k_fetch_ptr
=
k_ptr3
+
offset1
*
BLOCK_SIZE
*
KX
+
offset2
;
const
_B16x8
*
k_fetch_ptr_16B
=
reinterpret_cast
<
const
_B16x8
*>
(
k_fetch_ptr
);
Klocal
[
token_depth
][
qkhe_depth
/
2
].
xy
[
qkhe_depth
%
2
]
=
*
k_fetch_ptr_16B
;
}
}
constexpr
int
VTOKENS_PER_LANE
=
TOKENS_PER_WARP
/
ROWS_PER_WARP
;
// 32/1 = 32 vtokens per lane
constexpr
int
VBLOCKS_PER_LANE
=
2
;
// assumes block size >=16
constexpr
int
VTLOOP
=
NWARPS
;
// corresponds to tokens across warps
constexpr
int
VTLANELOOP
=
DIVIDE_ROUND_UP
(
VTOKENS_PER_LANE
,
CONTIGUOUS_KV_ELEMS_16B_LOAD
);
// optimized for 16B fetches; assumes
// minimum block size is 16
constexpr
int
VHELOOP
=
DIVIDE_ROUND_UP
(
(
HEAD_SIZE
/
16
),
NWARPS
);
// head_size distributed across warps; each
// wmma instr works on 16 head elements
int
vphysical_block_number
[
VTLOOP
][
VBLOCKS_PER_LANE
];
// fetch v physical block numbers
for
(
int
vtoken_depth
=
0
;
vtoken_depth
<
VTLOOP
;
vtoken_depth
++
)
{
for
(
int
vblock_depth
=
0
;
vblock_depth
<
VBLOCKS_PER_LANE
;
vblock_depth
++
)
{
const
int
vlocal_token_idx
=
vtoken_depth
*
VTOKENS_PER_LANE
*
ROWS_PER_WARP
+
vblock_depth
*
BLOCK_SIZE
;
const
int
vglobal_token_idx
=
partition_start_token_idx
+
vlocal_token_idx
;
const
int
vblock_idx
=
(
vglobal_token_idx
<
context_len
)
?
vglobal_token_idx
/
BLOCK_SIZE
:
last_ctx_block
;
vphysical_block_number
[
vtoken_depth
][
vblock_depth
]
=
block_table_seq
[
vblock_idx
];
}
}
_B16x16
Vlocal
[
VTLOOP
][
VHELOOP
]
[
VTLANELOOP
/
2
];
// this can be interpreted as B8x16 too
const
cache_t
*
v_ptr
=
v_cache
+
wg_start_kv_head_idx
*
kv_head_stride
;
// v fetches are 16head elems across lanes x (16x2) tokens per lane
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
const
int
vhead_elem
=
vhe_depth
*
NWARPS
*
16
+
warpid
*
16
+
lane16id
;
const
cache_t
*
v_ptr2
=
v_ptr
+
vhead_elem
*
BLOCK_SIZE
;
for
(
int
vtoken_depth
=
0
;
vtoken_depth
<
VTLOOP
;
vtoken_depth
++
)
{
for
(
int
vfetch_depth
=
0
;
vfetch_depth
<
VTLANELOOP
;
vfetch_depth
++
)
{
const
int64_t
vblock_number
=
static_cast
<
int64_t
>
(
vphysical_block_number
[
vtoken_depth
]
[
vfetch_depth
/
VBLOCKS_PER_LANE
]);
const
cache_t
*
v_ptr3
=
v_ptr2
+
(
vblock_number
*
kv_block_stride
);
const
cache_t
*
v_fetch_ptr
=
v_ptr3
+
(
vfetch_depth
%
VBLOCKS_PER_LANE
)
*
CONTIGUOUS_KV_ELEMS_16B_LOAD
;
const
_B16x8
*
v_fetch_ptr_16B
=
reinterpret_cast
<
const
_B16x8
*>
(
v_fetch_ptr
);
Vlocal
[
vtoken_depth
][
vhe_depth
][
vfetch_depth
/
2
].
xy
[
vfetch_depth
%
2
]
=
*
v_fetch_ptr_16B
;
}
}
}
floatx8
dout
[
TLOOP
];
// qk wmma
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
dout
[
token_depth
]
=
{
0
};
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
/
2
;
qkhe_depth
++
)
{
dout
[
token_depth
]
=
gcn_wmma16x16x16_instr
<
scalar_t
,
0
,
0
,
0
>
(
Klocal
[
token_depth
][
qkhe_depth
].
u16x16
,
Qlocal
[
qkhe_depth
].
u16x16
,
dout
[
token_depth
]);
}
dout
[
token_depth
]
*=
scale
;
}
// calculate qk_max and exp_sum per warp and write to shared memory
float
qk_max
=
-
FLT_MAX
;
float
exp_sum
=
0.0
f
;
const
int
qkout_token_idx
=
partition_start_token_idx
+
TOKENS_PER_WARP
*
warpid
+
rowid
;
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
2
*
i
<
context_len
)
?
dout
[
token_depth
][
i
]
:
-
FLT_MAX
;
qk_max
=
fmaxf
(
qk_max
,
tmp
);
}
}
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor
(
qk_max
,
16
));
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
2
*
i
<
context_len
)
?
__expf
(
dout
[
token_depth
][
i
]
-
qk_max
)
:
0.0
f
;
dout
[
token_depth
][
i
]
=
tmp
;
exp_sum
+=
tmp
;
}
}
exp_sum
+=
__shfl_xor
(
exp_sum
,
16
);
__syncthreads
();
if
(
laneid
<
16
)
{
shared_qk_max
[
warpid
][
lane16id
]
=
qk_max
;
shared_exp_sum
[
warpid
][
lane16id
]
=
exp_sum
;
}
__syncthreads
();
// calculate partition qk_max and exp_sum
float
partition_qk_max
=
-
FLT_MAX
;
float
warp_qk_max_exp
[
NWARPS
];
float
partition_exp_sum
=
0.0
f
;
#pragma unroll
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
warp_qk_max_exp
[
w
]
=
shared_qk_max
[
w
][
lane16id
];
partition_qk_max
=
fmaxf
(
partition_qk_max
,
warp_qk_max_exp
[
w
]);
}
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
warp_qk_max_exp
[
w
]
=
__expf
(
warp_qk_max_exp
[
w
]
-
partition_qk_max
);
partition_exp_sum
+=
shared_exp_sum
[
w
][
lane16id
]
*
warp_qk_max_exp
[
w
];
}
const
float
inv_sum_scale
=
__fdividef
(
1.
f
,
partition_exp_sum
+
1e-6
f
)
*
warp_qk_max_exp
[
warpid
];
__syncthreads
();
// write logits to shared mem
#pragma unroll
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
dout
[
token_depth
]
*=
inv_sum_scale
;
shared_logits
[
warpid
][
token_depth
][
lane16id
][
0
].
xy
[
rowid
]
=
from_floatx8
<
scalar_t
>
(
dout
[
token_depth
]);
}
__syncthreads
();
_B16x8
swp_buf
[
TLOOP
][
2
];
#pragma unroll
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
swp_buf
[
token_depth
][
0
]
=
shared_logits
[
warpid
][
token_depth
][
lane16id
][
0
].
xy
[
0
];
swp_buf
[
token_depth
][
1
]
=
shared_logits
[
warpid
][
token_depth
][
lane16id
][
0
].
xy
[
1
];
}
#pragma unroll
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
shared_logits
[
warpid
][
token_depth
][
lane16id
][
0
].
xy
[
rowid
].
u16x8
[
i
]
=
swp_buf
[
token_depth
][
i
%
2
].
u16x8
[
4
*
rowid
+
(
i
/
2
)];
}
}
// write out partition max_logits and exp_sum
if
(
threadIdx
.
x
<
GQA_RATIO
)
{
const
int
qhead_idx
=
lane16id
;
const
int
offset
=
seq_idx
*
total_num_heads
*
max_num_partitions
+
(
wg_start_head_idx
+
qhead_idx
)
*
max_num_partitions
+
partition_idx
;
max_logits
[
offset
]
=
partition_qk_max
;
exp_sums
[
offset
]
=
partition_exp_sum
;
}
__syncthreads
();
_B16x8
outelems
[
VHELOOP
];
// Softmax V wmma
// v layout: 16he across lanes x (16x2) tokens per lane
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
floatx8
tmp_out
=
{
0
};
for
(
int
vtoken_depth
=
0
;
vtoken_depth
<
VTLOOP
;
vtoken_depth
++
)
{
for
(
int
vfetch_depth
=
0
;
vfetch_depth
<
VTLANELOOP
/
2
;
vfetch_depth
++
)
{
const
int
offset
=
vfetch_depth
;
// if output format is 16 qheads across 16 lanes, 16 head elems spread
// across rows
tmp_out
=
gcn_wmma16x16x16_instr
<
scalar_t
,
0
,
0
,
0
>
(
Vlocal
[
vtoken_depth
][
vhe_depth
][
vfetch_depth
].
u16x16
,
shared_logits
[
vtoken_depth
][
offset
][
lane16id
][
0
].
u16x16
,
tmp_out
);
}
}
outelems
[
vhe_depth
]
=
from_floatx8
<
scalar_t
>
(
tmp_out
);
}
__syncthreads
();
#pragma unroll
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
shared_logits
[
warpid
][
vhe_depth
][
lane16id
][
0
].
xy
[
rowid
]
=
outelems
[
vhe_depth
];
// lane16 id head dimension; rowid head element
// dimension
}
__syncthreads
();
#pragma unroll
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
swp_buf
[
vhe_depth
][
0
]
=
shared_logits
[
warpid
][
vhe_depth
][
lane16id
][
0
].
xy
[
0
];
swp_buf
[
vhe_depth
][
1
]
=
shared_logits
[
warpid
][
vhe_depth
][
lane16id
][
0
].
xy
[
1
];
}
#pragma unroll
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
shared_logits
[
warpid
][
vhe_depth
][
lane16id
][
0
].
xy
[
rowid
].
u16x8
[
i
]
=
swp_buf
[
vhe_depth
][
i
%
2
].
u16x8
[
4
*
rowid
+
(
i
/
2
)];
}
}
__syncthreads
();
// write to tmp_out with coalesced writes after reading from shared mem
if
(
warpid
==
0
)
{
_B16x8
vout
[
GQA_RATIO2
];
// each lane writes out 16Bytes of tmp_out along head elem dimension
const
int
head_elem_idx
=
lane16id
*
8
;
if
(
head_elem_idx
<
HEAD_SIZE
)
{
for
(
int
h
=
0
;
h
<
GQA_RATIO2
;
h
++
)
{
const
int
local_head_idx
=
2
*
h
+
rowid
;
const
int
offset1
=
(
head_elem_idx
/
16
)
%
NWARPS
;
const
int
offset2
=
head_elem_idx
/
16
/
NWARPS
;
const
int
offset3
=
(
head_elem_idx
/
8
)
%
2
;
// num_he % num_row
vout
[
h
]
=
shared_logits
[
offset1
][
offset2
][
local_head_idx
][
0
].
xy
[
offset3
];
}
const
int
hsz_maxp_mult
=
HEAD_SIZE
*
max_num_partitions
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
total_num_heads
*
hsz_maxp_mult
+
partition_idx
*
HEAD_SIZE
;
for
(
int
h
=
0
;
h
<
GQA_RATIO2
;
h
++
)
{
const
int
local_head_idx
=
2
*
h
+
rowid
;
if
(
local_head_idx
<
GQA_RATIO
)
{
const
int
out_head_idx
=
wg_start_head_idx
+
local_head_idx
;
scalar_t
*
out_ptr2
=
out_ptr
+
out_head_idx
*
hsz_maxp_mult
;
scalar_t
*
out_ptr3
=
out_ptr2
+
head_elem_idx
;
_B16x8
*
out_ptr_B16x8
=
reinterpret_cast
<
_B16x8
*>
(
out_ptr3
);
*
out_ptr_B16x8
=
vout
[
h
];
}
}
}
}
}
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_mfma4_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
typename
OUTT
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
const
auto
num_heads
=
gridDim
.
x
;
const
auto
head_idx
=
blockIdx
.
x
;
const
auto
seq_idx
=
blockIdx
.
y
;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if
(
query_start_loc_ptr
!=
nullptr
&&
(
query_start_loc_ptr
[
seq_idx
+
1
]
-
query_start_loc_ptr
[
seq_idx
]
!=
1
))
{
return
;
}
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context_len
,
PARTITION_SIZE
);
[[
maybe_unused
]]
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
[[
maybe_unused
]]
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
__shared__
float
shared_global_exp_sum
;
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__
float
shared_exp_sums
[
NPAR_LOOPS
*
WARP_SIZE
];
if
(
warpid
==
0
)
{
const
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
// valid partition is the last valid partition in case threadid > num
// partitions
int
valid_partition
[
NPAR_LOOPS
];
float
reg_max_logit
[
NPAR_LOOPS
];
const
int
last_valid_partition
=
num_partitions
-
1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
valid_partition
[
i
]
=
(
partition_no
<
num_partitions
)
?
partition_no
:
last_valid_partition
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
reg_max_logit
[
i
]
=
max_logits_ptr
[
valid_partition
[
i
]];
}
float
max_logit
=
reg_max_logit
[
0
];
#pragma unroll
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
max_logit
=
fmaxf
(
max_logit
,
reg_max_logit
[
i
]);
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor
(
max_logit
,
mask
));
}
const
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
rescaled_exp_sum
[
NPAR_LOOPS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
rescaled_exp_sum
[
i
]
=
exp_sums_ptr
[
valid_partition
[
i
]];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
rescaled_exp_sum
[
i
]
*=
(
partition_no
<
num_partitions
)
?
expf
(
reg_max_logit
[
i
]
-
max_logit
)
:
0.0
f
;
}
float
global_exp_sum
=
rescaled_exp_sum
[
0
];
#pragma unroll
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
global_exp_sum
+=
rescaled_exp_sum
[
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
shared_exp_sums
[
partition_no
]
=
rescaled_exp_sum
[
i
];
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
global_exp_sum
+=
__shfl_xor
(
global_exp_sum
,
mask
);
}
if
(
threadIdx
.
x
==
0
)
{
shared_global_exp_sum
=
global_exp_sum
;
}
}
// warpid == 0
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
threadIdx
.
x
;
constexpr
int
MAX_NPAR
=
32
;
scalar_t
tmps
[
MAX_NPAR
];
const
float
dzero
=
0.0
f
;
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
tmps
[
j
]
=
from_float
<
scalar_t
>
(
dzero
);
}
const
int
last_partition_offset
=
(
num_partitions
-
1
)
*
HEAD_SIZE
;
const
int
num_partition_offset
=
(
num_partitions
)
*
HEAD_SIZE
;
int
idx
=
0
;
constexpr
int
JCHUNK
=
16
;
#pragma unroll
for
(
int
j
=
0
;
j
<
JCHUNK
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
// lastj is last valid partition
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
__syncthreads
();
if
(
num_partitions
>
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
JCHUNK
*
HEAD_SIZE
;
j
<
2
*
JCHUNK
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
if
(
num_partitions
>
2
*
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
2
*
JCHUNK
*
HEAD_SIZE
;
j
<
MAX_NPAR
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
}
}
// num_partitions > JCHUNK
// Aggregate tmp_out to out.
float
acc
=
0.0
f
;
#pragma unroll
for
(
int
j
=
0
;
j
<
JCHUNK
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
if
(
num_partitions
>
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
JCHUNK
;
j
<
2
*
JCHUNK
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
if
(
num_partitions
>
2
*
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
2
*
JCHUNK
;
j
<
MAX_NPAR
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
}
}
for
(
int
p
=
1
;
p
<
NPAR_LOOPS
;
p
++
)
{
if
(
num_partitions
>
p
*
MAX_NPAR
)
{
idx
=
0
;
#pragma unroll
for
(
int
j
=
p
*
MAX_NPAR
*
HEAD_SIZE
;
j
<
(
p
+
1
)
*
MAX_NPAR
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
// lastj is last valid partition
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
+
p
*
MAX_NPAR
];
}
}
}
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
shared_global_exp_sum
+
1e-6
f
);
acc
*=
inv_global_exp_sum
;
const
int64_t
query_start_off
=
static_cast
<
int64_t
>
(
query_start_loc_ptr
?
query_start_loc_ptr
[
seq_idx
]
:
seq_idx
);
OUTT
*
out_ptr
=
out
+
query_start_off
*
num_heads
*
HEAD_SIZE
+
static_cast
<
int64_t
>
(
head_idx
)
*
HEAD_SIZE
;
out_ptr
[
threadIdx
.
x
]
=
from_float
<
scalar_t
>
(
acc
);
}
#elif defined(__HIP__GFX12__)
using
floatx8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
float
))))
float
;
using
bit16_t
=
uint16_t
;
using
bit16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
uint16_t
))))
uint16_t
;
typedef
bit16x4
_B16x4
;
using
bit16x8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
uint16_t
))))
uint16_t
;
union
b16x8_u
{
bit16x8
u16x8
;
_B16x4
xy
[
2
];
};
typedef
b16x8_u
_B16x8
;
using
_B8x8
=
uint2
;
using
bit8_t
=
uint8_t
;
typedef
struct
_B8x16
{
_B8x8
xy
[
2
];
}
_B8x16
;
template
<
typename
T
,
int
absz
,
int
cbid
,
int
blgp
>
__device__
__forceinline__
floatx8
gcn_wmma16x16x16_instr
(
const
bit16x8
&
inpA
,
const
bit16x8
&
inpB
,
const
floatx8
&
inpC
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12
(
inpA
,
inpB
,
inpC
);
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12
(
inpA
,
inpB
,
inpC
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
template
<
typename
T
>
__device__
__forceinline__
float
to_float
(
const
T
&
inp
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
float
)
inp
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__bfloat162float
(
inp
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
template
<
typename
T
>
__device__
__forceinline__
float
to_float_b16
(
const
bit16_t
&
inp
)
{
union
tmpcvt
{
bit16_t
u
;
_Float16
f
;
__hip_bfloat16
b
;
}
t16
;
t16
.
u
=
inp
;
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
float
)
t16
.
f
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__bfloat162float
(
t16
.
b
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
template
<
typename
T
>
__device__
__forceinline__
T
from_float
(
const
float
&
inp
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
return
(
_Float16
)
inp
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
return
__float2bfloat16
(
inp
);
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
template
<
typename
T
>
__device__
__forceinline__
_B16x8
from_floatx8
(
const
floatx8
&
inp
)
{
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
union
h2cvt
{
__half2
h2
[
4
];
_B16x8
b16x8
;
}
u
;
u
.
h2
[
0
]
=
__float22half2_rn
(
make_float2
(
inp
[
0
],
inp
[
1
]));
u
.
h2
[
1
]
=
__float22half2_rn
(
make_float2
(
inp
[
2
],
inp
[
3
]));
u
.
h2
[
2
]
=
__float22half2_rn
(
make_float2
(
inp
[
4
],
inp
[
5
]));
u
.
h2
[
3
]
=
__float22half2_rn
(
make_float2
(
inp
[
6
],
inp
[
7
]));
return
u
.
b16x8
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
union
b2cvt
{
__hip_bfloat162
b2
[
4
];
_B16x8
b16x8
;
}
u
;
u
.
b2
[
0
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
0
],
inp
[
1
]));
u
.
b2
[
1
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
2
],
inp
[
3
]));
u
.
b2
[
2
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
4
],
inp
[
5
]));
u
.
b2
[
3
]
=
__float22bfloat162_rn
(
make_float2
(
inp
[
6
],
inp
[
7
]));
return
u
.
b16x8
;
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
// clang-format off
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
,
3
)
void
paged_attention_ll4mi_QKV_mfma16_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
// clang-format on
constexpr
int
NWARPS
=
NUM_THREADS
/
WARP_SIZE
;
// 8 warps on gfx11
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
lane2id
=
laneid
%
2
;
const
int
lane4id
=
laneid
%
4
;
const
int
lane16id
=
laneid
%
16
;
const
int
rowid
=
laneid
/
16
;
const
int
seq_idx
=
blockIdx
.
x
;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if
(
query_start_loc_ptr
!=
nullptr
&&
(
query_start_loc_ptr
[
seq_idx
+
1
]
-
query_start_loc_ptr
[
seq_idx
]
!=
1
))
{
return
;
}
const
int
partition_idx
=
blockIdx
.
y
;
constexpr
int
T_PAR_SIZE
=
256
;
// token partition size set to 256
const
int
max_num_partitions
=
gridDim
.
y
;
const
int
context_len
=
context_lens
[
seq_idx
];
// length of a seq
const
int
partition_start_token_idx
=
partition_idx
*
T_PAR_SIZE
;
// exit if partition is out of context for seq
if
(
partition_start_token_idx
>=
context_len
)
{
return
;
}
constexpr
int
GQA_RATIO2
=
DIVIDE_ROUND_UP
(
GQA_RATIO
,
2
);
__shared__
float
shared_qk_max
[
NWARPS
][
16
+
1
];
__shared__
float
shared_exp_sum
[
NWARPS
][
16
+
1
];
// shared_logits is used for multiple purposes
__shared__
_B16x8
shared_logits
[
NWARPS
][
2
][
16
][
2
];
// for QK wmma16x16_gfx12, layout is QHead/Tokenx16 across every 16 lanes,
// 16 Bytes HeadElements in each lane, 2x16B HeadElements across 2 rows of
// warp
constexpr
int
ROWS_PER_WARP
=
WARP_SIZE
/
16
;
// rows refers to 16 lanes; refer dpp terminology
constexpr
int
CONTIGUOUS_KV_ELEMS_16B_LOAD
=
16
/
sizeof
(
cache_t
);
// 8 for 16 bit cache type, 16 for 8 bit types
constexpr
int
QKHE_PER_FETCH
=
CONTIGUOUS_KV_ELEMS_16B_LOAD
*
ROWS_PER_WARP
;
// each fetch across a warp fetches these many elements
constexpr
int
QKHELOOP
=
HEAD_SIZE
/
QKHE_PER_FETCH
;
// 2xQKHE_16B across
// warp
_B16x8
Qlocal
[
QKHELOOP
];
// note that 16 contiguous elements of Q should
// be fetched per lane for 16 bit cache types
constexpr
int
CONTIGUOUS_SCALAR_ELEMS_16B
=
16
/
sizeof
(
scalar_t
);
constexpr
int
TOKENS_PER_WARP
=
T_PAR_SIZE
/
NWARPS
;
// sub partition of tokens per warp for qk calculation
constexpr
int
TLOOP
=
TOKENS_PER_WARP
/
16
;
// each wmma16x16x16 instruction processes 16 tokens
_B16x8
Klocal
[
TLOOP
]
[
QKHELOOP
];
// can be interpreted as B8x16 for 8 bit types
const
int
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
int
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
int
total_num_heads
=
gridDim
.
z
*
GQA_RATIO
;
// for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps
// each wmma takes QH16xT16x16HE across warp
// repeat wmma across QKHELOOP dimension
// output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
// across 2 rows x 8 tokens per lane
const
int64_t
query_start_off
=
static_cast
<
int64_t
>
(
query_start_loc_ptr
?
query_start_loc_ptr
[
seq_idx
]
:
seq_idx
);
if
(
GQA_RATIO
==
1
)
{
const
int
local_qhead_idx
=
lane16id
%
GQA_RATIO
;
const
int
global_qhead_idx
=
wg_start_head_idx
+
local_qhead_idx
;
const
scalar_t
*
q_ptr
=
q
+
query_start_off
*
q_stride
+
global_qhead_idx
*
HEAD_SIZE
+
rowid
*
CONTIGUOUS_KV_ELEMS_16B_LOAD
;
if
(
lane16id
<
GQA_RATIO
)
{
#pragma unroll
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
;
qkhe_depth
++
)
{
const
scalar_t
*
q_fetch_ptr
=
q_ptr
+
qkhe_depth
*
QKHE_PER_FETCH
;
const
_B16x8
*
q_fetch_ptr_16B
=
reinterpret_cast
<
const
_B16x8
*>
(
q_fetch_ptr
);
Qlocal
[
qkhe_depth
]
=
*
q_fetch_ptr_16B
;
}
}
}
else
{
// fetch Q in shared across warps and then write to registers
const
int
local_qhead_idx
=
2
*
warpid
+
rowid
;
const
int
global_qhead_idx
=
wg_start_head_idx
+
local_qhead_idx
;
const
scalar_t
*
q_ptr
=
q
+
query_start_off
*
q_stride
+
global_qhead_idx
*
HEAD_SIZE
;
const
int
qhead_element
=
lane16id
*
CONTIGUOUS_SCALAR_ELEMS_16B
;
if
((
local_qhead_idx
<
GQA_RATIO
)
&&
(
qhead_element
<
HEAD_SIZE
))
{
const
scalar_t
*
q_fetch_ptr
=
q_ptr
+
qhead_element
;
const
_B16x8
*
q_fetch_ptr_16B
=
reinterpret_cast
<
const
_B16x8
*>
(
q_fetch_ptr
);
_B16x8
tmp
=
*
q_fetch_ptr_16B
;
const
int
offset1
=
lane16id
/
2
;
// 16 contiguous chunks of head elems are spread across 8x2lanes
shared_logits
[
offset1
][
lane2id
][
local_qhead_idx
][
0
]
=
tmp
;
}
__syncthreads
();
#pragma unroll
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
;
qkhe_depth
++
)
{
Qlocal
[
qkhe_depth
]
=
shared_logits
[
qkhe_depth
][
rowid
][
lane16id
%
GQA_RATIO
][
0
];
}
}
const
int
num_context_blocks
=
DIVIDE_ROUND_UP
(
context_len
,
BLOCK_SIZE
);
const
int
last_ctx_block
=
num_context_blocks
-
1
;
const
int
*
block_table_seq
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
int
kphysical_block_number
[
TLOOP
];
// fetch k physical block numbers
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
klocal_token_idx
=
TOKENS_PER_WARP
*
warpid
+
token_depth
*
16
+
lane16id
;
const
int
kglobal_token_idx
=
partition_start_token_idx
+
klocal_token_idx
;
const
int
kblock_idx
=
(
kglobal_token_idx
<
context_len
)
?
kglobal_token_idx
/
BLOCK_SIZE
:
last_ctx_block
;
kphysical_block_number
[
token_depth
]
=
block_table_seq
[
kblock_idx
];
}
constexpr
int
KX
=
16
/
sizeof
(
cache_t
);
const
cache_t
*
k_ptr
=
k_cache
+
wg_start_kv_head_idx
*
kv_head_stride
;
const
int
row_head_elem
=
rowid
*
CONTIGUOUS_KV_ELEMS_16B_LOAD
;
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int64_t
kblock_number
=
static_cast
<
int64_t
>
(
kphysical_block_number
[
token_depth
]);
const
cache_t
*
k_ptr2
=
k_ptr
+
kblock_number
*
kv_block_stride
;
const
int
klocal_token_idx
=
TOKENS_PER_WARP
*
warpid
+
token_depth
*
16
+
lane16id
;
const
int
kglobal_token_idx
=
partition_start_token_idx
+
klocal_token_idx
;
const
int
kphysical_block_offset
=
klocal_token_idx
%
BLOCK_SIZE
;
const
cache_t
*
k_ptr3
=
k_ptr2
+
kphysical_block_offset
*
KX
;
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
;
qkhe_depth
++
)
{
const
int
head_elem
=
row_head_elem
+
qkhe_depth
*
QKHE_PER_FETCH
;
const
int
offset1
=
head_elem
/
KX
;
const
int
offset2
=
head_elem
%
KX
;
const
cache_t
*
k_fetch_ptr
=
k_ptr3
+
offset1
*
BLOCK_SIZE
*
KX
+
offset2
;
const
_B16x8
*
k_fetch_ptr_16B
=
reinterpret_cast
<
const
_B16x8
*>
(
k_fetch_ptr
);
Klocal
[
token_depth
][
qkhe_depth
]
=
*
k_fetch_ptr_16B
;
}
}
constexpr
int
VTOKENS_PER_LANE
=
TOKENS_PER_WARP
/
ROWS_PER_WARP
;
// 32/2 = 16 vtokens per lane
constexpr
int
VBLOCKS_PER_LANE
=
1
;
// assumes block size >=16
constexpr
int
VTLOOP
=
NWARPS
;
// corresponds to tokens across warps
constexpr
int
VTLANELOOP
=
DIVIDE_ROUND_UP
(
VTOKENS_PER_LANE
,
CONTIGUOUS_KV_ELEMS_16B_LOAD
);
// optimized for 16B fetches; assumes
// minimum block size is 16
constexpr
int
VHELOOP
=
DIVIDE_ROUND_UP
(
(
HEAD_SIZE
/
16
),
NWARPS
);
// head_size distributed across warps; each
// wmma instr works on 16 head elements
int
vphysical_block_number
[
VTLOOP
][
VBLOCKS_PER_LANE
];
// fetch v physical block numbers
for
(
int
vtoken_depth
=
0
;
vtoken_depth
<
VTLOOP
;
vtoken_depth
++
)
{
for
(
int
vblock_depth
=
0
;
vblock_depth
<
VBLOCKS_PER_LANE
;
vblock_depth
++
)
{
const
int
vlocal_token_idx
=
vtoken_depth
*
VTOKENS_PER_LANE
*
ROWS_PER_WARP
+
rowid
*
VTOKENS_PER_LANE
+
vblock_depth
*
BLOCK_SIZE
;
const
int
vglobal_token_idx
=
partition_start_token_idx
+
vlocal_token_idx
;
const
int
vblock_idx
=
(
vglobal_token_idx
<
context_len
)
?
vglobal_token_idx
/
BLOCK_SIZE
:
last_ctx_block
;
vphysical_block_number
[
vtoken_depth
][
vblock_depth
]
=
block_table_seq
[
vblock_idx
];
}
}
_B16x8
Vlocal
[
VTLOOP
][
VHELOOP
]
[
VTLANELOOP
];
// this can be interpreted as B8x16 too
const
cache_t
*
v_ptr
=
v_cache
+
wg_start_kv_head_idx
*
kv_head_stride
+
((
rowid
*
VTOKENS_PER_LANE
)
%
BLOCK_SIZE
);
// v fetches are 16head elems across lanes x 16 tokens per lane
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
const
int
vhead_elem
=
vhe_depth
*
NWARPS
*
16
+
warpid
*
16
+
lane16id
;
const
cache_t
*
v_ptr2
=
v_ptr
+
vhead_elem
*
BLOCK_SIZE
;
for
(
int
vtoken_depth
=
0
;
vtoken_depth
<
VTLOOP
;
vtoken_depth
++
)
{
for
(
int
vfetch_depth
=
0
;
vfetch_depth
<
VTLANELOOP
;
vfetch_depth
++
)
{
const
int
vblock_depth
=
0
;
const
int64_t
vblock_number
=
static_cast
<
int64_t
>
(
vphysical_block_number
[
vtoken_depth
][
vblock_depth
]);
const
cache_t
*
v_ptr3
=
v_ptr2
+
(
vblock_number
*
kv_block_stride
);
const
cache_t
*
v_fetch_ptr
=
v_ptr3
+
vfetch_depth
*
CONTIGUOUS_KV_ELEMS_16B_LOAD
;
const
_B16x8
*
v_fetch_ptr_16B
=
reinterpret_cast
<
const
_B16x8
*>
(
v_fetch_ptr
);
Vlocal
[
vtoken_depth
][
vhe_depth
][
vfetch_depth
]
=
*
v_fetch_ptr_16B
;
}
}
}
floatx8
dout
[
TLOOP
];
// qk wmma
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
dout
[
token_depth
]
=
{
0
};
for
(
int
qkhe_depth
=
0
;
qkhe_depth
<
QKHELOOP
;
qkhe_depth
++
)
{
dout
[
token_depth
]
=
gcn_wmma16x16x16_instr
<
scalar_t
,
0
,
0
,
0
>
(
Klocal
[
token_depth
][
qkhe_depth
].
u16x8
,
Qlocal
[
qkhe_depth
].
u16x8
,
dout
[
token_depth
]);
}
dout
[
token_depth
]
*=
scale
;
}
// calculate qk_max and exp_sum per warp and write to shared memory
float
qk_max
=
-
FLT_MAX
;
float
exp_sum
=
0.0
f
;
const
int
qkout_token_idx
=
partition_start_token_idx
+
TOKENS_PER_WARP
*
warpid
+
rowid
*
8
;
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
i
<
context_len
)
?
dout
[
token_depth
][
i
]
:
-
FLT_MAX
;
qk_max
=
fmaxf
(
qk_max
,
tmp
);
}
}
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor
(
qk_max
,
16
));
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
i
<
context_len
)
?
__expf
(
dout
[
token_depth
][
i
]
-
qk_max
)
:
0.0
f
;
dout
[
token_depth
][
i
]
=
tmp
;
exp_sum
+=
tmp
;
}
}
exp_sum
+=
__shfl_xor
(
exp_sum
,
16
);
__syncthreads
();
if
(
laneid
<
16
)
{
shared_qk_max
[
warpid
][
lane16id
]
=
qk_max
;
shared_exp_sum
[
warpid
][
lane16id
]
=
exp_sum
;
}
__syncthreads
();
// calculate partition qk_max and exp_sum
float
partition_qk_max
=
-
FLT_MAX
;
float
warp_qk_max_exp
[
NWARPS
];
float
partition_exp_sum
=
0.0
f
;
#pragma unroll
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
warp_qk_max_exp
[
w
]
=
shared_qk_max
[
w
][
lane16id
];
partition_qk_max
=
fmaxf
(
partition_qk_max
,
warp_qk_max_exp
[
w
]);
}
for
(
int
w
=
0
;
w
<
NWARPS
;
w
++
)
{
warp_qk_max_exp
[
w
]
=
__expf
(
warp_qk_max_exp
[
w
]
-
partition_qk_max
);
partition_exp_sum
+=
shared_exp_sum
[
w
][
lane16id
]
*
warp_qk_max_exp
[
w
];
}
const
float
inv_sum_scale
=
__fdividef
(
1.
f
,
partition_exp_sum
+
1e-6
f
)
*
warp_qk_max_exp
[
warpid
];
__syncthreads
();
// write logits to shared mem
#pragma unroll
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
dout
[
token_depth
]
*=
inv_sum_scale
;
shared_logits
[
warpid
][
token_depth
][
lane16id
][
rowid
]
=
from_floatx8
<
scalar_t
>
(
dout
[
token_depth
]);
}
// write out partition max_logits and exp_sum
if
(
threadIdx
.
x
<
GQA_RATIO
)
{
const
int
qhead_idx
=
lane16id
;
const
int
offset
=
seq_idx
*
total_num_heads
*
max_num_partitions
+
(
wg_start_head_idx
+
qhead_idx
)
*
max_num_partitions
+
partition_idx
;
max_logits
[
offset
]
=
partition_qk_max
;
exp_sums
[
offset
]
=
partition_exp_sum
;
}
__syncthreads
();
_B16x8
outelems
[
VHELOOP
];
// Softmax V wmma
// v layout: 16he across lanes x 16 tokens per lane
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
floatx8
tmp_out
=
{
0
};
for
(
int
vtoken_depth
=
0
;
vtoken_depth
<
VTLOOP
;
vtoken_depth
++
)
{
for
(
int
vfetch_depth
=
0
;
vfetch_depth
<
VTLANELOOP
;
vfetch_depth
++
)
{
const
int
offset
=
rowid
*
VTLANELOOP
+
vfetch_depth
;
const
int
offset1
=
offset
%
ROWS_PER_WARP
;
const
int
offset2
=
offset
/
ROWS_PER_WARP
;
// if output format is 16 qheads across 16 lanes, 16 head elems spread
// across rows
tmp_out
=
gcn_wmma16x16x16_instr
<
scalar_t
,
0
,
0
,
0
>
(
Vlocal
[
vtoken_depth
][
vhe_depth
][
vfetch_depth
].
u16x8
,
shared_logits
[
vtoken_depth
][
offset2
][
lane16id
][
offset1
].
u16x8
,
tmp_out
);
}
}
outelems
[
vhe_depth
]
=
from_floatx8
<
scalar_t
>
(
tmp_out
);
}
__syncthreads
();
#pragma unroll
for
(
int
vhe_depth
=
0
;
vhe_depth
<
VHELOOP
;
vhe_depth
++
)
{
shared_logits
[
warpid
][
vhe_depth
][
lane16id
][
rowid
]
=
outelems
[
vhe_depth
];
// lane16 id head dimension; rowid head element
// dimension
}
__syncthreads
();
// write to tmp_out with coalesced writes after reading from shared mem
if
(
warpid
==
0
)
{
_B16x8
vout
[
GQA_RATIO2
];
// each lane writes out 16Bytes of tmp_out along head elem dimension
const
int
head_elem_idx
=
lane16id
*
8
;
if
(
head_elem_idx
<
HEAD_SIZE
)
{
for
(
int
h
=
0
;
h
<
GQA_RATIO2
;
h
++
)
{
const
int
local_head_idx
=
2
*
h
+
rowid
;
const
int
offset1
=
(
head_elem_idx
/
16
)
%
NWARPS
;
const
int
offset2
=
head_elem_idx
/
16
/
NWARPS
;
const
int
offset3
=
(
head_elem_idx
/
8
)
%
2
;
// num_he % num_row
vout
[
h
]
=
shared_logits
[
offset1
][
offset2
][
local_head_idx
][
offset3
];
}
const
int
hsz_maxp_mult
=
HEAD_SIZE
*
max_num_partitions
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
total_num_heads
*
hsz_maxp_mult
+
partition_idx
*
HEAD_SIZE
;
for
(
int
h
=
0
;
h
<
GQA_RATIO2
;
h
++
)
{
const
int
local_head_idx
=
2
*
h
+
rowid
;
if
(
local_head_idx
<
GQA_RATIO
)
{
const
int
out_head_idx
=
wg_start_head_idx
+
local_head_idx
;
scalar_t
*
out_ptr2
=
out_ptr
+
out_head_idx
*
hsz_maxp_mult
;
scalar_t
*
out_ptr3
=
out_ptr2
+
head_elem_idx
;
_B16x8
*
out_ptr_B16x8
=
reinterpret_cast
<
_B16x8
*>
(
out_ptr3
);
*
out_ptr_B16x8
=
vout
[
h
];
}
}
}
}
}
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_mfma4_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
typename
OUTT
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
const
auto
num_heads
=
gridDim
.
x
;
const
auto
head_idx
=
blockIdx
.
x
;
const
auto
seq_idx
=
blockIdx
.
y
;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if
(
query_start_loc_ptr
!=
nullptr
&&
(
query_start_loc_ptr
[
seq_idx
+
1
]
-
query_start_loc_ptr
[
seq_idx
]
!=
1
))
{
return
;
}
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context_len
,
PARTITION_SIZE
);
[[
maybe_unused
]]
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
[[
maybe_unused
]]
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
__shared__
float
shared_global_exp_sum
;
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__
float
shared_exp_sums
[
NPAR_LOOPS
*
WARP_SIZE
];
if
(
warpid
==
0
)
{
const
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
// valid partition is the last valid partition in case threadid > num
// partitions
int
valid_partition
[
NPAR_LOOPS
];
float
reg_max_logit
[
NPAR_LOOPS
];
const
int
last_valid_partition
=
num_partitions
-
1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
valid_partition
[
i
]
=
(
partition_no
<
num_partitions
)
?
partition_no
:
last_valid_partition
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
reg_max_logit
[
i
]
=
max_logits_ptr
[
valid_partition
[
i
]];
}
float
max_logit
=
reg_max_logit
[
0
];
#pragma unroll
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
max_logit
=
fmaxf
(
max_logit
,
reg_max_logit
[
i
]);
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor
(
max_logit
,
mask
));
}
const
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
rescaled_exp_sum
[
NPAR_LOOPS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
rescaled_exp_sum
[
i
]
=
exp_sums_ptr
[
valid_partition
[
i
]];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
rescaled_exp_sum
[
i
]
*=
(
partition_no
<
num_partitions
)
?
expf
(
reg_max_logit
[
i
]
-
max_logit
)
:
0.0
f
;
}
float
global_exp_sum
=
rescaled_exp_sum
[
0
];
#pragma unroll
for
(
int
i
=
1
;
i
<
NPAR_LOOPS
;
i
++
)
{
global_exp_sum
+=
rescaled_exp_sum
[
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
shared_exp_sums
[
partition_no
]
=
rescaled_exp_sum
[
i
];
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
global_exp_sum
+=
__shfl_xor
(
global_exp_sum
,
mask
);
}
if
(
threadIdx
.
x
==
0
)
{
shared_global_exp_sum
=
global_exp_sum
;
}
}
// warpid == 0
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
threadIdx
.
x
;
constexpr
int
MAX_NPAR
=
32
;
scalar_t
tmps
[
MAX_NPAR
];
const
float
dzero
=
0.0
f
;
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
tmps
[
j
]
=
from_float
<
scalar_t
>
(
dzero
);
}
const
int
last_partition_offset
=
(
num_partitions
-
1
)
*
HEAD_SIZE
;
const
int
num_partition_offset
=
(
num_partitions
)
*
HEAD_SIZE
;
int
idx
=
0
;
constexpr
int
JCHUNK
=
16
;
#pragma unroll
for
(
int
j
=
0
;
j
<
JCHUNK
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
// lastj is last valid partition
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
__syncthreads
();
if
(
num_partitions
>
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
JCHUNK
*
HEAD_SIZE
;
j
<
2
*
JCHUNK
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
if
(
num_partitions
>
2
*
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
2
*
JCHUNK
*
HEAD_SIZE
;
j
<
MAX_NPAR
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
}
}
// num_partitions > JCHUNK
// Aggregate tmp_out to out.
float
acc
=
0.0
f
;
#pragma unroll
for
(
int
j
=
0
;
j
<
JCHUNK
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
if
(
num_partitions
>
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
JCHUNK
;
j
<
2
*
JCHUNK
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
if
(
num_partitions
>
2
*
JCHUNK
)
{
#pragma unroll
for
(
int
j
=
2
*
JCHUNK
;
j
<
MAX_NPAR
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
];
}
}
}
for
(
int
p
=
1
;
p
<
NPAR_LOOPS
;
p
++
)
{
if
(
num_partitions
>
p
*
MAX_NPAR
)
{
idx
=
0
;
#pragma unroll
for
(
int
j
=
p
*
MAX_NPAR
*
HEAD_SIZE
;
j
<
(
p
+
1
)
*
MAX_NPAR
*
HEAD_SIZE
;
j
+=
HEAD_SIZE
)
{
// lastj is last valid partition
const
int
lastj_offset
=
(
j
<
num_partition_offset
)
?
j
:
last_partition_offset
;
tmps
[
idx
]
=
tmp_out_ptr
[
lastj_offset
];
idx
++
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NPAR
;
j
++
)
{
acc
+=
to_float
<
scalar_t
>
(
tmps
[
j
])
*
shared_exp_sums
[
j
+
p
*
MAX_NPAR
];
}
}
}
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
shared_global_exp_sum
+
1e-6
f
);
acc
*=
inv_global_exp_sum
;
const
int64_t
query_start_off
=
static_cast
<
int64_t
>
(
query_start_loc_ptr
?
query_start_loc_ptr
[
seq_idx
]
:
seq_idx
);
OUTT
*
out_ptr
=
out
+
query_start_off
*
num_heads
*
HEAD_SIZE
+
static_cast
<
int64_t
>
(
head_idx
)
*
HEAD_SIZE
;
out_ptr
[
threadIdx
.
x
]
=
from_float
<
scalar_t
>
(
acc
);
}
#else
// clang-format off
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_mfma16_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
UNREACHABLE_CODE
}
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
typename
OUTT
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
bool
ALIBI_ENABLED
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_mfma4_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
OUTT
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
typename
OUTT
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
,
int
NPAR_LOOPS
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_reduce_kernel
(
OUTT
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
UNREACHABLE_CODE
}
// clang-format on
#endif
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
template
<
typename
T
,
typename
KVT
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
typename
OUTT
,
int
PARTITION_SIZE_OLD
,
bool
ALIBI_ENABLED
>
void
paged_attention_custom_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_context_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
std
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
)
{
int
num_seqs
=
block_tables
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const
int
*
query_start_loc_ptr
=
query_start_loc
?
reinterpret_cast
<
const
int
*>
(
query_start_loc
.
value
().
data_ptr
())
:
nullptr
;
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
KVT
*
key_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
key_cache
.
data_ptr
());
KVT
*
value_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
// NOTE: fp8_out_scale is optional.
const
auto
fp8_out_scale_ptr
=
fp8_out_scale
?
static_cast
<
const
float
*>
(
fp8_out_scale
.
value
().
data_ptr
())
:
nullptr
;
OUTT
*
out_ptr
=
reinterpret_cast
<
OUTT
*>
(
out
.
data_ptr
());
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_context_len
,
BLOCK_SIZE
);
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr
int
PARTITION_SIZE
=
256
;
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_context_len
,
PARTITION_SIZE
);
const
int
gqa_ratio
=
num_heads
/
num_kv_heads
;
assert
(
num_heads
%
num_kv_heads
==
0
);
assert
(
head_size
==
HEAD_SIZE
);
constexpr
int
NTHR
=
256
;
dim3
grid
(
num_seqs
,
max_num_partitions
,
num_kv_heads
);
dim3
block
(
NTHR
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// mfma4 kernel is faster than mfma16 for gqa_ratio <= 4
switch
(
gqa_ratio
)
{
case
1
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
1
);
break
;
case
2
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
2
);
break
;
case
3
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
3
);
break
;
case
4
:
LAUNCH_CUSTOM_ATTENTION_MFMA4
(
4
);
break
;
case
5
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
5
);
break
;
case
6
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
6
);
break
;
case
7
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
7
);
break
;
case
8
:
...
...
@@ -1744,13 +3251,195 @@ void paged_attention_custom_launcher(
}
}
template
<
typename
T
,
typename
KVT
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
typename
OUTT
,
int
PARTITION_SIZE_OLD
,
bool
ALIBI_ENABLED
>
void
paged_attention_custom_launcher_navi
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_context_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
)
{
int
num_seqs
=
block_tables
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const
int
*
query_start_loc_ptr
=
query_start_loc
?
reinterpret_cast
<
const
int
*>
(
query_start_loc
.
value
().
data_ptr
())
:
nullptr
;
// NOTE: Navi does not support alibi_slopes.
const
float
*
alibi_slopes_ptr
=
nullptr
;
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
KVT
*
key_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
key_cache
.
data_ptr
());
KVT
*
value_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
// NOTE: Navi does not support fp8.
const
auto
fp8_out_scale_ptr
=
nullptr
;
OUTT
*
out_ptr
=
reinterpret_cast
<
OUTT
*>
(
out
.
data_ptr
());
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_context_len
,
BLOCK_SIZE
);
constexpr
int
PARTITION_SIZE
=
256
;
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_context_len
,
PARTITION_SIZE
);
const
int
gqa_ratio
=
num_heads
/
num_kv_heads
;
assert
(
num_heads
%
num_kv_heads
==
0
);
assert
(
head_size
==
HEAD_SIZE
);
constexpr
int
NTHR
=
256
;
dim3
grid
(
num_seqs
,
max_num_partitions
,
num_kv_heads
);
dim3
block
(
NTHR
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
gqa_ratio
)
{
case
1
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
1
);
break
;
case
2
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
2
);
break
;
case
3
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
3
);
break
;
case
4
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
4
);
break
;
case
5
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
5
);
break
;
case
6
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
6
);
break
;
case
7
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
7
);
break
;
case
8
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
8
);
break
;
case
9
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
9
);
break
;
case
10
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
10
);
break
;
case
11
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
11
);
break
;
case
12
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
12
);
break
;
case
13
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
13
);
break
;
case
14
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
14
);
break
;
case
15
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
15
);
break
;
case
16
:
LAUNCH_CUSTOM_ATTENTION_MFMA16
(
16
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported gqa ratio: "
,
gqa_ratio
);
break
;
}
dim3
reduce_grid
(
num_heads
,
num_seqs
);
dim3
reduce_block
(
head_size
);
const
int
warp_size
=
32
;
const
int
npar_loops
=
DIVIDE_ROUND_UP
(
max_num_partitions
,
warp_size
);
// reduction kernel supports upto 16 NPAR_loops * 32 (warp_size) * 256
// (partition size) = 128K context length
switch
(
npar_loops
)
{
case
1
:
LAUNCH_CUSTOM_REDUCTION
(
1
);
break
;
case
2
:
LAUNCH_CUSTOM_REDUCTION
(
2
);
break
;
case
3
:
LAUNCH_CUSTOM_REDUCTION
(
3
);
break
;
case
4
:
LAUNCH_CUSTOM_REDUCTION
(
4
);
break
;
case
5
:
LAUNCH_CUSTOM_REDUCTION
(
5
);
break
;
case
6
:
LAUNCH_CUSTOM_REDUCTION
(
6
);
break
;
case
7
:
LAUNCH_CUSTOM_REDUCTION
(
7
);
break
;
case
8
:
LAUNCH_CUSTOM_REDUCTION
(
8
);
break
;
case
9
:
LAUNCH_CUSTOM_REDUCTION
(
9
);
break
;
case
10
:
LAUNCH_CUSTOM_REDUCTION
(
10
);
break
;
case
11
:
LAUNCH_CUSTOM_REDUCTION
(
11
);
break
;
case
12
:
LAUNCH_CUSTOM_REDUCTION
(
12
);
break
;
case
13
:
LAUNCH_CUSTOM_REDUCTION
(
13
);
break
;
case
14
:
LAUNCH_CUSTOM_REDUCTION
(
14
);
break
;
case
15
:
LAUNCH_CUSTOM_REDUCTION
(
15
);
break
;
case
16
:
LAUNCH_CUSTOM_REDUCTION
(
16
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported npar_loops: "
,
npar_loops
);
break
;
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED>( \
if (!is_navi) { \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
} else { \
paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
max_context_len, alibi_slopes, k_scale, v_scale); \
}
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE) \
...
...
@@ -1807,6 +3496,24 @@ void paged_attention_custom_launcher(
break; \
}
bool
is_navi_gpu
()
{
static
bool
is_cached
=
false
;
static
bool
result
;
if
(
!
is_cached
)
{
int
device_id
;
hipDeviceProp_t
deviceProp
;
hipGetDevice
(
&
device_id
);
hipGetDeviceProperties
(
&
deviceProp
,
device_id
);
std
::
string
arch
=
deviceProp
.
gcnArchName
;
result
=
arch
.
find
(
"gfx11"
)
==
0
||
arch
.
find
(
"gfx12"
)
==
0
;
is_cached
=
true
;
}
return
result
;
}
// clang-format off
void
paged_attention
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
...
...
@@ -1827,6 +3534,8 @@ void paged_attention(
torch
::
Tensor
&
v_scale
,
const
std
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
)
{
// clang-format on
bool
is_navi
=
is_navi_gpu
();
const
int
head_size
=
query
.
size
(
2
);
if
(
kv_cache_dtype
==
"auto"
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
...
...
tests/kernels/attention/test_attention.py
View file @
dd5fa7e0
...
...
@@ -148,6 +148,11 @@ def test_paged_attention(
or
(
version
==
"rocm"
and
head_size
not
in
(
64
,
128
))):
pytest
.
skip
()
if
(
version
==
"rocm"
and
current_platform
.
is_navi
()
and
(
kv_cache_dtype
==
"fp8"
or
head_size
!=
128
or
block_size
!=
16
or
use_alibi
)):
pytest
.
skip
()
global
PARTITION_SIZE
current_platform
.
seed_everything
(
seed
)
...
...
@@ -275,6 +280,7 @@ def test_paged_attention(
scale
,
block_tables
,
seq_lens
,
None
,
block_size
,
max_seq_len
,
alibi_slopes
,
...
...
@@ -286,7 +292,7 @@ def test_paged_attention(
opcheck
(
torch
.
ops
.
_rocm_C
.
paged_attention
,
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
seq_lens
,
None
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
dd5fa7e0
...
...
@@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
use_custom
=
use_rocm_custom_paged_attention
(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
decode_meta
.
max_decode_seq_len
,
self
.
sliding_window
)
decode_meta
.
max_decode_seq_len
,
self
.
sliding_window
,
self
.
kv_cache_dtype
,
self
.
alibi_slopes
)
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
...
...
vllm/attention/ops/chunked_prefill_paged_decode.py
View file @
dd5fa7e0
...
...
@@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
use_custom
=
use_rocm_custom_paged_attention
(
query
.
dtype
,
head_size
,
block_size
,
num_queries_per_kv
,
max_seq_len
,
sliding_window
)
max_seq_len
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
)
if
use_custom
:
_PARTITION_SIZE_ROCM
=
256
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE_ROCM
-
1
)
//
...
...
vllm/platforms/rocm.py
View file @
dd5fa7e0
...
...
@@ -102,27 +102,43 @@ def on_mi250_mi300() -> bool:
@
cache
def
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
def
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
max_seq_len
:
int
,
sliding_window
:
int
)
->
bool
:
sliding_window
:
int
,
kv_cache_dtype
:
str
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
ON_GFX9
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
])
ON_GFX11_GFX12
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx11"
,
"gfx12"
])
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return
(
ON_GFX9
and
(
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
if
ON_GFX9
:
return
((
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
(
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
(
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
and
not
(
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
and
envs
.
VLLM_ROCM_USE_AITER
))
else
:
return
(
ON_GFX11_GFX12
and
(
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
head_size
==
128
and
block_size
==
16
and
(
gqa_ratio
>=
3
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
alibi_slopes
is
None
and
kv_cache_dtype
==
"auto"
and
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
...
...
@@ -362,3 +378,7 @@ class RocmPlatform(Platform):
def
get_cu_count
(
cls
,
device_id
:
int
=
0
)
->
int
:
return
torch
.
cuda
.
get_device_properties
(
device_id
).
multi_processor_count
@
classmethod
def
is_navi
(
cls
)
->
bool
:
return
'gfx1'
in
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment