Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ea48fb3
Unverified
Commit
4ea48fb3
authored
Feb 08, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 08, 2025
Browse files
[V1][Minor] Move cascade attn logic outside _prepare_inputs (#12943)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
e31498bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
61 deletions
+89
-61
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+89
-61
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
4ea48fb3
...
@@ -476,67 +476,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -476,67 +476,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
device
,
non_blocking
=
True
).
long
()
self
.
device
,
non_blocking
=
True
).
long
()
# Prepare for cascade attention if needed.
# Prepare for cascade attention if needed.
common_prefix_len
=
(
scheduler_output
.
num_common_prefix_blocks
*
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
self
.
block_size
)
num_scheduled_tokens
,
if
common_prefix_len
==
0
:
scheduler_output
.
num_common_prefix_blocks
,
# Common case.
)
use_cascade
=
False
use_cascade
=
common_prefix_len
>
0
else
:
# NOTE(woosuk): Cascade attention uses two attention kernels: one
# for the common prefix and the other for the rest. For the first
# kernel, we concatenate all the query tokens (possibly from
# different requests) and treat them as if they are from the same
# request. Then, we use bi-directional attention to process the
# common prefix in the KV cache. Importantly, this means that the
# first kernel does not do any masking.
# Consider the following example:
# Request 1's input query: [D, E, X]
# Request 1's kv cache: [A, B, C, D, E, X]
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
# Request 2's input query: [E, Y]
# Request 2's kv cache: [A, B, C, D, E, Y]
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D, E] as the common prefix, then the
# first kernel will compute the bi-directional attention between
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
# However, this is wrong because D in Request 1 should not attend to
# E in the common prefix (i.e., we need masking).
# To avoid this, [A, B, C, D] should be the common prefix.
# That is, the common prefix should be capped by the minimum
# num_computed_tokens among the requests, and plus one to include
# the first token of the query.
# In practice, we use [A, B, C] as the common prefix, instead of
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
# num_computed_tokens, without plus one).
# This is because of an implementation detail: We want to always
# use two kernels for cascade attention. Let's imagine:
# Request 3's input query: [D]
# Request 3's kv cache: [A, B, C, D]
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D] as the common prefix for Request 1-3,
# then Request 3 will be processed only by the first kernel,
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
common_prefix_len
=
min
(
common_prefix_len
,
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
].
min
())
# common_prefix_len should be a multiple of the block size.
common_prefix_len
=
(
common_prefix_len
//
self
.
block_size
*
self
.
block_size
)
use_cascade
=
FlashAttentionBackend
.
use_cascade_attention
(
common_prefix_len
=
common_prefix_len
,
query_lens
=
num_scheduled_tokens
,
num_query_heads
=
self
.
num_query_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
use_alibi
=
False
,
# FIXME
use_sliding_window
=
self
.
sliding_window
is
not
None
,
num_sms
=
self
.
num_sms
,
)
if
use_cascade
:
if
use_cascade
:
# TODO: Optimize.
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
(
cu_prefix_query_lens
=
torch
.
tensor
(
...
@@ -581,6 +525,90 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -581,6 +525,90 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits_indices
=
query_start_loc
[
1
:]
-
1
logits_indices
=
query_start_loc
[
1
:]
-
1
return
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
def
_compute_cascade_attn_prefix_len
(
self
,
num_scheduled_tokens
:
np
.
ndarray
,
num_common_prefix_blocks
:
int
,
)
->
int
:
"""Compute the length of the common prefix for cascade attention.
NOTE(woosuk): The common prefix length returned by this function
represents the length used specifically for cascade attention, not the
actual number of tokens shared between requests. When cascade attention
is disabled (use_cascade=False), this function returns 0 even if
requests share common tokens. Additionally, the common prefix length is
truncated to a multiple of the block size and may be further truncated
due to implementation details explained below.
Args:
num_scheduled_tokens: Number of tokens scheduled per request.
num_common_prefix_blocks: Number of shared KV cache blocks.
Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len
=
num_common_prefix_blocks
*
self
.
block_size
if
common_prefix_len
==
0
:
# Common case.
return
0
# NOTE(woosuk): Cascade attention uses two attention kernels: one
# for the common prefix and the other for the rest. For the first
# kernel, we concatenate all the query tokens (possibly from
# different requests) and treat them as if they are from the same
# request. Then, we use bi-directional attention to process the
# common prefix in the KV cache. Importantly, this means that the
# first kernel does not do any masking.
# Consider the following example:
# Request 1's input query: [D, E, X]
# Request 1's kv cache: [A, B, C, D, E, X]
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
# Request 2's input query: [E, Y]
# Request 2's kv cache: [A, B, C, D, E, Y]
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D, E] as the common prefix, then the
# first kernel will compute the bi-directional attention between
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
# However, this is wrong because D in Request 1 should not attend to
# E in the common prefix (i.e., we need masking).
# To avoid this, [A, B, C, D] should be the common prefix.
# That is, the common prefix should be capped by the minimum
# num_computed_tokens among the requests, and plus one to include
# the first token of the query.
# In practice, we use [A, B, C] as the common prefix, instead of
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
# num_computed_tokens, without plus one).
# This is because of an implementation detail: We want to always
# use two kernels for cascade attention. Let's imagine:
# Request 3's input query: [D]
# Request 3's kv cache: [A, B, C, D]
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D] as the common prefix for Request 1-3,
# then Request 3 will be processed only by the first kernel,
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
num_reqs
=
len
(
num_scheduled_tokens
)
common_prefix_len
=
min
(
common_prefix_len
,
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
].
min
())
# common_prefix_len should be a multiple of the block size.
common_prefix_len
=
(
common_prefix_len
//
self
.
block_size
*
self
.
block_size
)
use_cascade
=
FlashAttentionBackend
.
use_cascade_attention
(
common_prefix_len
=
common_prefix_len
,
query_lens
=
num_scheduled_tokens
,
num_query_heads
=
self
.
num_query_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
use_alibi
=
False
,
# FIXME
use_sliding_window
=
self
.
sliding_window
is
not
None
,
num_sms
=
self
.
num_sms
,
)
return
common_prefix_len
if
use_cascade
else
0
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
mrope_pos_ptr
=
0
mrope_pos_ptr
=
0
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment