Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b91ac93
Commit
2b91ac93
authored
Sep 24, 2024
by
zhuwenwen
Browse files
优化pa小batch性能(pa_v2),优化pa小seq性能(pa_v1),reusekv=16优化
parent
de7d9456
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
82 deletions
+140
-82
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+133
-46
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+6
-35
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+1
-1
No files found.
csrc/attention/attention_kernels_opt.cu
View file @
2b91ac93
...
...
@@ -43,7 +43,16 @@ inline std::string get_device_name()
const
std
::
string
raw_name
(
props
.
gcnArchName
);
return
raw_name
.
substr
(
0
,
raw_name
.
find
(
':'
));
// str.substr(0, npos) returns str.
}
static
inline
int
get_env_
(
const
char
*
env_var
)
{
if
(
char
*
value
=
std
::
getenv
(
env_var
))
{
return
atoi
(
value
);
}
return
0
;
}
static
const
int
PA_REUSE_KV_TIMES
=
get_env_
(
"PA_REUSE_KV_TIMES"
);
static
const
int
PA_BLOCK_SIZE
=
get_env_
(
"PA_BLOCK_SIZE"
);
static
const
int
PA_PRINT_PARAM
=
get_env_
(
"PA_PRINT_PARAM"
);
namespace
vllm
{
// Utility function for attention softmax.
...
...
@@ -344,7 +353,7 @@ __device__ void paged_attention_kernel_TC(
}
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
constexpr
(
REUSE_KV_TIMES
<=
2
&&
(
NUM_WARPS
>
64
||
USE_PARTITIONING
)
){
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
...
...
@@ -727,6 +736,60 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt(
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
void
get_numberthread_and_reuse_kv_v1
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
seq
,
int
qheads
,
int
kvheads
){
//mha
reusekv
=
1
;
if
(
qheads
==
kvheads
){
//llama 7B ,其他模型未可知
if
(
seq
<=
16
||
batchsize
>=
32
)
num_thread
=
64
;
else
if
(
batchsize
<=
2
)
num_thread
=
256
;
else
if
(
batchsize
<
8
)
num_thread
=
128
;
else
num_thread
=
64
;
return
;
}
// mqa
if
(
qheads
>
kvheads
*
4
){
if
(
seq
<
64
){
if
(
batchsize
<=
64
){
reusekv
=
1
;
num_thread
=
64
;}
else
if
(
batchsize
<
128
){
reusekv
=
2
;
num_thread
=
64
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
else
if
(
seq
<=
400
){
if
(
batchsize
<
16
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
batchsize
<
64
){
reusekv
=
2
;
num_thread
=
256
;}
else
if
(
batchsize
<=
128
){
reusekv
=
4
;
if
(
qheads
%
7
==
0
)
num_thread
=
64
;
//qwen7b
else
num_thread
=
256
;
//llama70b
}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
if
(
seq
<=
1000
){
if
(
batchsize
<
16
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
qheads
%
7
==
0
&&
batchsize
<=
128
){
//qwen7b
if
(
batchsize
<
64
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
else
if
(
batchsize
<=
64
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
8
;
num_thread
=
128
;}
}
else
if
(
seq
<
3900
)
{
reusekv
=
8
;
num_thread
=
256
;}
else
if
(
seq
<
7800
)
{
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
2
;
num_thread
=
256
;}
return
;
}
if
(
qheads
/
kvheads
>
4
&&
seq
<
3900
)
reusekv
=
8
;
else
if
(
qheads
/
kvheads
>
2
&&
seq
<
7800
)
reusekv
=
4
;
else
if
(
qheads
/
kvheads
>=
2
&&
seq
<
15600
)
reusekv
=
2
;
if
(
seq
<=
64
){
num_thread
=
64
;
if
(
batchsize
<=
64
)
reusekv
=
1
;
}
else
num_thread
=
256
;
}
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
...
...
@@ -769,18 +832,15 @@ void paged_attention_v1_launcher_opt(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
// if(head_size==128&&get_device_name()=="gfx928"){
REUSEKV_SWITCH_V1
([
&
]
{
constexpr
int
HEAD_SIZE
=
128
;
// constexpr int REUSE_KV_TIMES=8;
int
num_thread
=
64
;
if
(
REUSE_KV_TIMES
>
1
){
if
(
padded_max_seq_len
>
1024
||
num_heads
*
num_seqs
/
REUSE_KV_TIMES
<
600
)
num_thread
=
256
;
else
num_thread
=
128
;
}
else
if
(
num_heads
*
num_seqs
<
800
)
num_thread
=
128
;
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
static
int
use_vmac
=
false
;
int
reusekv
,
num_thread
;
get_numberthread_and_reuse_kv_v1
(
num_thread
,
reusekv
,
num_seqs
,
padded_max_seq_len
,
num_heads
,
num_kv_heads
);
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d
\n
"
,
reusekv
,
num_thread
);
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
2
;
...
...
@@ -799,7 +859,6 @@ void paged_attention_v1_launcher_opt(
});
});
}
// }
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...
...
@@ -902,10 +961,43 @@ void paged_attention_v1_opt(
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(
block
), reduce_shared_mem_size, stream, out_ptr, \
dim3(reduce_grid), dim3(
128
), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
max_num_partitions
,
int
qheads
,
int
kvheads
){
reusekv
=
1
;
int
blocks
=
batchsize
*
qheads
*
max_num_partitions
;
if
(
qheads
==
kvheads
){
if
(
blocks
<=
80
||
blocks
>
8000
){
num_thread
=
256
;}
else
if
(
blocks
<=
160
){
num_thread
=
128
;}
else
num_thread
=
64
;
return
;
}
if
(
qheads
/
kvheads
>
8
&&
blocks
>
4000
){
reusekv
=
16
;
if
(
blocks
>
40000
)
num_thread
=
64
;
else
num_thread
=
128
;
}
else
if
(
qheads
/
kvheads
==
5
||
qheads
/
kvheads
==
7
){
if
(
blocks
<=
160
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
640
/
5
*
qheads
/
kvheads
){
reusekv
=
4
;
num_thread
=
256
;}
else
if
(
blocks
<
1920
){
reusekv
=
8
;
num_thread
=
128
;}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
if
(
qheads
>
kvheads
*
4
){
if
(
blocks
<=
128
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
1536
){
reusekv
=
4
;
num_thread
=
256
;}
else
if
(
blocks
<
6144
){
reusekv
=
8
;
num_thread
=
128
;}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
{
if
(
blocks
<=
128
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
3000
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_v2_launcher_opt
(
...
...
@@ -953,17 +1045,12 @@ void paged_attention_v2_launcher_opt(
//if(head_size==128&&get_device_name()=="gfx928"){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
static
int
use_vmac
=
false
;
REUSEKV_SWITCH_V2
([
&
]
{
int
num_thread
;
if
(
REUSE_KV_TIMES
>
1
){
if
(
num_seqs
<
16
)
num_thread
=
256
;
else
if
(
max_num_partitions
*
num_seqs
*
num_heads
/
REUSE_KV_TIMES
>
4000
)
num_thread
=
64
;
else
num_thread
=
128
;
}
else
{
if
(
num_seqs
<
16
&&
max_num_partitions
<
10
)
num_thread
=
256
;
else
num_thread
=
64
;
}
int
reusekv
,
num_thread
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
num_seqs
,
max_num_partitions
,
num_heads
,
num_kv_heads
);
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d
\n
"
,
reusekv
,
num_thread
);
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
...
...
csrc/attention/static_switch.h
View file @
2b91ac93
...
...
@@ -52,47 +52,18 @@
} \
}()
#define REUSEKV_SWITCH(
num_blocks ,
...) \
#define REUSEKV_SWITCH(
reusekv,
...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V2( ...) \
[&] { \
if (num_heads / num_kv_heads > 8 ){ \
if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 4 ){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
return __VA_ARGS__();} \
else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (
num_heads/num_kv_heads >2 && padded_max_seq_len<7800
){ \
}else if (
reusekv==4
){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (
num_heads/num_kv_heads ==2 && padded_max_seq_len<15600
){ \
}else if (
reusekv==2
){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}else { \
...
...
vllm/attention/ops/paged_attn.py
View file @
2b91ac93
...
...
@@ -127,7 +127,7 @@ class PagedAttention:
# use_v1 = (max_seq_len <= 8192
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1
=
(
max_seq_len
<
8192
and
(
max_seq_len
<
1000
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
and
(
max_seq_len
<
(
10
24
if
num_kv_heads
==
num_heads
else
6
00
)
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
if
use_v1
:
# Run PagedAttention V1.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment