Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dee71fba
Commit
dee71fba
authored
May 06, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5.post1' into v0.8.5.post1-dev
parents
8a12a939
3015d563
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
2 deletions
+116
-2
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+77
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+35
-1
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-1
No files found.
tests/v1/core/test_scheduler.py
View file @
dee71fba
...
@@ -1165,3 +1165,80 @@ def test_kv_connector_handles_preemption():
...
@@ -1165,3 +1165,80 @@ def test_kv_connector_handles_preemption():
# All memory should be freed since nothing is running.
# All memory should be freed since nothing is running.
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
==
NUM_BLOCKS
-
1
==
NUM_BLOCKS
-
1
def
make_output
(
scheduler
:
Scheduler
):
return
ModelRunnerOutput
(
req_ids
=
[
req
.
request_id
for
req
in
scheduler
.
running
],
req_id_to_index
=
{
req
.
request_id
:
i
for
i
,
req
in
enumerate
(
scheduler
.
running
)
},
sampled_token_ids
=
[[
1000
]]
*
len
(
scheduler
.
running
),
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
def
assert_scheduler_empty
(
scheduler
:
Scheduler
):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert
len
(
scheduler
.
requests
)
==
0
assert
len
(
scheduler
.
waiting
)
==
0
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
assert
len
(
scheduler
.
_cached_reqs_data
)
==
0
# EncoderCacheManager.
assert
len
(
scheduler
.
encoder_cache_manager
.
freed
)
==
0
assert
len
(
scheduler
.
encoder_cache_manager
.
cached
)
==
0
# KVCache Manager.
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_blocks
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_block_hashes
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
num_cached_block
)
==
0
num_free_blocks
=
(
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
.
num_free_blocks
)
assert
num_free_blocks
==
(
scheduler
.
kv_cache_manager
.
block_pool
.
num_gpu_blocks
-
1
)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for
block
in
scheduler
.
kv_cache_manager
.
block_pool
.
blocks
:
assert
block
.
ref_cnt
==
0
# assert block._block_hash is None
# assert (
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
# ) == 0)
def
test_memory_leak
():
"""Test that we do not have a memory leak."""
scheduler
=
create_scheduler
(
enable_prefix_caching
=
True
)
NUM_REQUESTS
=
5
NUM_TOKENS
=
10
MAX_TOKENS
=
10
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
# Add each request.
for
request
in
requests
:
scheduler
.
add_request
(
request
)
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
make_output
(
scheduler
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
# Iterate until done.
while
True
:
scheduler_output
=
scheduler
.
schedule
()
if
len
(
scheduler
.
running
)
==
0
:
break
model_runner_output
=
make_output
(
scheduler
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
# Confirm no memory leak.
assert_scheduler_empty
(
scheduler
)
vllm/v1/attention/backends/flash_attn.py
View file @
dee71fba
...
@@ -10,9 +10,11 @@ from vllm import _custom_ops as ops
...
@@ -10,9 +10,11 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
is_quantized_kv_cache
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
get_flash_attn_version
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
...
@@ -276,13 +278,23 @@ def make_local_attention_virtual_batches(
...
@@ -276,13 +278,23 @@ def make_local_attention_virtual_batches(
block_table_local
block_table_local
def
_get_sliding_window_configs
(
vllm_config
:
VllmConfig
)
->
set
[
Optional
[
tuple
[
int
,
int
]]]:
"""Get the set of all sliding window configs used in the model."""
sliding_window_configs
:
set
[
Optional
[
tuple
[
int
,
int
]]]
=
set
()
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
for
layer
in
layers
.
values
():
assert
isinstance
(
layer
.
impl
,
FlashAttentionImpl
)
sliding_window_configs
.
add
(
layer
.
impl
.
sliding_window
)
return
sliding_window_configs
class
FlashAttentionMetadataBuilder
:
class
FlashAttentionMetadataBuilder
:
def
__init__
(
self
,
runner
:
"GPUModelRunner"
):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
):
model_config
=
runner
.
model_config
model_config
=
runner
.
model_config
self
.
runner
=
runner
self
.
runner
=
runner
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
num_heads_q
=
model_config
.
get_num_attention_heads
(
self
.
num_heads_q
=
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
)
runner
.
parallel_config
)
self
.
num_heads_kv
=
model_config
.
get_num_kv_heads
(
self
.
num_heads_kv
=
model_config
.
get_num_kv_heads
(
...
@@ -290,6 +302,11 @@ class FlashAttentionMetadataBuilder:
...
@@ -290,6 +302,11 @@ class FlashAttentionMetadataBuilder:
self
.
headdim
=
model_config
.
get_head_size
()
self
.
headdim
=
model_config
.
get_head_size
()
self
.
page_size
=
self
.
runner
.
block_size
self
.
page_size
=
self
.
runner
.
block_size
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self
.
aot_sliding_window
:
Optional
[
tuple
[
int
,
int
]]
=
None
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
False
return
False
...
@@ -307,6 +324,22 @@ class FlashAttentionMetadataBuilder:
...
@@ -307,6 +324,22 @@ class FlashAttentionMetadataBuilder:
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
if
self
.
aot_sliding_window
is
None
:
self
.
aot_sliding_window
=
(
-
1
,
-
1
)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if
self
.
aot_schedule
:
sliding_window_configs
=
_get_sliding_window_configs
(
self
.
runner
.
vllm_config
)
if
len
(
sliding_window_configs
)
==
1
:
sliding_window_config
=
sliding_window_configs
.
pop
()
if
sliding_window_config
is
not
None
:
self
.
aot_sliding_window
=
sliding_window_config
elif
len
(
sliding_window_configs
)
>
1
:
self
.
aot_schedule
=
False
def
schedule
(
batch_size
,
cu_query_lens
,
max_query_len
,
seqlens
,
def
schedule
(
batch_size
,
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
max_seq_len
,
causal
):
if
self
.
aot_schedule
:
if
self
.
aot_schedule
:
...
@@ -321,6 +354,7 @@ class FlashAttentionMetadataBuilder:
...
@@ -321,6 +354,7 @@ class FlashAttentionMetadataBuilder:
page_size
=
self
.
page_size
,
page_size
=
self
.
page_size
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_q
=
cu_query_lens
,
causal
=
causal
,
causal
=
causal
,
window_size
=
self
.
aot_sliding_window
,
)
)
return
None
return
None
...
...
vllm/v1/core/sched/scheduler.py
View file @
dee71fba
...
@@ -739,7 +739,10 @@ class Scheduler(SchedulerInterface):
...
@@ -739,7 +739,10 @@ class Scheduler(SchedulerInterface):
# Return the cached request data to the queue so they can be reused.
# Return the cached request data to the queue so they can be reused.
for
req_data
in
scheduler_output
.
scheduled_cached_reqs
:
for
req_data
in
scheduler_output
.
scheduled_cached_reqs
:
self
.
_cached_reqs_data
[
req_data
.
req_id
].
append
(
req_data
)
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
# to _cached_reqs_data will cause a memory leak.
if
req_data
.
req_id
not
in
self
.
finished_req_ids
:
self
.
_cached_reqs_data
[
req_data
.
req_id
].
append
(
req_data
)
self
.
running
=
new_running
self
.
running
=
new_running
engine_core_outputs
=
EngineCoreOutputs
(
engine_core_outputs
=
EngineCoreOutputs
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment