Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45a060d6
Commit
45a060d6
authored
Feb 05, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.15.1' into v0.15.1-dev
parents
99fc9fc3
1892993b
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
20 deletions
+36
-20
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+0
-12
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+23
-7
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+12
-0
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
45a060d6
...
...
@@ -315,18 +315,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
vllm_config
:
"VllmConfig"
,
kv_cache_spec
:
"AttentionSpec"
,
)
->
AttentionCGSupport
:
# FA2 does not support CUDA graphs with encoder-decoder models due to
# accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091
if
(
vllm_config
.
model_config
.
is_encoder_decoder
and
get_flash_attn_version
()
==
2
):
logger
.
warning_once
(
"FlashAttention2 does not support CUDA graphs with "
"encoder-decoder models due to accuracy issues reported in #33091. "
"Disabling CUDA graph."
)
return
AttentionCGSupport
.
NEVER
return
cls
.
_cudagraph_support
def
__init__
(
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
45a060d6
...
...
@@ -479,6 +479,16 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
hit_length
=
max_cache_hit_length
hit_blocks_by_group
:
list
[
list
[
KVCacheBlock
]
|
None
]
=
[
None
]
*
num_groups
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
# Full attn is always first if it exists. This avoids EAGLE drops
# being applied multiple times to non-full-attn groups.
# FIXME (yifan): However, for complex hybrid models with multiple attn
# groups, we still have the EAGLE spiral block dropping problem. See
# discussion in issue https://github.com/vllm-project/vllm/issues/32802.
is_simple_hybrid
=
len
(
self
.
attention_groups
)
==
2
and
isinstance
(
self
.
attention_groups
[
0
][
0
],
FullAttentionSpec
)
while
True
:
curr_hit_length
=
hit_length
...
...
@@ -495,10 +505,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
# the last iteration.
num_blocks
=
curr_hit_length
//
spec
.
block_size
curr_hit_length
=
num_blocks
*
spec
.
block_size
for
group_id
in
group_ids
:
blocks
=
hit_blocks_by_group
[
group_id
]
assert
blocks
is
not
None
del
blocks
[
num_blocks
:]
else
:
hit_blocks
=
manager_cls
.
find_longest_cache_hit
(
block_hashes
=
_get_block_hashes
(
spec
),
...
...
@@ -513,10 +519,20 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
for
group_id
,
blocks
in
zip
(
group_ids
,
hit_blocks
):
hit_blocks_by_group
[
group_id
]
=
blocks
if
curr_hit_length
<
hit_length
:
hit_length
=
curr_hit_length
else
:
if
curr_hit_length
>=
hit_length
:
break
hit_length
=
curr_hit_length
# Simple hybrid: exit after one iteration
if
is_simple_hybrid
:
break
# Truncate full attention blocks to final hit_length (if present)
spec
,
group_ids
,
_
=
self
.
attention_groups
[
0
]
if
isinstance
(
spec
,
FullAttentionSpec
):
num_blocks
=
hit_length
//
spec
.
block_size
for
group_id
in
group_ids
:
if
(
blks
:
=
hit_blocks_by_group
[
group_id
])
is
not
None
:
del
blks
[
num_blocks
:]
return
tuple
(
blocks
if
blocks
is
not
None
else
[]
for
blocks
in
hit_blocks_by_group
...
...
vllm/v1/core/sched/scheduler.py
View file @
45a060d6
...
...
@@ -1284,7 +1284,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
)
)
if
scheduled_spec_token_ids
:
if
scheduled_spec_token_ids
and
generated_token_ids
:
num_draft_tokens
=
len
(
scheduled_spec_token_ids
)
num_accepted
=
len
(
generated_token_ids
)
-
1
num_rejected
=
num_draft_tokens
-
num_accepted
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
45a060d6
...
...
@@ -1382,12 +1382,14 @@ class GPUModelRunner(
num_scheduled_tokens
:
dict
[
str
,
int
],
kv_cache_spec
:
KVCacheSpec
,
num_reqs
:
int
,
for_cudagraph_capture
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
|
None
,
np
.
ndarray
|
None
]:
if
not
isinstance
(
kv_cache_spec
,
CrossAttentionSpec
):
return
None
,
None
# Zero out buffer for padding requests that are not actually scheduled (CGs)
self
.
encoder_seq_lens
.
np
[:
num_reqs
]
=
0
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
for
req_id
in
num_scheduled_tokens
:
...
...
@@ -1404,6 +1406,15 @@ class GPUModelRunner(
feature
.
mm_position
.
length
for
feature
in
req_state
.
mm_features
)
self
.
encoder_seq_lens
.
np
[
req_index
]
=
encoder_input_tokens
if
for_cudagraph_capture
:
# During CUDA graph capture, we need to use realistic encoder lengths
# so that max_seqlen_k is captured with the correct value.
max_encoder_len
=
getattr
(
self
.
model_config
.
hf_config
,
"max_source_positions"
,
self
.
max_encoder_len
,
)
self
.
encoder_seq_lens
.
np
[:
num_reqs
]
=
max_encoder_len
self
.
encoder_seq_lens
.
copy_to_gpu
(
num_reqs
)
encoder_seq_lens
=
self
.
encoder_seq_lens
.
gpu
[:
num_reqs
]
...
...
@@ -1821,6 +1832,7 @@ class GPUModelRunner(
num_scheduled_tokens
or
{},
kv_cache_group
.
kv_cache_spec
,
num_reqs_padded
,
for_cudagraph_capture
=
for_cudagraph_capture
,
)
if
kv_cache_gid
>
0
:
cm
.
block_table_tensor
=
_get_block_table
(
kv_cache_gid
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment