Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3dcb3e8b
Unverified
Commit
3dcb3e8b
authored
Apr 04, 2024
by
SangBin Cho
Committed by
GitHub
Apr 03, 2024
Browse files
[3/N] Refactor scheduler for chunked prefill scheduling (#3550)
parent
c64cf386
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1020 additions
and
255 deletions
+1020
-255
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+532
-6
tests/core/utils.py
tests/core/utils.py
+13
-6
vllm/core/scheduler.py
vllm/core/scheduler.py
+456
-241
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+1
-1
vllm/utils.py
vllm/utils.py
+18
-1
No files found.
tests/core/test_scheduler.py
View file @
3dcb3e8b
import
time
import
time
from
collections
import
deque
from
typing
import
List
from
typing
import
List
from
unittest.mock
import
MagicMock
import
pytest
# noqa
import
pytest
# noqa
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
SequenceGroup
from
.utils
import
create_dummy_prompt
from
.utils
import
create_dummy_prompt
...
@@ -177,7 +182,6 @@ def test_scheduler_max_seqs():
...
@@ -177,7 +182,6 @@ def test_scheduler_max_seqs():
def
test_scheduler_delay_factor
():
def
test_scheduler_delay_factor
():
block_size
=
4
block_size
=
4
scheduler_config
=
SchedulerConfig
(
100
,
64
,
16
,
delay_factor
=
0.5
)
scheduler_config
=
SchedulerConfig
(
100
,
64
,
16
,
delay_factor
=
0.5
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
...
@@ -189,7 +193,7 @@ def test_scheduler_delay_factor():
...
@@ -189,7 +193,7 @@ def test_scheduler_delay_factor():
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
=
block_size
)
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
=
block_size
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
prompt_run
assert
out
.
num_prefill_groups
>
0
assert
seq_group_meta
[
0
].
request_id
==
'0'
assert
seq_group_meta
[
0
].
request_id
==
'0'
# wait for a second before scheduling next prompt
# wait for a second before scheduling next prompt
...
@@ -199,11 +203,533 @@ def test_scheduler_delay_factor():
...
@@ -199,11 +203,533 @@ def test_scheduler_delay_factor():
# second prompt should *not* be scheduled
# second prompt should *not* be scheduled
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
not
out
.
prompt_run
assert
out
.
num_prefill_groups
==
0
assert
seq_group_meta
[
0
].
request_id
==
'0'
assert
seq_group_meta
[
0
].
request_id
==
'0'
# wait for more than 0.5 second and try again
# wait for more than 0.5 second and try again
time
.
sleep
(
0.6
)
time
.
sleep
(
0.6
)
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
prompt_run
assert
out
.
num_prefill_groups
>
0
assert
seq_group_meta
[
0
].
request_id
==
'1'
assert
seq_group_meta
[
0
].
request_id
==
'1'
def
test_swapped_out_prioritized
():
scheduler
=
initialize_scheduler
(
max_num_seqs
=
6
)
# best_of=2 * 3 == 6 sequences.
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
add_seq_group
(
seq_group
)
_
,
out
=
scheduler
.
schedule
()
# prefill scheduled now.
assert
len
(
out
.
scheduled_seq_groups
)
==
3
# The last request should be swapped out.
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
return
seq_group
.
request_id
!=
"2"
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
_
,
out
=
scheduler
.
schedule
()
assert
len
(
out
.
scheduled_seq_groups
)
==
2
assert
out
.
num_batched_tokens
==
2
assert
out
.
blocks_to_swap_out
!=
{}
assert
out
.
blocks_to_swap_in
==
{}
# Add 1 more task. Swap should be prioritized over prefill.
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
add_seq_group
(
seq_group
)
_
,
out
=
scheduler
.
schedule
()
assert
len
(
out
.
scheduled_seq_groups
)
==
3
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
3
assert
out
.
blocks_to_swap_in
!=
{}
assert
out
.
blocks_to_swap_out
==
{}
def
initialize_scheduler
(
*
,
max_num_seqs
=
1000
,
max_token_budget
=
1000
,
max_model_len
=
1000
,
lora_config
=
None
):
block_size
=
4
scheduler_config
=
SchedulerConfig
(
max_token_budget
,
max_num_seqs
,
max_model_len
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
return
scheduler
def
create_token_budget
(
num_batched_tokens
:
int
=
0
,
num_curr_seqs
:
int
=
0
,
token_budget
:
int
=
10000
,
max_num_seqs
:
int
=
10000
)
->
SchedulingBudget
:
return
SchedulingBudget
(
num_batched_tokens
=
num_batched_tokens
,
num_curr_seqs
=
num_curr_seqs
,
token_budget
=
token_budget
,
max_num_seqs
=
max_num_seqs
,
)
def
test_prefill_schedule_max_prompt_len
():
"""
Test prompt longer than max_prompt_len is aborted.
"""
scheduler
=
initialize_scheduler
(
max_model_len
=
30
)
_
,
seq_group
=
create_dummy_prompt
(
0
,
prompt_length
=
60
)
waiting
=
deque
([
seq_group
])
budget
=
create_token_budget
()
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remaining_waiting
)
==
0
def
test_prefill_schedule_token_budget
():
"""
Test token budget respected.
"""
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
(
token_budget
=
0
)
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
# 0 token budget == nothing is scheduled.
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remaining_waiting
)
==
2
# 60 token budget == 1 request scheduled.
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
60
assert
budget
.
num_curr_seqs
==
1
assert
len
(
remaining_waiting
)
==
1
# Test when current_batched_tokens respected.
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
(
num_batched_tokens
=
30
,
token_budget
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
# Cannot schedule a prompt that doesn't fit the budget.
waiting
.
append
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
30
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remaining_waiting
)
==
1
budget
=
create_token_budget
(
num_batched_tokens
=
30
,
token_budget
=
90
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
90
assert
budget
.
num_curr_seqs
==
1
assert
len
(
remaining_waiting
)
==
0
def
test_prefill_schedule_max_seqs
():
"""
Test max seq respected.
"""
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
(
max_num_seqs
=
2
)
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
assert
budget
.
num_curr_seqs
==
2
assert
len
(
remaining_waiting
)
==
1
# Verify curr_num_seqs respected.
waiting
=
deque
()
budget
=
create_token_budget
(
num_curr_seqs
=
2
,
max_num_seqs
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
2
assert
len
(
remaining_waiting
)
==
1
def
test_prefill_schedule_max_lora
():
"""
Test max lora is respected and prioritized.
"""
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
waiting
=
deque
()
budget
=
create_token_budget
(
token_budget
=
120
)
curr_loras
=
set
()
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
lora_request
=
LoRARequest
(
lora_name
=
str
(
i
),
lora_int_id
=
i
+
1
,
lora_local_path
=
"abc"
))
waiting
.
append
(
seq_group
)
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled.
# If a request is not scheduled because it hits max lora, it is
# prioritized. Verify that.
for
i
in
range
(
2
,
4
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
# Schedule 2 requests (0 and 2)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
curr_loras
)
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
assert
budget
.
num_curr_seqs
==
2
assert
len
(
remaining_waiting
)
==
2
assert
len
(
curr_loras
)
==
1
# The second lora request is scheduled next as FCFS policy.
# Reset curr_loras so that it can be scheduled.
curr_loras
=
set
()
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
remaining_waiting
,
budget
,
curr_loras
)
assert
len
(
output
.
seq_groups
)
==
1
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"1"
assert
len
(
remaining_waiting
)
==
1
assert
len
(
curr_loras
)
==
1
assert
budget
.
num_batched_tokens
==
60
def
test_prefill_schedule_no_block_manager_capacity
():
"""
Test sequence cannot be scheduled due to block manager has no capacity.
"""
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
LATER
remainig_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remainig_waiting
)
==
3
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
NEVER
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
assert
len
(
output
.
ignored_seq_groups
)
==
3
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remaining_waiting
)
==
0
def
test_decode_schedule_preempted
():
"""
Test decodes cannot be scheduled and preempted.
"""
scheduler
=
initialize_scheduler
()
running
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
running
.
append
(
seq_group
)
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
return
seq_group
.
request_id
!=
"1"
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
# 1 cannot be scheduled, and the lowest priority (request 2)
# should be preempted. 1 will also be preempted.
budget
=
create_token_budget
(
num_batched_tokens
=
3
,
num_curr_seqs
=
3
)
remainig_running
,
output
=
scheduler
.
_schedule_decodes
(
running
,
budget
,
curr_loras
,
policy
)
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
seq_groups
)
==
1
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"0"
assert
len
(
output
.
preempted
)
==
2
# Verify budgets are updated.
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
1
# Both should be preempted, not swapped.
assert
output
.
blocks_to_swap_out
==
{}
# Nothing is copied.
assert
output
.
blocks_to_copy
==
{}
def
test_decode_swap_beam_search
():
"""
Test best_of > 1 swap out blocks
"""
scheduler
=
initialize_scheduler
()
running
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
running
.
append
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
return
seq_group
.
request_id
!=
"2"
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
scheduler
.
block_manager
.
swap_out
=
MagicMock
()
expected_swap_mapping
=
{
"5"
:
"7"
}
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
budget
=
create_token_budget
(
num_batched_tokens
=
3
,
num_curr_seqs
=
3
)
remainig_running
,
output
=
scheduler
.
_schedule_decodes
(
running
,
budget
,
curr_loras
,
policy
)
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"0"
assert
output
.
seq_groups
[
1
].
seq_group
.
request_id
==
"1"
assert
len
(
output
.
preempted
)
==
0
assert
len
(
output
.
swapped_out
)
==
1
# Budget should refledct preempted requests.
assert
budget
.
num_batched_tokens
==
2
# since there are 2 sequences, 2 should be subtracted.
assert
budget
.
num_curr_seqs
==
1
# Both should be preempted, not swapped.
assert
output
.
blocks_to_swap_out
==
expected_swap_mapping
# Nothing is copied.
assert
output
.
blocks_to_copy
==
{}
def
test_schedule_decode_blocks_to_copy_update
():
"""
Verify blocks_to_copy is updated.
"""
scheduler
=
initialize_scheduler
()
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
running
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
scheduler
.
_allocate_and_set_running
(
seq_group
)
running
.
append
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
{
2
:
[
3
]}
budget
=
create_token_budget
()
remaining_running
,
output
=
scheduler
.
_schedule_decodes
(
running
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_running
)
==
0
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
output
.
preempted
)
==
0
assert
len
(
output
.
swapped_out
)
==
0
# Nothing is preempted.
assert
output
.
blocks_to_swap_out
==
{}
# Since append_slot returns the source -> dist mapping, it should
# applied.
assert
output
.
blocks_to_copy
==
{
2
:
[
3
]}
def
test_schedule_swapped_simple
():
scheduler
=
initialize_scheduler
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
=
{}
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
swapped
.
append
(
seq_group
)
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
0
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
seq_groups
)
==
1
# swap in is the reverse of swap out
blocks_to_swap_in_reverse
=
{}
for
swapin
,
swapout
in
output
.
blocks_to_swap_in
.
items
():
blocks_to_swap_in_reverse
[
swapout
]
=
swapin
assert
blocks_to_swap_out
==
blocks_to_swap_in_reverse
def
test_schedule_swapped_max_token_budget
():
scheduler
=
initialize_scheduler
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
=
{}
for
_
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
swapped
.
append
(
seq_group
)
budget
=
create_token_budget
(
token_budget
=
1
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
seq_groups
)
==
1
# Verify num_batched_tokens are respected.
budget
=
create_token_budget
(
num_batched_tokens
=
1
,
token_budget
=
1
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
remaining_swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
0
assert
len
(
output
.
seq_groups
)
==
0
def
test_schedule_swapped_max_seqs
():
scheduler
=
initialize_scheduler
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
=
{}
for
_
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
swapped
.
append
(
seq_group
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
seq_groups
)
==
1
# Verify num_curr_seqs are respected.
budget
=
create_token_budget
(
num_curr_seqs
=
2
,
max_num_seqs
=
2
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
remaining_swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
seq_groups
)
==
0
def
test_schedule_swapped_max_loras
():
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
set
()
blocks_to_swap_out
=
{}
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
lora_request
=
LoRARequest
(
lora_name
=
str
(
i
),
lora_int_id
=
i
+
1
,
lora_local_path
=
"abc"
))
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
swapped
.
append
(
seq_group
)
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
1
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
curr_loras
)
==
1
def
test_schedule_swapped_cannot_swap_in
():
scheduler
=
initialize_scheduler
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
=
{}
for
_
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
swapped
.
append
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
False
# Since we cannot swap in, none of the requests are swapped in.
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
output
.
seq_groups
)
==
0
def
test_schedule_swapped_blocks_to_copy
():
scheduler
=
initialize_scheduler
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
blocks_to_swap_out
=
{}
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
swapped
.
append
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
{
2
:
[
3
]}
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
seq_groups
)
==
1
assert
output
.
blocks_to_copy
==
{
2
:
[
3
]}
tests/core/utils.py
View file @
3dcb3e8b
import
time
import
time
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
,
Sequence
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
Sequence
,
SequenceGroup
def
create_dummy_prompt
(
def
create_dummy_prompt
(
request_id
:
str
,
request_id
:
str
,
prompt_length
:
int
,
prompt_length
:
int
,
block_size
:
int
=
None
)
->
Tuple
[
Sequence
,
SequenceGroup
]:
block_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
use_beam_search
:
bool
=
False
,
best_of
:
int
=
1
,
)
->
Tuple
[
Sequence
,
SequenceGroup
]:
if
not
block_size
:
if
not
block_size
:
block_size
=
prompt_length
block_size
=
prompt_length
...
@@ -17,8 +22,10 @@ def create_dummy_prompt(
...
@@ -17,8 +22,10 @@ def create_dummy_prompt(
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt
=
Sequence
(
int
(
request_id
),
prompt_str
,
prompt_tokens
,
block_size
)
prompt
=
Sequence
(
int
(
request_id
),
prompt_str
,
prompt_tokens
,
block_size
)
seq_group
=
SequenceGroup
(
request_id
,
[
prompt
],
SamplingParams
(),
seq_group
=
SequenceGroup
(
time
.
time
(),
None
)
request_id
,
[
prompt
],
SamplingParams
(
use_beam_search
=
use_beam_search
,
best_of
=
best_of
),
time
.
time
(),
lora_request
)
return
prompt
,
seq_group
return
prompt
,
seq_group
...
...
vllm/core/scheduler.py
View file @
3dcb3e8b
...
@@ -6,11 +6,12 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
...
@@ -6,11 +6,12 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.utils
import
merge_dicts
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -28,9 +29,19 @@ class PreemptionMode(enum.Enum):
...
@@ -28,9 +29,19 @@ class PreemptionMode(enum.Enum):
RECOMPUTE
=
enum
.
auto
()
RECOMPUTE
=
enum
.
auto
()
# seq_group: SequenceGroup to schedule.
@
dataclass
# token_chunk_size: The number of prefill tokens to be processed in the next
class
SchedulingBudget
:
# step.
"""The available slots for scheduling."""
num_batched_tokens
:
int
num_curr_seqs
:
int
token_budget
:
int
max_num_seqs
:
int
def
can_schedule
(
self
,
*
,
num_new_tokens
:
int
,
num_new_seqs
:
int
):
return
(
self
.
num_batched_tokens
+
num_new_tokens
<=
self
.
token_budget
and
self
.
num_curr_seqs
+
num_new_seqs
<=
self
.
max_num_seqs
)
@
dataclass
@
dataclass
class
ScheduledSequenceGroup
:
class
ScheduledSequenceGroup
:
# A sequence group that's scheduled.
# A sequence group that's scheduled.
...
@@ -41,53 +52,28 @@ class ScheduledSequenceGroup:
...
@@ -41,53 +52,28 @@ class ScheduledSequenceGroup:
token_chunk_size
:
int
token_chunk_size
:
int
@
dataclass
class
SchedulerOutputs
:
class
SchedulerOutputs
:
# Scheduled sequence groups.
def
__init__
(
scheduled_seq_groups
:
Iterable
[
ScheduledSequenceGroup
]
self
,
# Number of prefill groups scheduled.
scheduled_seq_groups
:
Iterable
[
ScheduledSequenceGroup
],
num_prefill_groups
:
int
prompt_run
:
bool
,
num_batched_tokens
:
int
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
ignored_seq_groups
:
List
[
SequenceGroup
],
num_lookahead_slots
:
int
,
)
->
None
:
"""A list of sequence groups to be scheduled as a single batch.
Args:
scheduled_seq_groups: A tuple of scheduled sequence group and its
token chunk size.
prompt_run: True if all sequence groups are in prefill phase.
If False, all sequence groups are in decoding phase.
num_batched_tokens: Total number of batched tokens.
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
number.
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
number.
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
ignored_seq_groups: Sequence groups that are going to be ignored.
"""
# A tuple of scheduled sequence group and its chunk size.
self
.
scheduled_seq_groups
:
ScheduledSequenceGroup
=
scheduled_seq_groups
# True if all sequence groups are in prefill phase. If False, all
# sequence groups are in decoding phase.
self
.
prompt_run
:
bool
=
prompt_run
# Total number of batched tokens.
# Total number of batched tokens.
self
.
num_batched_tokens
:
int
=
num_batched_tokens
num_batched_tokens
:
int
# Blocks to swap in. Dict of CPU -> GPU block number.
# Blocks to swap in. Dict of CPU -> GPU block number.
self
.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
blocks_to_swap_in
blocks_to_swap_in
:
Dict
[
int
,
int
]
# Blocks to swap out. Dict of GPU -> CPU block number.
# Blocks to swap out. Dict of GPU -> CPU block number.
self
.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
blocks_to_swap_out
blocks_to_swap_out
:
Dict
[
int
,
int
]
# Blocks to copy. Source to a list of dest blocks.
# Blocks to copy. Source to a list of dest blocks.
self
.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
blocks_to_copy
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
# Sequence groups that are going to be ignored.
# Sequence groups that are going to be ignored.
self
.
ignored_seq_groups
:
List
[
SequenceGroup
]
=
ignored_seq_groups
ignored_seq_groups
:
List
[
SequenceGroup
]
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
def
__post_init__
(
self
):
# Swap in and swap out should never happen at the same time.
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
assert
not
(
self
.
blocks_to_swap_in
and
self
.
blocks_to_swap_out
)
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
num_loras
:
int
=
len
(
self
.
lora_requests
)
self
.
num_loras
:
int
=
len
(
self
.
lora_requests
)
if
self
.
num_loras
>
0
:
if
self
.
num_loras
>
0
:
...
@@ -108,6 +94,73 @@ class SchedulerOutputs:
...
@@ -108,6 +94,73 @@ class SchedulerOutputs:
return
{
g
.
seq_group
.
lora_request
for
g
in
self
.
scheduled_seq_groups
}
return
{
g
.
seq_group
.
lora_request
for
g
in
self
.
scheduled_seq_groups
}
@
dataclass
class
SchedulerDecodeOutputs
:
"""Outputs of the decoding phase of the scheduler."""
# Selected sequence groups for decoding.
seq_groups
:
List
[
SequenceGroup
]
# The preempted sequences.
preempted
:
List
[
SequenceGroup
]
# Sequences that are swapped out.
swapped_out
:
List
[
SequenceGroup
]
# The blocks to swap out.
blocks_to_swap_out
:
Dict
[
int
,
int
]
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
num_lookahead_slots
:
int
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerDecodeOutputs"
:
return
SchedulerDecodeOutputs
(
seq_groups
=
[],
preempted
=
[],
swapped_out
=
[],
blocks_to_swap_out
=
{},
blocks_to_copy
=
{},
num_lookahead_slots
=
0
,
)
@
dataclass
class
SchedulerSwappedInOutputs
:
"""Outputs of the decoding phase of the scheduler."""
# Selected sequence groups for decoding.
seq_groups
:
List
[
SequenceGroup
]
# The blocks to swap in.
blocks_to_swap_in
:
Dict
[
int
,
int
]
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
# # The number of batched tokens.
num_lookahead_slots
:
int
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerSwappedInOutputs"
:
return
SchedulerSwappedInOutputs
(
seq_groups
=
[],
blocks_to_swap_in
=
{},
blocks_to_copy
=
{},
num_lookahead_slots
=
0
,
)
@
dataclass
class
SchedulerPrefillOutputs
:
"""Outputs of the prefill phase of the scheduler."""
# Selected sequence groups for prefill.
seq_groups
:
List
[
SequenceGroup
]
# Ignored sequence groups.
ignored_seq_groups
:
List
[
SequenceGroup
]
num_lookahead_slots
:
int
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerPrefillOutputs"
:
return
SchedulerPrefillOutputs
(
seq_groups
=
[],
ignored_seq_groups
=
[],
num_lookahead_slots
=
0
,
)
class
Scheduler
:
class
Scheduler
:
def
__init__
(
def
__init__
(
...
@@ -123,6 +176,7 @@ class Scheduler:
...
@@ -123,6 +176,7 @@ class Scheduler:
# LoRAs. This should be improved in the future.
# LoRAs. This should be improved in the future.
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
# TODO(sang): Fix it after chunked prefill is enabled.
self
.
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
...
@@ -142,10 +196,13 @@ class Scheduler:
...
@@ -142,10 +196,13 @@ class Scheduler:
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
)
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
)
# Sequence groups in the WAITING state.
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
self
.
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
self
.
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the RUNNING state.
# Sequence groups in the RUNNING state.
# Contain decode requests.
self
.
running
:
Deque
[
SequenceGroup
]
=
deque
()
self
.
running
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the SWAPPED state.
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
# Time at previous scheduling step
# Time at previous scheduling step
...
@@ -159,8 +216,14 @@ class Scheduler:
...
@@ -159,8 +216,14 @@ class Scheduler:
def
lora_enabled
(
self
)
->
bool
:
def
lora_enabled
(
self
)
->
bool
:
return
bool
(
self
.
lora_config
)
return
bool
(
self
.
lora_config
)
@
property
def
num_decoding_tokens_per_seq
(
self
)
->
int
:
"""The number of new tokens."""
return
1
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
logger
.
debug
(
f
"add_seq_group
{
seq_group
.
request_id
}
"
)
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
...
@@ -205,50 +268,237 @@ class Scheduler:
...
@@ -205,50 +268,237 @@ class Scheduler:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
def
_schedule_decodes
(
self
,
running_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
)
->
Tuple
[
deque
,
SchedulerDecodeOutputs
]:
"""Schedule sequence groups in a decoding stage.
NOTE(sang): All the RUNNING num_batched_tokens, num_curr_seqs,
and curr_loras should be already included in `budget` and `curr_loras`.
The API doesn't ADD UP these values.
Note that `budget` and `curr_loras` are still subtracted/popped when
any running requests are preempted from this API.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerDecodeOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
# Fix the current time.
seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
swapped_out
:
List
[
SequenceGroup
]
=
[]
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
now
=
time
.
time
()
now
=
time
.
time
()
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
# Join waiting sequences if possible.
while
running_queue
:
if
not
self
.
swapped
:
# NOTE: running
seq_group
=
running_queue
[
0
]
num_running_tokens
=
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
*
self
.
num_decoding_tokens_per_seq
)
num_running_seqs
=
seq_group
.
get_max_num_running_seqs
()
running_queue
.
popleft
()
while
not
self
.
_can_append_slots
(
seq_group
):
# Increase the budget as requests are preempted.
budget
.
num_batched_tokens
-=
num_running_tokens
budget
.
num_curr_seqs
-=
num_running_seqs
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
pop
(
seq_group
.
lora_int_id
)
if
running_queue
:
# Preempt the lowest-priority sequence groups.
victim_seq_group
=
running_queue
.
pop
()
preempted_mode
=
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
victim_seq_group
)
else
:
swapped_out
.
append
(
victim_seq_group
)
else
:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
preempted_mode
=
self
.
_preempt
(
seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
seq_group
)
else
:
swapped_out
.
append
(
seq_group
)
break
else
:
logger
.
debug
(
f
"append slot for
{
seq_group
}
"
)
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
1
))
# Make sure all queues are updated.
assert
len
(
running_queue
)
==
0
return
running_queue
,
SchedulerDecodeOutputs
(
seq_groups
=
seq_groups
,
preempted
=
preempted
,
swapped_out
=
swapped_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
))
def
_schedule_swapped
(
self
,
swapped_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
)
->
Tuple
[
deque
,
SchedulerSwappedInOutputs
]:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
leftover_swapped
=
deque
()
while
swapped_queue
:
seq_group
=
swapped_queue
[
0
]
# If the sequence group cannot be swapped in, stop.
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
break
lora_int_id
=
0
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
if
(
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped
.
appendleft
(
seq_group
)
swapped_queue
.
popleft
()
continue
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
num_new_tokens
=
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
*
self
.
num_decoding_tokens_per_seq
)
if
not
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens
,
num_new_seqs
=
num_new_seqs
):
break
if
lora_int_id
>
0
and
curr_loras
is
not
None
:
curr_loras
.
add
(
lora_int_id
)
swapped_queue
.
popleft
()
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
1
))
budget
.
num_batched_tokens
+=
num_new_tokens
budget
.
num_curr_seqs
+=
num_new_seqs
swapped_queue
.
extendleft
(
leftover_swapped
)
return
swapped_queue
,
SchedulerSwappedInOutputs
(
seq_groups
=
seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
))
def
_schedule_prefills
(
self
,
waiting_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
)
->
Tuple
[
deque
,
SchedulerPrefillOutputs
]:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
as a new prefill (that starts from beginning -> most recently generated
tokens).
It schedules waiting requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are scheduled.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
scheduled
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
# The total number of sequences on the fly, including the
# We don't sort waiting queue because we assume it is sorted.
# requests in the generation phase.
# Copy the queue so that the input queue is not modified.
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
waiting_queue
=
deque
([
s
for
s
in
waiting_queue
])
for
seq_group
in
self
.
running
)
curr_loras
=
set
(
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences
=
deque
()
leftover_waiting_sequences
=
deque
()
num_batched_tokens
=
0
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
while
self
.
_passed_delay
(
now
)
and
self
.
waiting
:
seq_group
=
waiting_queue
[
0
]
seq_group
=
self
.
waiting
[
0
]
waiting_seqs
=
seq_group
.
get_seqs
(
waiting_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)
status
=
SequenceStatus
.
WAITING
)
assert
len
(
waiting_seqs
)
==
1
,
(
assert
len
(
waiting_seqs
)
==
1
,
(
"Waiting sequence group should have only one prompt "
"Waiting sequence group should have only one prompt "
"sequence."
)
"sequence."
)
# get_len includes output tokens if the request has been
# preempted.
num_prompt_tokens
=
waiting_seqs
[
0
].
get_len
()
num_prefill_tokens
=
waiting_seqs
[
0
].
get_len
()
if
num_prompt_tokens
>
self
.
prompt_limit
:
if
num_prefill_tokens
>
self
.
prompt_limit
:
logger
.
warning
(
logger
.
warning
(
f
"Input prompt (
{
num_pr
efill
_tokens
}
tokens) is too "
f
"Input prompt (
{
num_pr
ompt
_tokens
}
tokens) is too
long
"
f
"long
and exceeds limit of
{
self
.
prompt_limit
}
"
)
f
"
and exceeds limit of
{
self
.
prompt_limit
}
"
)
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
popleft
()
waiting
_queue
.
popleft
()
continue
continue
# If the sequence group cannot be allocated, stop.
# If the sequence group cannot be allocated, stop.
...
@@ -257,162 +507,126 @@ class Scheduler:
...
@@ -257,162 +507,126 @@ class Scheduler:
break
break
elif
can_allocate
==
AllocStatus
.
NEVER
:
elif
can_allocate
==
AllocStatus
.
NEVER
:
logger
.
warning
(
logger
.
warning
(
f
"Input prompt (
{
num_pr
efill
_tokens
}
tokens) is too "
f
"Input prompt (
{
num_pr
ompt
_tokens
}
tokens) is too
long
"
f
"long
and exceeds the capacity of block_manager"
)
f
"
and exceeds the capacity of block_manager"
)
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
popleft
()
waiting
_queue
.
popleft
()
continue
continue
lora_int_id
=
0
lora_int_id
=
0
if
self
.
lora_enabled
:
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
lora_int_id
=
seq_group
.
lora_int_id
if
(
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
if
(
self
.
lora_enabled
and
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
# We don't have a space for another LoRA, so
# We don't have a space for another LoRA, so
# we ignore this request for now.
# we ignore this request for now.
leftover_waiting_sequences
.
appendleft
(
seq_group
)
leftover_waiting_sequences
.
appendleft
(
seq_group
)
self
.
waiting
.
popleft
()
waiting
_queue
.
popleft
()
continue
continue
# If the number of batched tokens exceeds the limit, stop.
num_batched_tokens
+=
num_prefill_tokens
if
(
num_batched_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
if
(
num_curr_seqs
+
num_new_seqs
>
if
not
budget
.
can_schedule
(
num_new_tokens
=
num_prompt_tokens
,
self
.
scheduler_config
.
max_num
_seqs
):
num_new_seqs
=
num_new
_seqs
):
break
break
if
lora_int_id
>
0
:
# Can schedule this request.
if
curr_loras
is
not
None
and
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
curr_loras
.
add
(
lora_int_id
)
self
.
waiting
.
popleft
()
waiting
_queue
.
popleft
()
self
.
_allocate
(
seq_group
)
self
.
_allocate
_and_set_running
(
seq_group
)
self
.
running
.
append
(
seq_group
)
seq_groups
.
append
(
num_curr_seqs
+=
num_new_seqs
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
scheduled
.
append
(
token_chunk_size
=
num_prompt_tokens
))
ScheduledSequenceGroup
(
budget
.
num_batched_tokens
+=
num_prompt_tokens
seq_group
=
seq_group
,
budget
.
num_curr_seqs
+=
num_new_seqs
token_chunk_size
=
num_prefill_tokens
))
self
.
waiting
.
extendleft
(
leftover_waiting_sequences
)
# Queue requests that couldn't be scheduled.
waiting_queue
.
extendleft
(
leftover_waiting_sequences
)
if
scheduled
or
ignored_
seq_groups
:
if
len
(
seq_groups
)
>
0
:
self
.
prev_prompt
=
True
self
.
prev_prompt
=
True
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
return
waiting_queue
,
SchedulerPrefillOutputs
(
prompt_run
=
True
,
seq_groups
=
seq_groups
,
num_batched_tokens
=
num_batched_tokens
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
ignored_seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
is_prefill
=
True
),
)
return
scheduler_outputs
# NOTE(woosuk): Preemption happens only when there is no available slot
def
_schedule
(
self
)
->
SchedulerOutputs
:
# to keep all the sequence groups in the RUNNING state.
"""Batch requests that are queued..
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
# Reserve new token slots for the running sequence groups.
The current policy is designed to opimimize the throughput. First,
running
:
Deque
[
SequenceGroup
]
=
deque
()
it batches as many prefill requests as possible. And it schedules
preempted
:
List
[
SequenceGroup
]
=
[]
decodes. If there's a pressure on GPU memory, decode requests can
while
self
.
running
:
be swapped or preempted.
seq_group
=
self
.
running
.
popleft
()
"""
while
not
self
.
_can_append_slots
(
seq_group
):
# Include running requests to the budget.
if
self
.
running
:
budget
=
SchedulingBudget
(
# Preempt the lowest-priority sequence groups.
num_batched_tokens
=
sum
(
victim_seq_group
=
self
.
running
.
pop
()
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
for
seq_group
in
self
.
running
),
preempted
.
append
(
victim_seq_group
)
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
else
:
for
seq_group
in
self
.
running
),
# No other sequence groups can be preempted.
token_budget
=
self
.
scheduler_config
.
max_num_batched_tokens
,
# Preempt the current sequence group.
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
self
.
_preempt
(
seq_group
,
blocks_to_swap_out
)
)
preempted
.
append
(
seq_group
)
break
else
:
# Append new slots to the sequence group.
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
running
.
append
(
seq_group
)
self
.
running
=
running
# Swap in the sequence groups in the SWAPPED state if possible.
self
.
swapped
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
swapped
)
if
not
preempted
:
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
curr_loras
=
set
(
curr_loras
=
set
(
seq_group
.
lora_int_id
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
leftover_swapped
=
deque
()
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
remaining_running
,
decodes
=
(
self
.
running
,
SchedulerDecodeOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
while
self
.
swapped
:
# If any requests are swapped, prioritized swapped requests.
seq_group
=
self
.
swapped
[
0
]
if
not
self
.
swapped
:
lora_int_id
=
0
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
if
self
.
lora_enabled
:
self
.
waiting
,
budget
,
curr_loras
)
lora_int_id
=
seq_group
.
lora_int_id
if
(
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
# Don't schedule decodes if prefills are scheduled.
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
if
len
(
prefills
.
seq_groups
)
==
0
:
# We don't have a space for another LoRA, so
remaining_running
,
decodes
=
self
.
_schedule_decodes
(
# we ignore this request for now.
self
.
running
,
budget
,
curr_loras
,
self
.
policy
)
leftover_swapped
.
appendleft
(
seq_group
)
# If any sequence group is preempted, do not swap in any sequence
self
.
swapped
.
popleft
()
# group. because it means there's no slot for new running requests.
continue
if
len
(
decodes
.
preempted
)
+
len
(
decodes
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
# If the sequence group cannot be swapped in, stop.
self
.
swapped
,
budget
,
curr_loras
,
self
.
policy
)
if
not
self
.
_can_swap_in
(
seq_group
):
break
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
# The total number of sequences in the RUNNING state should not
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
# Update waiting requests.
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
waiting
=
remaining_waiting
self
.
scheduler_config
.
max_num_seqs
):
self
.
waiting
.
extendleft
(
decodes
.
preempted
)
break
# Update new running requests.
self
.
running
=
remaining_running
if
lora_int_id
>
0
:
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
curr_loras
.
add
(
lora_int_id
)
self
.
running
.
extend
([
s
.
seq_group
for
s
in
decodes
.
seq_groups
])
self
.
swapped
.
popleft
()
self
.
running
.
extend
([
s
.
seq_group
for
s
in
swapped_in
.
seq_groups
])
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
# Update swapped requests.
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
self
.
swapped
=
remaining_swapped
num_curr_seqs
+=
num_new_seqs
self
.
swapped
.
extend
(
decodes
.
swapped_out
)
self
.
running
.
append
(
seq_group
)
return
SchedulerOutputs
(
self
.
swapped
.
extendleft
(
leftover_swapped
)
scheduled_seq_groups
=
prefills
.
seq_groups
+
decodes
.
seq_groups
+
swapped_in
.
seq_groups
,
# Each sequence in the generation phase only takes one token slot.
num_prefill_groups
=
len
(
prefills
.
seq_groups
),
# Therefore, the number of batched tokens is equal to the number of
num_batched_tokens
=
budget
.
num_batched_tokens
,
# sequences in the RUNNING state.
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
num_batched_tokens
=
sum
(
blocks_to_swap_out
=
decodes
.
blocks_to_swap_out
,
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
blocks_to_copy
=
merge_dicts
(
decodes
.
blocks_to_copy
,
for
seq_group
in
self
.
running
)
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
scheduler_outputs
=
SchedulerOutputs
(
num_lookahead_slots
=
(
prefills
.
num_lookahead_slots
+
scheduled_seq_groups
=
[
decodes
.
num_lookahead_slots
+
ScheduledSequenceGroup
(
seq_group
=
running_group
,
swapped_in
.
num_lookahead_slots
),
token_chunk_size
=
1
)
for
running_group
in
self
.
running
],
prompt_run
=
False
,
num_batched_tokens
=
num_batched_tokens
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
[],
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
),
)
)
return
scheduler_outputs
def
_can_append_slots
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
_can_append_slots
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
"""Determine whether or not we have enough space in the KV cache to
"""Determine whether or not we have enough space in the KV cache to
...
@@ -444,7 +658,8 @@ class Scheduler:
...
@@ -444,7 +658,8 @@ class Scheduler:
# Create input data structures.
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group
.
maybe_set_first_scheduled_time
(
now
)
...
@@ -464,9 +679,12 @@ class Scheduler:
...
@@ -464,9 +679,12 @@ class Scheduler:
self
.
block_manager
.
get_common_computed_block_ids
(
self
.
block_manager
.
get_common_computed_block_ids
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt
=
i
<
scheduler_outputs
.
num_prefill_groups
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
seq_group
.
request_id
,
request_id
=
seq_group
.
request_id
,
is_prompt
=
scheduler_outputs
.
prompt
_run
,
is_prompt
=
is_
prompt
,
seq_data
=
seq_data
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
...
@@ -479,7 +697,7 @@ class Scheduler:
...
@@ -479,7 +697,7 @@ class Scheduler:
# the subsequent comms can still use delta, but
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
prompt_run
else
None
,
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
...
@@ -504,7 +722,7 @@ class Scheduler:
...
@@ -504,7 +722,7 @@ class Scheduler:
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
if
not
seq_group
.
is_finished
())
if
not
seq_group
.
is_finished
())
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_allocate
_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
self
.
block_manager
.
allocate
(
seq_group
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
...
@@ -539,7 +757,7 @@ class Scheduler:
...
@@ -539,7 +757,7 @@ class Scheduler:
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
preemption_mode
:
Optional
[
PreemptionMode
]
=
None
,
preemption_mode
:
Optional
[
PreemptionMode
]
=
None
,
)
->
Non
e
:
)
->
PreemptionMod
e
:
# If preemption mode is not specified, we determine the mode as follows:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# swapping. However, when the sequence group has multiple sequences
...
@@ -562,6 +780,7 @@ class Scheduler:
...
@@ -562,6 +780,7 @@ class Scheduler:
self
.
_preempt_by_swap
(
seq_group
,
blocks_to_swap_out
)
self
.
_preempt_by_swap
(
seq_group
,
blocks_to_swap_out
)
else
:
else
:
raise
AssertionError
(
"Invalid preemption mode."
)
raise
AssertionError
(
"Invalid preemption mode."
)
return
preemption_mode
def
_preempt_by_recompute
(
def
_preempt_by_recompute
(
self
,
self
,
...
@@ -573,9 +792,6 @@ class Scheduler:
...
@@ -573,9 +792,6 @@ class Scheduler:
seq
.
status
=
SequenceStatus
.
WAITING
seq
.
status
=
SequenceStatus
.
WAITING
self
.
free_seq
(
seq
)
self
.
free_seq
(
seq
)
seq
.
reset_state_for_recompute
()
seq
.
reset_state_for_recompute
()
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
self
.
waiting
.
appendleft
(
seq_group
)
def
_preempt_by_swap
(
def
_preempt_by_swap
(
self
,
self
,
...
@@ -583,7 +799,6 @@ class Scheduler:
...
@@ -583,7 +799,6 @@ class Scheduler:
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
)
->
None
:
)
->
None
:
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
swapped
.
append
(
seq_group
)
def
_swap_in
(
def
_swap_in
(
self
,
self
,
...
...
vllm/engine/llm_engine.py
View file @
3dcb3e8b
...
@@ -728,7 +728,7 @@ class LLMEngine:
...
@@ -728,7 +728,7 @@ class LLMEngine:
time_per_output_tokens
=
[]
time_per_output_tokens
=
[]
time_e2e_requests
=
[]
time_e2e_requests
=
[]
if
scheduler_outputs
is
not
None
:
if
scheduler_outputs
is
not
None
:
prompt_run
=
scheduler_outputs
.
prompt_run
prompt_run
=
scheduler_outputs
.
num_prefill_groups
>
0
# Number of Tokens.
# Number of Tokens.
if
prompt_run
:
if
prompt_run
:
...
...
vllm/utils.py
View file @
3dcb3e8b
...
@@ -6,7 +6,7 @@ import socket
...
@@ -6,7 +6,7 @@ import socket
import
subprocess
import
subprocess
import
uuid
import
uuid
import
warnings
import
warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
,
defaultdict
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Generic
,
Hashable
,
List
,
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Generic
,
Hashable
,
List
,
...
@@ -450,3 +450,20 @@ def maybe_expand_dim(tensor: torch.Tensor,
...
@@ -450,3 +450,20 @@ def maybe_expand_dim(tensor: torch.Tensor,
if
tensor
.
ndim
<
target_dims
:
if
tensor
.
ndim
<
target_dims
:
tensor
=
tensor
.
view
(
-
1
,
*
([
size
]
*
(
target_dims
-
tensor
.
ndim
)))
tensor
=
tensor
.
view
(
-
1
,
*
([
size
]
*
(
target_dims
-
tensor
.
ndim
)))
return
tensor
return
tensor
def
merge_dicts
(
dict1
:
dict
[
Any
,
list
[
Any
]],
dict2
:
dict
[
Any
,
list
[
Any
]])
->
dict
[
Any
,
list
[
Any
]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict
=
defaultdict
(
list
)
for
key
,
value
in
dict1
.
items
():
merged_dict
[
key
].
extend
(
value
)
for
key
,
value
in
dict2
.
items
():
merged_dict
[
key
].
extend
(
value
)
return
dict
(
merged_dict
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment