Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b61cb45
Unverified
Commit
3b61cb45
authored
Dec 09, 2024
by
Woosuk Kwon
Committed by
GitHub
Dec 09, 2024
Browse files
[V1] Further reduce CPU overheads in flash-attn (#10989)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
edc4fa31
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
7 deletions
+28
-7
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+12
-2
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+16
-5
No files found.
csrc/cache_kernels.cu
View file @
3b61cb45
...
...
@@ -307,10 +307,20 @@ void reshape_and_cache_flash(
torch
::
Tensor
&
key_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
value_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
or [num_actual_tokens]
const
std
::
string
&
kv_cache_dtype
,
const
double
k_scale
,
const
double
v_scale
)
{
int
num_tokens
=
key
.
size
(
0
);
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int
num_tokens
=
slot_mapping
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
1
);
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
3b61cb45
...
...
@@ -138,14 +138,25 @@ class FlashAttentionImpl(AttentionImpl):
# Profiling run.
return
output
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# Reshape the input keys and values and store them in the cache.
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
[:
num_actual_tokens
]
,
value
[:
num_actual_tokens
]
,
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment