Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9cc373f3
"vscode:/vscode.git/clone" did not exist on "5c3fbfe46bf62d339a42476120d9bf268fedfa24"
Unverified
Commit
9cc373f3
authored
Sep 19, 2024
by
Charlie Fu
Committed by
GitHub
Sep 19, 2024
Browse files
[Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (#8577)
parent
76515f30
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
246 additions
and
283 deletions
+246
-283
csrc/rocm/attention.cu
csrc/rocm/attention.cu
+161
-79
csrc/rocm/ops.h
csrc/rocm/ops.h
+2
-1
csrc/rocm/torch_bindings.cpp
csrc/rocm/torch_bindings.cpp
+2
-1
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+63
-188
vllm/_custom_ops.py
vllm/_custom_ops.py
+3
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+15
-13
No files found.
csrc/rocm/attention.cu
View file @
9cc373f3
...
...
@@ -18,8 +18,11 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_bf16.h>
#include "cuda_compat.h"
#include <algorithm>
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
...
...
@@ -38,7 +41,6 @@
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define WARP_SIZE 64
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
...
...
@@ -60,6 +62,8 @@ typedef struct _B16x8 {
_B16x4
xy
[
2
];
}
_B16x8
;
using
_B8x8
=
uint2
;
////// Non temporal load stores ///////
template
<
typename
T
>
...
...
@@ -168,17 +172,39 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
}
}
template
<
typename
T
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
>
__device__
__forceinline__
_B16x8
scaled_convert_b8x8
(
const
_B8x8
input
,
const
float
scale
)
{
union
alignas
(
16
)
{
uint4
u4
;
_B16x8
u16x8
;
vllm
::
bf16_8_t
b16x8
;
}
tmp
;
if
constexpr
(
std
::
is_same
<
T
,
_Float16
>::
value
)
{
tmp
.
u4
=
vllm
::
fp8
::
scaled_convert
<
uint4
,
_B8x8
,
KV_DTYPE
>
(
input
,
scale
);
return
tmp
.
u16x8
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
__hip_bfloat16
>::
value
)
{
tmp
.
b16x8
=
vllm
::
fp8
::
scaled_convert
<
vllm
::
bf16_8_t
,
_B8x8
,
KV_DTYPE
>
(
input
,
scale
);
return
tmp
.
u16x8
;
}
else
{
static_assert
(
false
,
"unsupported 16b dtype"
);
}
}
///////////////////////////////////////
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// block (partition size)
template
<
typename
scalar_t
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
s
ca
lar
_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
ca
che
_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
s
ca
lar
_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
ca
che
_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
...
@@ -192,10 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
#if 0
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
#endif
int
max_ctx_blocks
)
{
int
max_ctx_blocks
,
float
k_scale
,
float
v_scale
)
{
constexpr
int
NWARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -222,12 +245,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
constexpr
int
x
=
16
/
sizeof
(
scalar_t
);
constexpr
int
KHELOOP
=
HEAD_SIZE
/
x
;
_B16x8
Klocal
[
KHELOOP
];
_B8x8
Klocalb8
[
KHELOOP
];
constexpr
int
VHELOOP
=
HEAD_SIZE
/
WARP_SIZE
;
// v head_size dimension is distributed across lanes
constexpr
int
VTLOOP
=
8
;
// 16 separate 4xtokens across warp -> 16/2
// 8xtokens
_B16x8
Vlocal
[
VHELOOP
][
VTLOOP
];
_B8x8
Vlocalb8
[
VHELOOP
][
VTLOOP
];
floatx4
dout
[
QHLOOP
];
float
qk_max
[
QHLOOP
];
#pragma unroll
...
...
@@ -279,6 +304,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
(
vblock_idx
<=
last_ctx_block
)
?
vblock_idx
:
last_ctx_block
;
vphysical_blocks
[
b
]
=
block_table
[
vblock_idx_ctx
];
}
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
wg_start_head_idx
*
HEAD_SIZE
;
...
...
@@ -298,18 +324,30 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
Qlocal
[
QHLOOP
-
1
].
xy
[
1
]
=
{
0
};
}
const
s
ca
lar
_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
const
ca
che
_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
wg_start_kv_head_idx
*
kv_head_stride
;
const
int
physical_block_offset
=
local_token_idx
%
BLOCK_SIZE
;
// since x=half8, physical_block_offset
// is already cast as _H8
if
constexpr
(
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
const
_B16x8
*
k_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
k_ptr
);
#pragma unroll
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
Klocal
[
d
]
=
k_ptrh8
[
d
*
BLOCK_SIZE
+
physical_block_offset
];
}
}
else
{
constexpr
int
X
=
16
/
sizeof
(
cache_t
);
const
cache_t
*
k_ptr2
=
k_ptr
+
physical_block_offset
*
X
;
#pragma unroll
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
const
int
head_elem
=
d
*
8
;
const
int
offset1
=
head_elem
/
X
;
const
int
offset2
=
head_elem
%
X
;
const
cache_t
*
k_ptr3
=
k_ptr2
+
offset1
*
BLOCK_SIZE
*
X
+
offset2
;
Klocalb8
[
d
]
=
*
reinterpret_cast
<
const
_B8x8
*>
(
k_ptr3
);
}
}
float
alibi_slope
[
QHLOOP
];
if
(
alibi_slopes
!=
nullptr
)
{
...
...
@@ -322,7 +360,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
const
scalar_t
*
v_ptr
=
v_cache
+
wg_start_kv_head_idx
*
kv_head_stride
;
const
cache_t
*
v_ptr
=
v_cache
+
wg_start_kv_head_idx
*
kv_head_stride
;
if
constexpr
(
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
const
_B16x8
*
v_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
v_ptr
);
// iterate over each v block
#pragma unroll
...
...
@@ -345,6 +384,41 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}
}
}
else
{
const
_B8x8
*
v_ptrh8
=
reinterpret_cast
<
const
_B8x8
*>
(
v_ptr
);
// iterate over each v block
#pragma unroll
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const
int64_t
vphysical_block_number
=
static_cast
<
int64_t
>
(
vphysical_blocks
[
b
]);
const
_B8x8
*
v_ptrh8b
=
v_ptrh8
+
(
vphysical_block_number
*
kv_block_stride
)
/
8
;
// iterate over each head elem (within head_size)
#pragma unroll
for
(
int
h
=
0
;
h
<
VHELOOP
;
h
++
)
{
const
int
head_size_elem
=
h
*
WARP_SIZE
+
laneid
;
const
_B8x8
*
v_ptrh8be
=
v_ptrh8b
+
head_size_elem
*
BLOCK_SIZE
/
8
;
// iterate over all velems within block
#pragma unroll
for
(
int
d
=
0
;
d
<
BLOCK_SIZE
/
8
;
d
++
)
{
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const
_B8x8
Vlocalb8
=
v_ptrh8be
[
d
];
Vlocal
[
h
][
b
*
BLOCK_SIZE
/
8
+
d
]
=
scaled_convert_b8x8
<
scalar_t
,
KV_DTYPE
>
(
Vlocalb8
,
v_scale
);
}
}
}
}
if
constexpr
(
KV_DTYPE
!=
vllm
::
Fp8KVCacheDataType
::
kAuto
)
{
#pragma unroll
for
(
int
d
=
0
;
d
<
KHELOOP
;
d
++
)
{
Klocal
[
d
]
=
scaled_convert_b8x8
<
scalar_t
,
KV_DTYPE
>
(
Klocalb8
[
d
],
k_scale
);
}
}
#pragma unroll
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
...
...
@@ -794,13 +868,15 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
template
<
typename
scalar_t
,
typename
cache_t
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
GQA_RATIO
>
__global__
__launch_bounds__
(
NUM_THREADS
)
void
paged_attention_ll4mi_QKV_kernel
(
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
s
ca
lar
_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
ca
che
_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
s
ca
lar
_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
ca
che
_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
...
@@ -814,10 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t
*
__restrict__
final_out
,
// [num_seqs, num_heads, head_size]
#if 0
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
#endif
int
max_ctx_blocks
)
{
int
max_ctx_blocks
,
float
k_scale
,
float
v_scale
)
{
UNREACHABLE_CODE
}
...
...
@@ -839,26 +912,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \
paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, \
NTHR, GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks);
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale, v_scale);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
PARTITION_SIZE
=
256
>
template
<
typename
T
,
typename
KVT
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_custom_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
max_context_len
,
#if 0
torch::Tensor& qk_out,
torch::Tensor& softmax_out,
#endif
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -878,14 +949,10 @@ void paged_attention_custom_launcher(
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
KV
T
*
key_cache_ptr
=
reinterpret_cast
<
KV
T
*>
(
key_cache
.
data_ptr
());
KV
T
*
value_cache_ptr
=
reinterpret_cast
<
KV
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
#if 0
T* qk_out_ptr = reinterpret_cast<T*>(qk_out.data_ptr());
T* softmax_out_ptr = reinterpret_cast<T*>(softmax_out.data_ptr());
#endif
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_context_len
,
BLOCK_SIZE
);
const
int
max_num_partitions
=
...
...
@@ -972,32 +1039,32 @@ void paged_attention_custom_launcher(
}
}
#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE)
\
paged_attention_custom_launcher<T, BLK_SIZE, HEAD_SIZE>(
\
#define CALL_CUSTOM_LAUNCHER(T,
KVT, KV_DTYPE,
BLK_SIZE, HEAD_SIZE) \
paged_attention_custom_launcher<T,
KVT, KV_DTYPE,
BLK_SIZE, HEAD_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
alibi_slopes);
alibi_slopes
, k_scale, v_scale
);
#define CALL_CUSTOM_LAUNCHER_BLK(T,
HEAD_SIZE)
\
#define CALL_CUSTOM_LAUNCHER_BLK(T,
KVT, KV_DTYPE, HEAD_SIZE)
\
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE);
\
CALL_CUSTOM_LAUNCHER(T,
KVT, KV_DTYPE,
16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE);
\
CALL_CUSTOM_LAUNCHER(T,
KVT, KV_DTYPE,
32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T
)
\
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T
, KVT, KV_DTYPE)
\
switch (head_size) { \
case 64: \
CALL_CUSTOM_LAUNCHER_BLK(T,
64);
\
CALL_CUSTOM_LAUNCHER_BLK(T,
KVT, KV_DTYPE, 64);
\
break; \
case 128: \
CALL_CUSTOM_LAUNCHER_BLK(T,
128);
\
CALL_CUSTOM_LAUNCHER_BLK(T,
KVT, KV_DTYPE, 128);
\
break; \
default: \
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
...
...
@@ -1020,16 +1087,31 @@ void paged_attention(
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
)
{
assert
(
kv_cache_dtype
==
"auto"
);
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
)
{
const
int
head_size
=
query
.
size
(
2
);
if
(
kv_cache_dtype
==
"auto"
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
_Float16
,
_Float16
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
__hip_bfloat16
,
__hip_bfloat16
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
else
if
(
kv_cache_dtype
==
"fp8"
||
kv_cache_dtype
==
"fp8_e4m3"
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
_Float16
);
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
_Float16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
__hip_bfloat16
);
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
__hip_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
else
{
TORCH_CHECK
(
false
,
"Unsupported KV cache dtype: "
,
kv_cache_dtype
);
}
}
#undef WARP_SIZE
...
...
csrc/rocm/ops.h
View file @
9cc373f3
...
...
@@ -10,4 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch
::
Tensor
&
context_lens
,
int64_t
block_size
,
int64_t
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
);
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
);
csrc/rocm/torch_bindings.cpp
View file @
9cc373f3
...
...
@@ -26,7 +26,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor context_lens, int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype) -> ()"
);
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()"
);
rocm_ops
.
impl
(
"paged_attention"
,
torch
::
kCUDA
,
&
paged_attention
);
}
...
...
tests/kernels/test_attention.py
View file @
9cc373f3
...
...
@@ -31,8 +31,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
if
not
is_hip
()
else
[
64
,
80
,
96
,
112
,
128
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
...
...
@@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention(
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
])
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
]
if
not
is_hip
()
else
[
"v1"
,
"v2"
,
"rocm"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
...
@@ -137,7 +137,8 @@ def test_paged_attention(
seed
:
int
,
device
:
str
,
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
if
((
kv_cache_dtype
==
"fp8"
and
head_size
%
16
)
or
(
version
==
"rocm"
and
head_size
not
in
(
64
,
128
))):
pytest
.
skip
()
seed_everything
(
seed
)
...
...
@@ -206,7 +207,7 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
elif
version
==
"v2"
:
elif
version
in
(
"v2"
,
"rocm"
)
:
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
...
...
@@ -219,6 +220,7 @@ def test_paged_attention(
dtype
=
torch
.
float32
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
version
==
"v2"
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
...
...
@@ -240,10 +242,38 @@ def test_paged_attention(
)
opcheck
(
torch
.
ops
.
_C
.
paged_attention_v2
,
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
else
:
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
opcheck
(
torch
.
ops
.
_rocm_C
.
paged_attention
,
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
else
:
...
...
@@ -328,162 +358,6 @@ def ref_multi_query_kv_attention(
return
torch
.
cat
(
ref_outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"rocm"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
64
,
128
])
# only test 64 128
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
USE_ALIBI
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
not
is_hip
(),
reason
=
"only for rocm"
)
def
test_paged_attention_rocm
(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
seed
:
int
,
device
:
str
,
)
->
None
:
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
context_lens
[
-
1
]
=
MAX_SEQ_LEN
#context_lens = [8192 for _ in range(num_seqs)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
)
#print('>>> ctx lens', context_lens)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
seed
,
device
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# TODO(charlifu) enable fp8 kv cache
# Using default kv_scale
# kv_scale = 1.0
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
PARTITION_SIZE_ROCM
=
256
num_partitions
=
((
max_context_len
+
PARTITION_SIZE_ROCM
-
1
)
//
PARTITION_SIZE_ROCM
)
assert
PARTITION_SIZE_ROCM
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
version
==
"rocm"
:
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
)
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
# Run the reference implementation.
if
kv_cache_dtype
==
"fp8"
:
# Convert cache data back to dtype.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
//
x
,
block_size
,
x
)
dequantized_key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
device
)
ops
.
convert_fp8
(
key_cache
,
dequantized_key_cache
)
key_cache
=
dequantized_key_cache
value_cache_shape
=
value_cache
.
shape
dequantized_value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
device
)
ops
.
convert_fp8
(
value_cache
,
dequantized_value_cache
)
value_cache
=
dequantized_value_cache
ref_output
=
torch
.
empty_like
(
query
)
ref_single_query_cached_kv_attention
(
ref_output
,
query
,
num_queries_per_kv
,
key_cache
,
value_cache
,
block_tables
,
context_lens
,
scale
,
alibi_slopes
,
)
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol
,
rtol
=
1e-4
,
1e-5
if
dtype
==
torch
.
bfloat16
:
atol
,
rtol
=
2e-4
,
1e-5
if
use_alibi
:
if
dtype
==
torch
.
half
:
atol
,
rtol
=
5e-4
,
1e-5
if
dtype
==
torch
.
bfloat16
:
atol
,
rtol
=
1e-3
,
1e-5
if
kv_cache_dtype
==
"fp8"
:
atol
,
rtol
=
1e-2
,
1e-5
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
# TODO(woosuk): Add tests for USE_ALIBI=True.
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
...
@@ -491,7 +365,8 @@ def test_paged_attention_rocm(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"skip for rocm"
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Xformers backend is not supported on ROCm."
)
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention
(
num_seqs
:
int
,
...
...
vllm/_custom_ops.py
View file @
9cc373f3
...
...
@@ -146,12 +146,14 @@ def paged_attention_rocm(
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
torch
.
ops
.
_rocm_C
.
paged_attention
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
)
kv_cache_dtype
,
k_scale
,
v_scale
)
# pos encoding ops
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
9cc373f3
...
...
@@ -17,8 +17,8 @@ from vllm.platforms import current_platform
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE
=
256
ON_NAVI
=
"gfx1"
in
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
_PARTITION_SIZE
_ROCM
=
512
_
ON_NAVI
=
"gfx1"
in
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -489,14 +489,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
use_custom
=
use_rocm_custom_paged_attention
(
decode_query
.
dtype
,
head_size
,
block_size
,
self
.
kv_cache_dtype
,
gqa_ratio
,
decode_meta
.
max_decode_seq_len
)
use_custom
=
_
use_rocm_custom_paged_attention
(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
decode_meta
.
max_decode_seq_len
)
if
use_custom
:
max_seq_len
=
decode_meta
.
max_decode_seq_len
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
assert
_PARTITION_SIZE
%
block_size
==
0
max_num_partitions
=
(
(
max_seq_len
+
_PARTITION_SIZE_ROCM
-
1
)
//
_PARTITION_SIZE_ROCM
)
assert
_PARTITION_SIZE_ROCM
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
...
...
@@ -524,6 +525,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
...
...
@@ -580,12 +583,11 @@ def _sdpa_attention(
return
output
def
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
kv_cache_dtype
:
str
,
gqa_ratio
:
int
,
max_seq_len
:
int
)
->
bool
:
def
_
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
max_seq_len
:
int
)
->
bool
:
# rocm custom page attention not support on navi (gfx1*)
return
(
not
ON_NAVI
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
return
(
not
_
ON_NAVI
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
kv_cache_dtype
==
"auto"
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment