Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0c70376b
Commit
0c70376b
authored
Sep 20, 2024
by
zhuwenwen
Browse files
add pa tc
parent
fe1ec8c5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
614 additions
and
583 deletions
+614
-583
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+565
-562
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+44
-18
requirements-rocm.txt
requirements-rocm.txt
+1
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+4
-2
No files found.
csrc/attention/attention_kernels_opt.cu
View file @
0c70376b
...
...
@@ -20,12 +20,30 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize
#endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
inline
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
if
(
status
!=
hipSuccess
)
{
return
std
::
string
();
}
status
=
hipGetDeviceProperties
(
&
props
,
device
);
if
(
status
!=
hipSuccess
)
{
return
std
::
string
();
}
const
std
::
string
raw_name
(
props
.
gcnArchName
);
return
raw_name
.
substr
(
0
,
raw_name
.
find
(
':'
));
// str.substr(0, npos) returns str.
}
namespace
vllm
{
// Utility function for attention softmax.
...
...
@@ -64,16 +82,63 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
using
half4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
_Float16
))
))
_Float16
;
using
v4bh
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
short
))
))
short
;
using
float4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
struct
half4x2
{
half4_t
data
[
2
];
};
template
<
bool
is_half
>
inline
__device__
void
float4_2_half4
(
half4_t
&
dst
,
const
float4_t
&
src
)
{
if
constexpr
(
is_half
){
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
dst
[
i
]
=
src
[
i
];
}
}
else
{
__nv_bfloat16
*
out
=
reinterpret_cast
<
__nv_bfloat16
*>
(
&
dst
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
out
[
i
]
=
__float2bfloat16
(
src
[
i
]);
}
}
}
template
<
bool
is_half
>
inline
__device__
void
v_mmac_f32_16x16x16_f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
&
reg_c
)
{
if
constexpr
(
is_half
){
asm
volatile
(
"v_mmac_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
else
{
asm
volatile
(
"v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
}
template
<
bool
is_half
,
bool
use_vmac
>
inline
__device__
void
builtin_amdgcn_mmac
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
&
reg_c
)
{
if
constexpr
(
use_vmac
){
v_mmac_f32_16x16x16_f16
<
is_half
>
(
reg_a
,
reg_b
,
reg_c
);}
else
{
if
constexpr
(
is_half
){
reg_c
=
__builtin_amdgcn_mmac_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
);}
else
{
reg_c
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
*
(
v4bh
*
)
&
reg_a
,
*
(
v4bh
*
)
&
reg_b
,
reg_c
);
}
}
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
=
1
,
bool
odd_nheads
=
false
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel_opt
(
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
bool
use_vmac
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel_TC
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -84,8 +149,8 @@ __device__ void paged_attention_kernel_opt(
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_
kv_
heads]
const
int
num_heads
,
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
...
...
@@ -99,285 +164,170 @@ __device__ void paged_attention_kernel_opt(
const
int
partition_idx
=
blockIdx
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]
)
;
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
// No work to do. Terminate the thread block.
return
;
}
if
constexpr
(
sizeof
(
scalar_t
)
==
2
){
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
const
int
start_block_idx
=
partition_idx
*
num_blocks_per_partition
;
//0,64,128…
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
//64,128,192…
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
//64 or 1-63
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_THREAD_GROUPS
=
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
);
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
//0,1024,2048…
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
//1024,2048,3072…
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
//1024 or 1-1023
// divides NUM_THREADS
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
//4
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
//8
const
int
thread_idx
=
threadIdx
.
x
;
// const int warp_idx_vec = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
thread_idx
/
WARP_SIZE
);
const
int
lane
=
thread_idx
%
WARP_SIZE
;
int
warp_id_vec
=
threadIdx
.
x
/
WARP_SIZE
;
//warp id in a block
int
warp_idx
=
0
;
asm
volatile
(
"v_readfirstlane_b32 %0,%1"
:
"=s"
(
warp_idx
)
:
"v"
(
warp_id_vec
)
:
);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const
int
rowid
=
lane
%
16
;
const
int
rows
=
lane
/
16
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
// const float alibi_slope =
// alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr
int
VEC_SIZE
=
MAX
(
32
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
)),
1
);
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Quant_vec
=
typename
Vec
<
cache_t
,
VEC_SIZE
>::
Type
;
constexpr
int
NUM_ELEMS_PER_THREAD
=
HEAD_SIZE
/
THREAD_GROUP_SIZE
;
constexpr
int
NUM_VECS_PER_THREAD
=
NUM_ELEMS_PER_THREAD
/
VEC_SIZE
;
const
int
thread_group_idx
=
thread_idx
/
THREAD_GROUP_SIZE
;
const
int
thread_group_offset
=
thread_idx
%
THREAD_GROUP_SIZE
;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
// const scalar_t* q_ptr = q + seq_idx * q_stride;
const
scalar_t
*
q_ptr_offset
=
q
+
seq_idx
*
q_stride
;
__shared__
Q_vec
q_vecs
[
REUSE_KV_TIMES
*
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
// #pragma unroll
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// i += NUM_THREAD_GROUPS) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[thread_group_offset][i] =
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
odd_tg_round
=
(((
blockIdx
.
z
*
gridDim
.
y
*
gridDim
.
x
)
+
blockIdx
.
y
*
gridDim
.
x
)
/
128
)
%
2
;
const
int
mid_x
=
gridDim
.
x
/
2
;
const
int
blockIdx_shift
=
(
odd_tg_round
|
(
gridDim
.
x
&
1
))
?
blockIdx
.
x
:
(
blockIdx
.
x
<
mid_x
?
(
blockIdx
.
x
+
mid_x
)
:
(
blockIdx
.
x
-
mid_x
));
const
int
head_idx
=
(
blockIdx_shift
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx_shift
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
//const int head_idx=(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
int
q_boundary
=
REUSE_KV_TIMES
;
if
(
num_heads
<
REUSE_KV_TIMES
*
gridDim
.
x
&&
(
num_blocks_per_kv
-
1
)
*
REUSE_KV_TIMES
==
head_idx
%
num_queries_per_kv
)
q_boundary
=
num_queries_per_kv
-
(
num_blocks_per_kv
-
1
)
*
REUSE_KV_TIMES
;
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
constexpr
int
reuse_group
=
(
REUSE_KV_TIMES
-
1
)
/
4
+
1
;
float
alibi_slope
[
reuse_group
]
=
{
0.
f
};
if
(
alibi_slopes
!=
nullptr
){
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
){
int
reuse_kv_idx
=
rows
+
i
*
4
;
if
(
reuse_kv_idx
<
q_boundary
)
alibi_slope
[
i
]
=
alibi_slopes
[
head_idx
+
reuse_kv_idx
];
}
}
float
qk_max
[
reuse_group
];
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
){
qk_max
[
i
]
=-
FLT_MAX
;
}
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
half4x2
q_vec
;
q_vec
.
data
[
0
]
=
{
0
,
0
,
0
,
0
};
q_vec
.
data
[
1
]
=
{
0
,
0
,
0
,
0
};
__shared__
half4x2
q_vecs
[
REUSE_KV_TIMES
][
16
];
//if(thread_idx==0)printf("blockIdx.x==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.x,q_boundary,head_idx,kv_head_idx);
for
(
int
i
=
0
;
i
<
REUSE_KV_TIMES
;
i
++
){
if
(
thread_idx
<
16
){
q_vecs
[
i
][
thread_idx
]
=*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
}
}
__syncthreads
();
// Memory planning.
extern
__shared__
char
shared_mem
[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
floa
t
*
logits
=
reinterpret_cast
<
floa
t
*>
(
shared_mem
);
scalar_
t
*
logits
=
reinterpret_cast
<
scalar_
t
*>
(
shared_mem
);
// Workspace for reduction.
__shared__
float
red_smem
[
REUSE_KV_TIMES
][
2
*
NUM_WARPS
];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
float
qk_max
[
REUSE_KV_TIMES
];
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
qk_max
[
reuse_kv_idx
]
=
-
FLT_MAX
;
}
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
head_idx_soffset
=
(
blockIdx
.
x
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
x
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
const
int
kv_head_idx
=
head_idx_soffset
/
num_queries_per_kv
;
const
int
q_boundary
=
(
kv_head_idx
+
1
)
*
num_queries_per_kv
;
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const
scalar_t
*
q_ptr
=
q_ptr_offset
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
][
i
]
=
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
}
}
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
// blocksparse specific vars
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
bool
is_remote
=
((
k_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
);
const
bool
is_local
=
(
k_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
);
if
(
!
is_remote
&&
!
is_local
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
(
thread_group_offset
==
0
)
{
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits
[
token_idx
-
start_token_idx
]
=
-
FLT_MAX
;
}
}
continue
;
}
}
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// blocksparse specific vars
int
bs_block_offset
;
int
q_bs_block_id
;
const
cache_t
*
k_ptr_base
=
k_cache
+
kv_head_idx
*
kv_head_stride
+
lane
*
8
;
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
if
(
reuse_kv_idx
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
k_scale
);
}
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
cache_t
*
k_ptr
=
k_ptr_base
+
physical_block_number
*
kv_block_stride
;
float4_t
qk_vec
=
{
0
,
0
,
0
,
0
};
half4x2
k_vec
[
2
];
k_vec
[
0
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
);
#pragma unroll
for
(
int
i
=
0
;
i
<
3
;
i
++
){
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
i
*
4
+
rows
];
k_vec
[
1
-
i
%
2
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
(
i
+
1
)
*
512
);
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
k_vec
[
i
%
2
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
k_vec
[
i
%
2
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
}
//tail
{
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
3
*
4
+
rows
];
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
k_vec
[
1
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
k_vec
[
1
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
){
int
reuse_kv_idx
=
rows
+
i
*
4
;
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
else
qk_vec
[
i
]
*=
scale
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
if
(
alibi_slope
[
i
]
!=
0
){
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
qk_vec
[
i
]
+=
alibi
;
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
if
(
mask
){
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
}
else
{
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
qk_vec
[
i
]);
qk_max
[
i
]
=
fmaxf
(
qk_max
[
i
],
qk_vec
[
i
]);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
__builtin_amdgcn_sched_barrier
(
0
);
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
logits
[(
reuse_kv_idx
*
partition_size
)
+
(
token_idx
-
start_token_idx
)]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
[
reuse_kv_idx
]
=
mask
?
qk_max
[
reuse_kv_idx
]
:
fmaxf
(
qk_max
[
reuse_kv_idx
],
qk
);
}
}
}
}
}
// Get the sum of the exp values.
float
exp_sum
[
REUSE_KV_TIMES
]
=
{
0.
f
};
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
[
reuse_kv_idx
]
=
fmaxf
(
qk_max
[
reuse_kv_idx
],
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
],
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
reuse_kv_idx
][
warp_idx
]
=
qk_max
[
reuse_kv_idx
];
}
__syncthreads
();
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
const
int
head_idx_
=
head_idx
+
reuse_kv_idx
;
float
qk_max_tmp
=
qk_max
[
reuse_kv_idx
/
4
];
float
exp_sum
=
0.
f
;
#pragma unroll
for
(
int
mask
=
8
;
mask
>=
1
;
mask
/=
2
)
{
qk_max_tmp
=
fmaxf
(
qk_max_tmp
,
VLLM_SHFL_XOR_SYNC
(
qk_max_tmp
,
mask
));
}
if
(
rowid
==
0
&&
reuse_kv_idx
%
4
==
rows
)
{
red_smem
[
warp_idx
]
=
qk_max_tmp
;
}
__syncthreads
();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max
[
reuse_kv_idx
]
=
lane
<
NUM_WARPS
?
red_smem
[
reuse_kv_idx
][
lane
]
:
-
FLT_MAX
;
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max
_tmp
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
[
reuse_kv_idx
]
=
fmaxf
(
qk_max
[
reuse_kv_idx
]
,
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
]
,
mask
));
qk_max
_tmp
=
fmaxf
(
qk_max
_tmp
,
VLLM_SHFL_XOR_SYNC
(
qk_max
_tmp
,
mask
));
}
// Broadcast the max qk value to all threads.
qk_max
[
reuse_kv_idx
]
=
VLLM_SHFL_SYNC
(
qk_max
[
reuse_kv_idx
],
0
);
qk_max_tmp
=
VLLM_SHFL_SYNC
(
qk_max_tmp
,
0
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
qk_max
[
reuse_kv_idx
]
);
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
exp_sum
[
reuse_kv_idx
]
+=
val
;
float
val
=
__expf
(
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
-
qk_max
_tmp
);
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
val
)
;
exp_sum
+=
val
;
}
exp_sum
[
reuse_kv_idx
]
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
reuse_kv_idx
][
NUM_WARPS
],
exp_sum
[
reuse_kv_idx
]);
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
[
reuse_kv_idx
]
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
*=
inv_sum
;
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
])
*
inv_sum
)
;
}
__syncthreads
();
...
...
@@ -385,222 +335,212 @@ __device__ void paged_attention_kernel_opt(
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
[
reuse_kv_idx
]
;
head_idx
_
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
_tmp
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
[
reuse_kv_idx
]
;
head_idx
_
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
}
}
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
L_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
V_quant_vec
=
typename
Vec
<
cache_t
,
V_VEC_SIZE
>::
Type
;
using
Float_L_vec
=
typename
FloatVec
<
L_vec
>::
Type
;
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
constexpr
(
REUSE_KV_TIMES
<=
2
&&
(
NUM_WARPS
>
64
||
USE_PARTITIONING
)){
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
reuse_kv_idx
][
i
]
=
0.
f
;
#pragma unroll
for
(
int
k
=
0
;
k
<
REUSE_KV_TIMES
;
k
++
)
{
accs
[
k
][
i
]
=
0.
f
;
}
}
}
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
V_vec
v_vec
;
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
4
*
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
continue
;
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
v_vec
,
logits_vec
,
out_vec
);
if
(
rows
==
k
){
for
(
int
resuseid
=
0
;
resuseid
<
REUSE_KV_TIMES
;
resuseid
++
){
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
}
}
}
}
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
(
reuse_kv_idx
*
partition_size
)
+
token_idx
-
start_token_idx
));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if
(
reuse_kv_idx
==
0
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_scale
);
}
}
__syncthreads
();
// Perform reduction across warps.
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
][
i
];
}
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
//
context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
__syncthreads
();
//
Lower warps update the output.
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
]
;
}
}
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs
[
reuse_kv_idx
][
i
]
+=
dot
(
logits_vec
,
v_vec
);
__syncthreads
();
}
// Write the final output.
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
(
head_idx
+
reuse_kv_idx
)
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reuse_kv_idx
][
i
]);
}
}
}
}
// Perform reduction within each warp.
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
reuse_kv_idx
][
i
];
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
else
{
constexpr
int
GROUPS
=
reuse_group
*
4
;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
GROUPS
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
GROUPS
;
k
++
)
{
accs
[
k
][
i
]
=
0.
f
;
}
}
accs
[
reuse_kv_idx
][
i
]
=
acc
;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads
();
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
][
i
];
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
v_vec
,
logits_vec
,
out_vec
);
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
accs
[
g
*
4
+
k
][
i
]
+=
out_vec
[
g
];
}
}
}
}
}
}
__syncthreads
();
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
];
// Perform reduction across warps.
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
GROUPS
;
reuse_kv_idx
++
)
{
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
][
i
];
}
}
__syncthreads
();
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
];
}
}
__syncthreads
();
}
// Write the final output.
}
__syncthreads
();
}
// Write the final output.
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reuse_kv_idx
][
i
]);
if
(
warp_idx
==
0
)
{
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
int
reusekvid
=
g
*
4
+
rows
;
if
(
reusekvid
<
q_boundary
){
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
(
head_idx
+
reusekvid
)
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
){
const
int
row_idx
=
rowid
+
16
*
k
+
i
*
WARP_SIZE
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
g
*
4
+
k
][
i
]);
}
}
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
REUSE_KV_TIMES
=
1
,
bool
IS_BLOCK_SPARSE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel_opt
(
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
bool
use_vmac
>
__global__
void
paged_attention_v1_kernel_TC
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_heads
,
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
...
@@ -608,28 +548,27 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
#ifdef __gfx928__
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
#endif
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
int
PARTITION_SIZE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
bool
use_vmac
,
int
PARTITION_SIZE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel_
opt
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel_
TC
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -640,7 +579,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
...
@@ -648,23 +587,24 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
#ifdef __gfx928__
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
#endif
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel_opt
(
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel_opt
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -717,7 +657,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
...
...
@@ -727,7 +667,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
...
...
@@ -757,7 +697,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
...
...
@@ -770,37 +710,27 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
}
// namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel_
opt
<T, CACHE_T, HEAD_SIZE, \
((void*)vllm::paged_attention_v1_kernel_
TC
<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE,
REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads
>), \
KV_DTYPE,
IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac
>), \
shared_mem_size); \
hipLaunchKernelGGL((
vllm::paged_attention_v1_kernel_
opt
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE,
REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>)
\
, dim3(
grid
)
,
dim3(
block
)
, shared_mem_size, stream
,
\
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,
num_kv_heads, \
vllm::paged_attention_v1_kernel_
TC
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE,
IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>
\
<<<
grid, block, shared_mem_size, stream
>>>(
\
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
\
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v1_launcher
(
void
paged_attention_v1_launcher
_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
...
...
@@ -816,11 +746,11 @@ void paged_attention_v1_launcher(
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
num_threads
=
128
;
if
(
num_heads
!=
num_kv_heads
){
num_threads
=
256
;
// printf("paged_attention_v1\n");
if
(
num_heads
!=
num_kv_heads
)
{
num_threads
=
256
;
}
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
...
...
@@ -835,39 +765,48 @@ void paged_attention_v1_launcher(
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
REUSEKV_SWITCH_V1
(
num_heads
*
num_seqs
,
[
&
]
{
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_threads
,
[
&
]
{
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
sizeof
(
float
);
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
if
(
num_heads
==
num_kv_heads
)
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
LAUNCH_PAGED_ATTENTION_V1
(
HEAD_SIZE
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
// if(head_size==128&&get_device_name()=="gfx928"){
REUSEKV_SWITCH_V1
([
&
]
{
constexpr
int
HEAD_SIZE
=
128
;
// constexpr int REUSE_KV_TIMES=8;
int
num_thread
=
64
;
if
(
REUSE_KV_TIMES
>
1
){
if
(
padded_max_seq_len
>
1024
||
num_heads
*
num_seqs
/
REUSE_KV_TIMES
<
600
)
num_thread
=
256
;
else
num_thread
=
128
;
}
else
if
(
num_heads
*
num_seqs
<
800
)
num_thread
=
128
;
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
static
int
use_vmac
=
false
;
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
2
;
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
if
(
REUSE_KV_TIMES
==
1
)
outputs_size
=
0
;
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
if
(
num_heads
==
num_kv_heads
)
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
LAUNCH_PAGED_ATTENTION_V1_TC
(
HEAD_SIZE
);
});
});
});
});
});
}
// }
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v1_launcher
_opt
<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank,
\
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
...
...
@@ -899,6 +838,24 @@ void paged_attention_v1_launcher(
break; \
}
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
...
...
@@ -912,37 +869,46 @@ void paged_attention_v1_opt(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V1_LAUNCHER_BLOCK_SIZE
)
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V1_LAUNCHER_BLOCK_SIZE
)
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_kernel_TC< \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, PARTITION_SIZE>), \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
NUM_THREADS
=
256
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_v2_launcher
(
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_v2_launcher_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
...
...
@@ -958,8 +924,8 @@ void paged_attention_v2_launcher(
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
// printf("paged_attention_v2\n");
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
...
...
@@ -977,40 +943,46 @@ void paged_attention_v2_launcher(
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
REUSEKV_SWITCH
(
num_heads
*
max_num_partitions
*
num_seqs
,
[
&
]
{
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3
grid
;
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
grid
.
y
=
max_num_partitions
;
grid
.
z
=
num_seqs
;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
// For paged attention v2 reduce kernel.
dim3
reduce_grid
(
num_heads
,
num_seqs
);
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
dim3
block
(
NUM_THREADS
);
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
LAUNCH_PAGED_ATTENTION_V2
(
HEAD_SIZE
);
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
//if(head_size==128&&get_device_name()=="gfx928"){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
static
int
use_vmac
=
false
;
REUSEKV_SWITCH_V2
([
&
]
{
int
num_thread
;
if
(
REUSE_KV_TIMES
>
1
){
if
(
num_seqs
<
16
)
num_thread
=
256
;
else
if
(
max_num_partitions
*
num_seqs
*
num_heads
/
REUSE_KV_TIMES
>
4000
)
num_thread
=
64
;
else
num_thread
=
128
;
}
else
{
if
(
num_seqs
<
16
&&
max_num_partitions
<
10
)
num_thread
=
256
;
else
num_thread
=
64
;
}
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
grid
.
y
=
max_num_partitions
;
grid
.
z
=
num_seqs
;
dim3
block
(
NUM_THREADS
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
LAUNCH_PAGED_ATTENTION_V2_TC
(
HEAD_SIZE
);
});
});
}
);
});
}
//}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,
\
paged_attention_v2_launcher
_opt
<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
...
...
@@ -1046,7 +1018,7 @@ void paged_attention_v2_launcher(
break; \
}
void
paged_attention_v2
_opt
(
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
...
...
@@ -1063,13 +1035,44 @@ void paged_attention_v2_opt(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
paged_attention_v2
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
}
}
#undef WARP_SIZE
...
...
csrc/attention/static_switch.h
View file @
0c70376b
...
...
@@ -9,25 +9,17 @@
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
}
else
{
\
}else
if (NUM_THREAD == 128) {
\
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 64; \
return __VA_ARGS__(); \
} \
}()
...
...
@@ -45,12 +37,12 @@
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 120) { \
constexpr static int HEAD_SIZE = 120; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
...
...
@@ -74,14 +66,48 @@
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
#define REUSEKV_SWITCH_V2( ...) \
[&] { \
if (num_heads / num_kv_heads > 8 ){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 4 ){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads >2 && padded_max_seq_len<7800){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads ==2 && padded_max_seq_len<15600){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}
else { \
}else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \
[&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \
constexpr static int use_vmac = false; \
return __VA_ARGS__(); \
} else { \
constexpr static int use_vmac = true; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
requirements-rocm.txt
View file @
0c70376b
...
...
@@ -14,4 +14,4 @@ torch == 2.3.0
triton == 2.1.0
flash_attn == 2.6.1
xformers == 0.0.25
lmslim == 0.1.0
\ No newline at end of file
lmslim == 0.1.1
\ No newline at end of file
vllm/attention/ops/paged_attn.py
View file @
0c70376b
...
...
@@ -124,8 +124,10 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
# use_v1 = (max_seq_len <= 8192
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1
=
(
max_seq_len
<
8192
and
(
max_seq_len
<
1000
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
if
use_v1
:
# Run PagedAttention V1.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment