Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ea48fb3
Unverified
Commit
4ea48fb3
authored
Feb 08, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 08, 2025
Browse files
[V1][Minor] Move cascade attn logic outside _prepare_inputs (#12943)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
e31498bd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
61 deletions
+89
-61
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+89
-61
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
4ea48fb3
...
...
@@ -476,12 +476,82 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
device
,
non_blocking
=
True
).
long
()
# Prepare for cascade attention if needed.
common_prefix_len
=
(
scheduler_output
.
num_common_prefix_blocks
*
self
.
block_size
)
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
,
)
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
(
[
0
,
total_num_scheduled_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
suffix_kv_lens
=
(
self
.
seq_lens_np
[:
num_reqs
]
-
common_prefix_len
)
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
device
)
else
:
cu_prefix_query_lens
=
None
prefix_kv_lens
=
None
suffix_kv_lens
=
None
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
(
self
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
]),
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
)
# Hot-Swap lora model
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
return
attn_metadata
,
logits_indices
def
_compute_cascade_attn_prefix_len
(
self
,
num_scheduled_tokens
:
np
.
ndarray
,
num_common_prefix_blocks
:
int
,
)
->
int
:
"""Compute the length of the common prefix for cascade attention.
NOTE(woosuk): The common prefix length returned by this function
represents the length used specifically for cascade attention, not the
actual number of tokens shared between requests. When cascade attention
is disabled (use_cascade=False), this function returns 0 even if
requests share common tokens. Additionally, the common prefix length is
truncated to a multiple of the block size and may be further truncated
due to implementation details explained below.
Args:
num_scheduled_tokens: Number of tokens scheduled per request.
num_common_prefix_blocks: Number of shared KV cache blocks.
Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len
=
num_common_prefix_blocks
*
self
.
block_size
if
common_prefix_len
==
0
:
# Common case.
use_cascade
=
False
else
:
return
0
# NOTE(woosuk): Cascade attention uses two attention kernels: one
# for the common prefix and the other for the rest. For the first
# kernel, we concatenate all the query tokens (possibly from
...
...
@@ -521,6 +591,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
num_reqs
=
len
(
num_scheduled_tokens
)
common_prefix_len
=
min
(
common_prefix_len
,
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
].
min
())
...
...
@@ -536,50 +607,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_sliding_window
=
self
.
sliding_window
is
not
None
,
num_sms
=
self
.
num_sms
,
)
if
use_cascade
:
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
(
[
0
,
total_num_scheduled_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
suffix_kv_lens
=
(
self
.
seq_lens_np
[:
num_reqs
]
-
common_prefix_len
)
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
device
)
else
:
cu_prefix_query_lens
=
None
prefix_kv_lens
=
None
suffix_kv_lens
=
None
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
(
self
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
]),
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
)
# Hot-Swap lora model
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
return
attn_metadata
,
logits_indices
return
common_prefix_len
if
use_cascade
else
0
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
mrope_pos_ptr
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment