Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d2ccb8d
Commit
0d2ccb8d
authored
Mar 15, 2025
by
zhangshao
Browse files
调整pa tc和非tc调用关系
parent
4f8d38c8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
128 deletions
+8
-128
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+5
-64
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+1
-63
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+2
-1
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
0d2ccb8d
...
...
@@ -999,28 +999,6 @@ void paged_attention_v2_launcher_opt_tc(
break; \
}
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
...
...
@@ -1043,37 +1021,10 @@ void paged_attention_v2_opt_tc(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v2
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
}
}
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
...
...
@@ -1091,20 +1042,10 @@ void paged_attention_v1_opt_tc(
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
else
{
paged_attention_v2_opt_tc
(
out
,
out
,
out
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
}
#undef WARP_SIZE
...
...
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
0d2ccb8d
...
...
@@ -910,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
break; \
}
void
paged_attention_v2_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
// [num_seqs, max_seq_len]
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
);
void
paged_attention_v2_opt_tc_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
...
...
@@ -958,38 +934,10 @@ void paged_attention_v2_opt_tc_with_mask(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v2_with_mask
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
}
}
void
paged_attention_v1_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
);
void
paged_attention_v1_opt_tc_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
...
...
@@ -1010,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask(
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v1_with_mask
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
else
{
paged_attention_v2_opt_tc_with_mask
(
out
,
out
,
out
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
}
#undef WARP_SIZE
...
...
vllm/attention/ops/paged_attn.py
View file @
0d2ccb8d
...
...
@@ -16,6 +16,7 @@ if HAS_TRITON:
_PARTITION_SIZE
=
512
gpuname
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
name
support_tc
=
gpuname
.
startswith
(
'K100_AI'
)
or
gpuname
.
startswith
(
'BW'
)
use_tc
=
envs
.
VLLM_USE_OPT_OP
and
envs
.
VLLM_USE_TC_PAGED_ATTN
and
support_tc
@
dataclass
class
PagedAttentionMetadata
:
...
...
@@ -131,7 +132,7 @@ class PagedAttention:
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
if
envs
.
VLLM_USE_TC_PAGED_ATTN
and
support_tc
:
if
use_tc
and
head_size
==
128
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V1 SIZE:"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment