Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
228a714a
Commit
228a714a
authored
Mar 03, 2025
by
zhangshao
Browse files
解决PA部分size计算错误的问题
parent
40083064
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
562 additions
and
891 deletions
+562
-891
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+401
-589
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+7
-97
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+154
-205
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
228a714a
This diff is collapsed.
Click to expand it.
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
228a714a
...
@@ -316,13 +316,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -316,13 +316,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
}
__syncthreads
();
__syncthreads
();
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
(
q_boundary
<=
2
){
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
constexpr
int
acc_size
=
REUSE_KV_TIMES
==
1
?
1
:
2
;
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
float
accs
[
acc_size
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
acc_size
;
k
++
)
for
(
int
k
=
0
;
k
<
REUSE_KV_TIMES
;
k
++
)
{
{
accs
[
k
][
i
]
=
0.
f
;
accs
[
k
][
i
]
=
0.
f
;
}
}
...
@@ -356,7 +355,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -356,7 +355,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
if
(
rows
==
k
){
if
(
rows
==
k
){
for
(
int
resuseid
=
0
;
resuseid
<
acc_size
;
resuseid
++
){
for
(
int
resuseid
=
0
;
resuseid
<
REUSE_KV_TIMES
;
resuseid
++
){
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
}
}
}
}
...
@@ -366,8 +365,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -366,8 +365,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
__syncthreads
();
__syncthreads
();
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
// Perform reduction across warps.
// Perform reduction across warps.
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
acc_size
;
reuse_kv_idx
++
)
{
if
constexpr
(
NUM_THREADS
>
64
){
if
constexpr
(
NUM_THREADS
>
64
){
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
#pragma unroll
#pragma unroll
...
@@ -780,97 +778,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
...
@@ -780,97 +778,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE);}
max_num_partitions,PARTITION_SIZE);}
static
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
&
PARTITION_SIZE
,
int
&
max_num_partitions
,
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
&
PARTITION_SIZE
,
int
&
max_num_partitions
,
int
batchsize
,
int
max_seq_len
,
int
qheads
,
int
kvheads
,
int
num_blocks
)
int
batchsize
,
int
max_seq_len
,
int
qheads
,
int
kvheads
,
int
num_blocks
);
{
reusekv
=
1
;
num_thread
=
256
;
PARTITION_SIZE
=
512
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
if
(
max_seq_len
==
8192
&&
num_blocks
==
1024
){
//ali test
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
if
(
batchsize
==
1
&&
qheads
==
32
&&
kvheads
==
32
){
num_thread
=
64
;
return
;}
if
(
batchsize
==
1
){
if
(
qheads
==
52
){
reusekv
=
8
;
return
;}
if
(
qheads
==
13
){
reusekv
=
2
;
return
;}
reusekv
=
4
;
return
;
}
if
(
batchsize
==
64
){
if
(
qheads
==
13
){
PARTITION_SIZE
=
256
;
num_thread
=
128
;
reusekv
=
8
;}
else
if
(
qheads
==
32
){
PARTITION_SIZE
=
1024
;
reusekv
=
8
;}
else
if
(
qheads
==
52
||
qheads
==
26
){
reusekv
=
16
;}
else
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
return
;
}
}
if
(
qheads
==
kvheads
){
if
(
max_seq_len
<=
8192
){
if
(
batchsize
*
qheads
>=
512
){
max_num_partitions
=
1
;
num_thread
=
64
;
}
if
(
qheads
==
32
&&
max_seq_len
<=
1024
)
max_num_partitions
=
1
;
}
return
;
}
if
(
max_seq_len
<
800
)
max_num_partitions
=
1
;
if
(
qheads
>
kvheads
*
4
){
if
(
max_seq_len
<=
1000
||
max_seq_len
<
1500
&&
(
batchsize
>=
8
&&
qheads
>=
8
||
batchsize
>=
64
)
||
max_seq_len
<
1900
&&
batchsize
>=
8
&&
qheads
==
28
)
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
3900
)
reusekv
=
8
;
else
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
return
;
}
}
if
(
max_num_partitions
==
1
){
if
(
max_seq_len
<
512
){
int
bytes
=
max_seq_len
*
qheads
*
batchsize
;
if
(
bytes
<
51200
)
reusekv
=
1
;
else
if
(
bytes
<
256000
)
reusekv
=
4
;
else
reusekv
=
8
;
return
;
}
if
(
batchsize
<
4
||
batchsize
==
4
&&
qheads
==
8
)
reusekv
=
1
;
else
if
(
batchsize
<
32
||
batchsize
<=
64
&&
qheads
==
8
)
reusekv
=
4
;
else
reusekv
=
8
;
return
;
}
if
(
blocks
<
150
)
return
;
if
(
blocks
<
600
||
qheads
<=
kvheads
*
4
){
reusekv
=
4
;
return
;}
reusekv
=
8
;
return
;
}
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
4
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
return
;
}
}
if
(
max_seq_len
<=
1000
||
max_seq_len
<=
1500
&&
(
qheads
>
4
&&
batchsize
>=
16
||
batchsize
>=
64
))
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
max_seq_len
>=
2000
))
reusekv
=
4
;
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v2_launcher_opt_tc_with_mask
(
void
paged_attention_v2_launcher_opt_tc_with_mask
(
...
...
vllm/attention/ops/paged_attn.py
View file @
228a714a
...
@@ -14,7 +14,8 @@ if HAS_TRITON:
...
@@ -14,7 +14,8 @@ if HAS_TRITON:
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
gpuname
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
name
support_tc
=
gpuname
.
startswith
(
'K100_AI'
)
or
gpuname
.
startswith
(
'BW'
)
@
dataclass
@
dataclass
class
PagedAttentionMetadata
:
class
PagedAttentionMetadata
:
...
@@ -128,12 +129,62 @@ class PagedAttention:
...
@@ -128,12 +129,62 @@ class PagedAttention:
# to parallelize.
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
use_v1
=
(
max_seq_len
<
8192
and
(
max_seq_len
<
(
1024
if
num_kv_heads
==
num_heads
else
600
)
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
if
envs
.
VLLM_USE_TC_PAGED_ATTN
and
support_tc
:
else
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
use_v1
=
(
max_seq_len
<=
8192
print
(
"PA V1 SIZE:"
)
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt_tc
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v1_opt_tc_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
return
output
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
...
@@ -143,100 +194,52 @@ class PagedAttention:
...
@@ -143,100 +194,52 @@ class PagedAttention:
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt
(
ops
.
paged_attention_v1_opt_tc
(
output
,
output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
blocksparse_head_sliding_step
)
)
else
:
ops
.
paged_attention_v1_opt_tc_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt_with_mask
(
ops
.
paged_attention_v1_opt
(
output
,
output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
attn_masks
,
)
attn_masks_stride
else
:
)
ops
.
paged_attention_v1_opt_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1
(
ops
.
paged_attention_v1
(
...
@@ -306,112 +309,58 @@ class PagedAttention:
...
@@ -306,112 +309,58 @@ class PagedAttention:
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt
(
ops
.
paged_attention_v2_opt_tc
(
output
,
output
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
tmp_output
,
tmp_output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
blocksparse_head_sliding_step
)
)
else
:
ops
.
paged_attention_v2_opt_tc_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt_with_mask
(
ops
.
paged_attention_v2_opt
(
output
,
output
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
tmp_output
,
tmp_output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
attn_masks
,
)
attn_masks_stride
else
:
)
ops
.
paged_attention_v2_opt_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2
(
ops
.
paged_attention_v2
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment