Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b91ac93
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "9d72daf4ced05a5fec1ad8ea2914a39296f402da"
Commit
2b91ac93
authored
Sep 24, 2024
by
zhuwenwen
Browse files
优化pa小batch性能(pa_v2),优化pa小seq性能(pa_v1),reusekv=16优化
parent
de7d9456
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
82 deletions
+140
-82
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+133
-46
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+6
-35
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+1
-1
No files found.
csrc/attention/attention_kernels_opt.cu
View file @
2b91ac93
...
@@ -43,7 +43,16 @@ inline std::string get_device_name()
...
@@ -43,7 +43,16 @@ inline std::string get_device_name()
const
std
::
string
raw_name
(
props
.
gcnArchName
);
const
std
::
string
raw_name
(
props
.
gcnArchName
);
return
raw_name
.
substr
(
0
,
raw_name
.
find
(
':'
));
// str.substr(0, npos) returns str.
return
raw_name
.
substr
(
0
,
raw_name
.
find
(
':'
));
// str.substr(0, npos) returns str.
}
}
static
inline
int
get_env_
(
const
char
*
env_var
)
{
if
(
char
*
value
=
std
::
getenv
(
env_var
))
{
return
atoi
(
value
);
}
return
0
;
}
static
const
int
PA_REUSE_KV_TIMES
=
get_env_
(
"PA_REUSE_KV_TIMES"
);
static
const
int
PA_BLOCK_SIZE
=
get_env_
(
"PA_BLOCK_SIZE"
);
static
const
int
PA_PRINT_PARAM
=
get_env_
(
"PA_PRINT_PARAM"
);
namespace
vllm
{
namespace
vllm
{
// Utility function for attention softmax.
// Utility function for attention softmax.
...
@@ -344,7 +353,7 @@ __device__ void paged_attention_kernel_TC(
...
@@ -344,7 +353,7 @@ __device__ void paged_attention_kernel_TC(
}
}
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
constexpr
(
REUSE_KV_TIMES
<=
2
&&
(
NUM_WARPS
>
64
||
USE_PARTITIONING
)
){
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
...
@@ -723,10 +732,64 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt(
...
@@ -723,10 +732,64 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
\
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step);
void
get_numberthread_and_reuse_kv_v1
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
seq
,
int
qheads
,
int
kvheads
){
//mha
reusekv
=
1
;
if
(
qheads
==
kvheads
){
//llama 7B ,其他模型未可知
if
(
seq
<=
16
||
batchsize
>=
32
)
num_thread
=
64
;
else
if
(
batchsize
<=
2
)
num_thread
=
256
;
else
if
(
batchsize
<
8
)
num_thread
=
128
;
else
num_thread
=
64
;
return
;
}
// mqa
if
(
qheads
>
kvheads
*
4
){
if
(
seq
<
64
){
if
(
batchsize
<=
64
){
reusekv
=
1
;
num_thread
=
64
;}
else
if
(
batchsize
<
128
){
reusekv
=
2
;
num_thread
=
64
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
else
if
(
seq
<=
400
){
if
(
batchsize
<
16
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
batchsize
<
64
){
reusekv
=
2
;
num_thread
=
256
;}
else
if
(
batchsize
<=
128
){
reusekv
=
4
;
if
(
qheads
%
7
==
0
)
num_thread
=
64
;
//qwen7b
else
num_thread
=
256
;
//llama70b
}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
if
(
seq
<=
1000
){
if
(
batchsize
<
16
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
qheads
%
7
==
0
&&
batchsize
<=
128
){
//qwen7b
if
(
batchsize
<
64
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
else
if
(
batchsize
<=
64
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
8
;
num_thread
=
128
;}
}
else
if
(
seq
<
3900
)
{
reusekv
=
8
;
num_thread
=
256
;}
else
if
(
seq
<
7800
)
{
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
2
;
num_thread
=
256
;}
return
;
}
if
(
qheads
/
kvheads
>
4
&&
seq
<
3900
)
reusekv
=
8
;
else
if
(
qheads
/
kvheads
>
2
&&
seq
<
7800
)
reusekv
=
4
;
else
if
(
qheads
/
kvheads
>=
2
&&
seq
<
15600
)
reusekv
=
2
;
if
(
seq
<=
64
){
num_thread
=
64
;
if
(
batchsize
<=
64
)
reusekv
=
1
;
}
else
num_thread
=
256
;
}
// TODO(woosuk): Tune NUM_THREADS.
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
...
@@ -750,7 +813,7 @@ void paged_attention_v1_launcher_opt(
...
@@ -750,7 +813,7 @@ void paged_attention_v1_launcher_opt(
if
(
num_heads
!=
num_kv_heads
)
{
if
(
num_heads
!=
num_kv_heads
)
{
num_threads
=
256
;
num_threads
=
256
;
}
}
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
// NOTE: alibi_slopes is optional.
...
@@ -769,44 +832,40 @@ void paged_attention_v1_launcher_opt(
...
@@ -769,44 +832,40 @@ void paged_attention_v1_launcher_opt(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
// if(head_size==128&&get_device_name()=="gfx928"){
constexpr
int
HEAD_SIZE
=
128
;
REUSEKV_SWITCH_V1
([
&
]
{
constexpr
static
int
use_vmac
=
false
;
constexpr
int
HEAD_SIZE
=
128
;
int
reusekv
,
num_thread
;
// constexpr int REUSE_KV_TIMES=8;
get_numberthread_and_reuse_kv_v1
(
num_thread
,
reusekv
,
num_seqs
,
padded_max_seq_len
,
num_heads
,
num_kv_heads
);
int
num_thread
=
64
;
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
REUSE_KV_TIMES
>
1
){
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
padded_max_seq_len
>
1024
||
num_heads
*
num_seqs
/
REUSE_KV_TIMES
<
600
)
num_thread
=
256
;
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d
\n
"
,
reusekv
,
num_thread
);
else
num_thread
=
128
;
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
}
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
else
if
(
num_heads
*
num_seqs
<
800
)
num_thread
=
128
;
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
static
int
use_vmac
=
false
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
2
;
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
if
(
REUSE_KV_TIMES
==
1
)
outputs_size
=
0
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
2
;
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Keep that in sync with the logic here!
if
(
REUSE_KV_TIMES
==
1
)
outputs_size
=
0
;
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
if
(
num_heads
==
num_kv_heads
)
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
// Keep that in sync with the logic here!
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
if
(
num_heads
==
num_kv_heads
)
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
dim3
block
(
NUM_THREADS
);
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
LAUNCH_PAGED_ATTENTION_V1_TC
(
HEAD_SIZE
);
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
LAUNCH_PAGED_ATTENTION_V1_TC
(
HEAD_SIZE
);
});
});
});
});
}
}
// }
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank,
\
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step);
...
@@ -902,10 +961,43 @@ void paged_attention_v1_opt(
...
@@ -902,10 +961,43 @@ void paged_attention_v1_opt(
hipLaunchKernelGGL( \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(
block
), reduce_shared_mem_size, stream, out_ptr, \
dim3(reduce_grid), dim3(
128
), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
max_num_partitions);
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
max_num_partitions
,
int
qheads
,
int
kvheads
){
reusekv
=
1
;
int
blocks
=
batchsize
*
qheads
*
max_num_partitions
;
if
(
qheads
==
kvheads
){
if
(
blocks
<=
80
||
blocks
>
8000
){
num_thread
=
256
;}
else
if
(
blocks
<=
160
){
num_thread
=
128
;}
else
num_thread
=
64
;
return
;
}
if
(
qheads
/
kvheads
>
8
&&
blocks
>
4000
){
reusekv
=
16
;
if
(
blocks
>
40000
)
num_thread
=
64
;
else
num_thread
=
128
;
}
else
if
(
qheads
/
kvheads
==
5
||
qheads
/
kvheads
==
7
){
if
(
blocks
<=
160
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
640
/
5
*
qheads
/
kvheads
){
reusekv
=
4
;
num_thread
=
256
;}
else
if
(
blocks
<
1920
){
reusekv
=
8
;
num_thread
=
128
;}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
if
(
qheads
>
kvheads
*
4
){
if
(
blocks
<=
128
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
1536
){
reusekv
=
4
;
num_thread
=
256
;}
else
if
(
blocks
<
6144
){
reusekv
=
8
;
num_thread
=
128
;}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
{
if
(
blocks
<=
128
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
3000
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
512
>
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_v2_launcher_opt
(
void
paged_attention_v2_launcher_opt
(
...
@@ -953,17 +1045,12 @@ void paged_attention_v2_launcher_opt(
...
@@ -953,17 +1045,12 @@ void paged_attention_v2_launcher_opt(
//if(head_size==128&&get_device_name()=="gfx928"){
//if(head_size==128&&get_device_name()=="gfx928"){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
int
HEAD_SIZE
=
128
;
constexpr
static
int
use_vmac
=
false
;
constexpr
static
int
use_vmac
=
false
;
REUSEKV_SWITCH_V2
([
&
]
{
int
reusekv
,
num_thread
;
int
num_thread
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
num_seqs
,
max_num_partitions
,
num_heads
,
num_kv_heads
);
if
(
REUSE_KV_TIMES
>
1
){
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
num_seqs
<
16
)
num_thread
=
256
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
else
if
(
max_num_partitions
*
num_seqs
*
num_heads
/
REUSE_KV_TIMES
>
4000
)
num_thread
=
64
;
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d
\n
"
,
reusekv
,
num_thread
);
else
num_thread
=
128
;
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
}
else
{
if
(
num_seqs
<
16
&&
max_num_partitions
<
10
)
num_thread
=
256
;
else
num_thread
=
64
;
}
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
...
@@ -982,7 +1069,7 @@ void paged_attention_v2_launcher_opt(
...
@@ -982,7 +1069,7 @@ void paged_attention_v2_launcher_opt(
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,
\
IS_BLOCK_SPARSE>( \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
...
...
csrc/attention/static_switch.h
View file @
2b91ac93
...
@@ -52,47 +52,18 @@
...
@@ -52,47 +52,18 @@
} \
} \
}()
}()
#define REUSEKV_SWITCH(
num_blocks ,
...) \
#define REUSEKV_SWITCH(
reusekv,
...) \
[&] { \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V2( ...) \
[&] { \
if (num_heads / num_kv_heads > 8 ){ \
constexpr static int REUSE_KV_TIMES = 16; \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
return __VA_ARGS__();} \
}else if (num_heads / num_kv_heads > 4 ){ \
else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
constexpr static int REUSE_KV_TIMES = 8; \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
return __VA_ARGS__(); \
}else if (
num_heads/num_kv_heads >2 && padded_max_seq_len<7800
){ \
}else if (
reusekv==4
){ \
constexpr static int REUSE_KV_TIMES = 4; \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
return __VA_ARGS__(); \
}else if (
num_heads/num_kv_heads ==2 && padded_max_seq_len<15600
){ \
}else if (
reusekv==2
){ \
constexpr static int REUSE_KV_TIMES = 2; \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
return __VA_ARGS__(); \
}else { \
}else { \
...
...
vllm/attention/ops/paged_attn.py
View file @
2b91ac93
...
@@ -127,7 +127,7 @@ class PagedAttention:
...
@@ -127,7 +127,7 @@ class PagedAttention:
# use_v1 = (max_seq_len <= 8192
# use_v1 = (max_seq_len <= 8192
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1
=
(
max_seq_len
<
8192
use_v1
=
(
max_seq_len
<
8192
and
(
max_seq_len
<
1000
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
and
(
max_seq_len
<
(
10
24
if
num_kv_heads
==
num_heads
else
6
00
)
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment