Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5b976df
Commit
a5b976df
authored
Mar 18, 2025
by
zhuwenwen
Browse files
解决PA部分size计算错误的问题
优化bf16精度 解决bf16精度问题,解决cudagraph精度问题 调整pa tc和非tc调用关系
parent
10ce38cc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
156 additions
and
206 deletions
+156
-206
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+156
-206
No files found.
vllm/attention/ops/paged_attn.py
View file @
a5b976df
...
@@ -14,7 +14,9 @@ if HAS_TRITON:
...
@@ -14,7 +14,9 @@ if HAS_TRITON:
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
gpuname
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
name
support_tc
=
gpuname
.
startswith
(
'K100_AI'
)
or
gpuname
.
startswith
(
'BW'
)
use_tc
=
envs
.
VLLM_USE_OPT_OP
and
envs
.
VLLM_USE_TC_PAGED_ATTN
and
support_tc
@
dataclass
@
dataclass
class
PagedAttentionMetadata
:
class
PagedAttentionMetadata
:
...
@@ -128,22 +130,13 @@ class PagedAttention:
...
@@ -128,22 +130,13 @@ class PagedAttention:
# to parallelize.
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
use_v1
=
(
max_seq_len
<
8192
and
(
max_seq_len
<
(
1024
if
num_kv_heads
==
num_heads
else
600
)
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
else
:
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
if
use_tc
and
head_size
==
128
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V1 SIZE:"
)
print
(
"PA V1 SIZE:"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt_tc
(
ops
.
paged_attention_v1_opt_tc
(
output
,
output
,
...
@@ -190,7 +183,18 @@ class PagedAttention:
...
@@ -190,7 +183,18 @@ class PagedAttention:
attn_masks
,
attn_masks
,
attn_masks_stride
attn_masks_stride
)
)
else
:
return
output
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V1 SIZE:"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt
(
ops
.
paged_attention_v1_opt
(
output
,
output
,
...
@@ -306,60 +310,6 @@ class PagedAttention:
...
@@ -306,60 +310,6 @@ class PagedAttention:
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt_tc
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v2_opt_tc_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt
(
ops
.
paged_attention_v2_opt
(
output
,
output
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment