Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e2bd7e16
Commit
e2bd7e16
authored
Oct 18, 2024
by
zhangshao
Browse files
解决pa v2 bug
parent
304e2bab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
7 deletions
+9
-7
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+9
-7
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
e2bd7e16
...
@@ -838,7 +838,6 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -838,7 +838,6 @@ void paged_attention_v1_launcher_opt_tc(
get_numberthread_and_reuse_kv_v1
(
num_thread
,
reusekv
,
num_seqs
,
padded_max_seq_len
,
num_heads
,
num_kv_heads
);
get_numberthread_and_reuse_kv_v1
(
num_thread
,
reusekv
,
num_seqs
,
padded_max_seq_len
,
num_heads
,
num_kv_heads
);
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d
\n
"
,
reusekv
,
num_thread
);
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
...
@@ -855,6 +854,8 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -855,6 +854,8 @@ void paged_attention_v1_launcher_opt_tc(
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
dim3
block
(
NUM_THREADS
);
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
reusekv
,
num_thread
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
);
LAUNCH_PAGED_ATTENTION_V1_TC
(
HEAD_SIZE
);
LAUNCH_PAGED_ATTENTION_V1_TC
(
HEAD_SIZE
);
});
});
});
});
...
@@ -897,7 +898,7 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -897,7 +898,7 @@ void paged_attention_v1_launcher_opt_tc(
break; \
break; \
}
}
void
paged_attention_v1
(
void
paged_attention_v1
_opt
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
torch
::
Tensor
&
...
@@ -935,7 +936,7 @@ void paged_attention_v1_opt_tc(
...
@@ -935,7 +936,7 @@ void paged_attention_v1_opt_tc(
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v1
_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_block_size
,
blocksparse_head_sliding_step
);
...
@@ -961,7 +962,7 @@ void paged_attention_v1_opt_tc(
...
@@ -961,7 +962,7 @@ void paged_attention_v1_opt_tc(
hipLaunchKernelGGL( \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(
128
), reduce_shared_mem_size, stream, out_ptr, \
dim3(reduce_grid), dim3(
block
), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
max_num_partitions);
...
@@ -1049,7 +1050,6 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1049,7 +1050,6 @@ void paged_attention_v2_launcher_opt_tc(
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
num_seqs
,
max_num_partitions
,
num_heads
,
num_kv_heads
);
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
num_seqs
,
max_num_partitions
,
num_heads
,
num_kv_heads
);
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d
\n
"
,
reusekv
,
num_thread
);
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
...
@@ -1061,6 +1061,8 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1061,6 +1061,8 @@ void paged_attention_v2_launcher_opt_tc(
grid
.
z
=
num_seqs
;
grid
.
z
=
num_seqs
;
dim3
block
(
NUM_THREADS
);
dim3
block
(
NUM_THREADS
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
reusekv
,
num_thread
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
);
LAUNCH_PAGED_ATTENTION_V2_TC
(
HEAD_SIZE
);
LAUNCH_PAGED_ATTENTION_V2_TC
(
HEAD_SIZE
);
});
});
});
});
...
@@ -1105,7 +1107,7 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1105,7 +1107,7 @@ void paged_attention_v2_launcher_opt_tc(
break; \
break; \
}
}
void
paged_attention_v2
(
void
paged_attention_v2
_opt
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
...
@@ -1151,7 +1153,7 @@ void paged_attention_v2_opt_tc(
...
@@ -1151,7 +1153,7 @@ void paged_attention_v2_opt_tc(
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
paged_attention_v2
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v2
_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_block_size
,
blocksparse_head_sliding_step
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment