Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8a7e932
Unverified
Commit
c8a7e932
authored
Jul 31, 2024
by
youkaichao
Committed by
GitHub
Jul 31, 2024
Browse files
[core][scheduler] simplify and improve scheduler (#6867)
parent
3c10591e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
214 deletions
+112
-214
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+1
-1
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+68
-95
vllm/core/policy.py
vllm/core/policy.py
+0
-45
vllm/core/scheduler.py
vllm/core/scheduler.py
+43
-73
No files found.
tests/core/block/e2e/test_correctness.py
View file @
c8a7e932
...
...
@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"num_gpu_blocks_override"
:
2
*
(
16
+
1
),
"num_gpu_blocks_override"
:
2
*
(
16
+
2
),
}
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
...
...
tests/core/test_scheduler.py
View file @
c8a7e932
import
time
from
collections
import
deque
from
typing
import
Deque
,
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
from
unittest.mock
import
MagicMock
import
pytest
# noqa
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
,
SequenceGroup
,
SequenceStatus
...
...
@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
"""
scheduler
=
initialize_scheduler
(
max_model_len
=
30
)
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
=
60
)
waiting
=
deque
([
seq_group
]
)
scheduler
.
add_seq_group
(
seq_group
)
budget
=
create_token_budget
()
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
...
...
@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
Test token budget respected.
"""
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
token_budget
=
0
)
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# 0 token budget == nothing is scheduled.
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
...
...
@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
# 60 token budget == 1 request scheduled.
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
60
...
...
@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
# Test when current_batched_tokens respected.
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
(
token_budget
=
60
)
add_token_budget
(
budget
,
30
,
0
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
# Cannot schedule a prompt that doesn't fit the budget.
waiting
.
append
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
scheduler
.
add_seq_group
(
seq_group
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
30
...
...
@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
assert
len
(
remaining_waiting
)
==
1
budget
=
create_token_budget
(
token_budget
=
90
)
add_token_budget
(
budget
,
30
,
0
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
90
assert
budget
.
num_curr_seqs
==
1
...
...
@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
Test max seq respected.
"""
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
max_num_seqs
=
2
)
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
scheduler
.
add_seq_group
(
seq_group
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
...
...
@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
assert
len
(
remaining_waiting
)
==
1
# Verify curr_num_seqs respected.
waiting
=
deque
()
scheduler
.
waiting
=
deque
()
budget
=
create_token_budget
(
max_num_seqs
=
2
)
add_token_budget
(
budget
,
0
,
2
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
scheduler
.
add_seq_group
(
seq_group
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
...
...
@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
"""
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
token_budget
=
120
)
curr_loras
:
Set
[
int
]
=
set
()
for
i
in
range
(
2
):
...
...
@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
lora_name
=
str
(
i
),
lora_int_id
=
i
+
1
,
lora_path
=
"abc"
))
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled.
...
...
@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
# prioritized. Verify that.
for
i
in
range
(
2
,
4
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# Schedule 2 requests (0 and 2)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
curr_loras
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
curr_loras
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
...
...
@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
# Reset curr_loras so that it can be scheduled.
curr_loras
=
set
()
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
remaining_waiting
,
budget
,
curr_loras
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
curr_loras
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
seq_groups
)
==
1
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"1"
assert
len
(
remaining_waiting
)
==
1
...
...
@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
Test sequence cannot be scheduled due to block manager has no capacity.
"""
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
LATER
remainig_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remainig_waiting
)
==
3
assert
len
(
remaini
n
g_waiting
)
==
3
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
NEVER
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
waiting
,
budget
,
None
)
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
3
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
...
...
@@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
Test decodes cannot be scheduled and preempted.
"""
scheduler
=
initialize_scheduler
()
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
...
...
@@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
# 1 cannot be scheduled, and the lowest priority (request 2)
# should be preempted. 1 will also be preempted.
budget
=
create_token_budget
()
remainig_running
,
output
=
scheduler
.
_schedule_running
(
running
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
remainig_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
...
@@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
Test best_of > 1 swap out blocks
"""
scheduler
=
initialize_scheduler
()
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
budget
=
create_token_budget
()
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
seq_group
.
get_max_num_running_seqs
())
...
...
@@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
expected_swap_mapping
=
[(
"5"
,
"7"
)]
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
remainig_running
,
output
=
scheduler
.
_schedule_running
(
running
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
remainig_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
2
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
...
@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
"""
scheduler
=
initialize_scheduler
()
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
budget
=
create_token_budget
()
remaining_running
,
output
=
scheduler
.
_schedule_running
(
running
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
remaining_running
=
scheduler
.
running
assert
len
(
remaining_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
...
@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
def
test_schedule_swapped_simple
():
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
...
...
@@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
def
test_schedule_swapped_max_token_budget
():
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
...
...
@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
(
token_budget
=
1
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
...
...
@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
# Verify num_batched_tokens are respected.
budget
=
create_token_budget
(
token_budget
=
1
)
add_token_budget
(
budget
,
1
,
0
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
remaining_swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
0
...
...
@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
def
test_schedule_swapped_max_seqs
():
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
4
):
...
...
@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
...
...
@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
assert
len
(
output
.
prefill_seq_groups
)
==
0
# Verify num_curr_seqs are respected.
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
remaining_swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
...
...
@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
def
test_schedule_swapped_max_loras
():
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
:
Set
[
int
]
=
set
()
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
2
):
...
...
@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
1
...
...
@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
def
test_schedule_swapped_cannot_swap_in
():
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
...
...
@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
LATER
# Since we cannot swap in, none of the requests are swapped in.
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
...
...
@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
def
test_infeasible_swap
():
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
...
...
@@ -815,15 +790,15 @@ def test_infeasible_swap():
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
NEVER
# Since we cannot swap in, none of the requests are swapped in.
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
infeasible_seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
0
...
...
@@ -834,23 +809,21 @@ def test_infeasible_swap():
def
test_schedule_swapped_blocks_to_copy
():
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
swapped
,
budget
,
curr_loras
,
policy
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
...
vllm/core/policy.py
deleted
100644 → 0
View file @
3c10591e
from
collections
import
deque
from
typing
import
Deque
from
vllm.sequence
import
SequenceGroup
class
Policy
:
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
raise
NotImplementedError
def
sort_by_priority
(
self
,
now
:
float
,
seq_groups
:
Deque
[
SequenceGroup
],
)
->
Deque
[
SequenceGroup
]:
return
deque
(
sorted
(
seq_groups
,
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
reverse
=
True
,
))
class
FCFS
(
Policy
):
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
return
now
-
seq_group
.
metrics
.
arrival_time
class
PolicyFactory
:
_POLICY_REGISTRY
=
{
'fcfs'
:
FCFS
}
@
classmethod
def
get_policy
(
cls
,
policy_name
:
str
,
**
kwargs
)
->
Policy
:
return
cls
.
_POLICY_REGISTRY
[
policy_name
](
**
kwargs
)
vllm/core/scheduler.py
View file @
c8a7e932
...
...
@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
...
@@ -345,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
def
_add_seq_group_to_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the running queue.
# Only for testing purposes.
self
.
running
.
append
(
seq_group
)
def
_add_seq_group_to_swapped
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self
.
swapped
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a sequence group with the given ID.
...
...
@@ -398,32 +407,26 @@ class Scheduler:
def
_schedule_running
(
self
,
running_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerRunningOutputs
]
:
)
->
SchedulerRunningOutputs
:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs.
SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
...
...
@@ -436,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
now
=
time
.
time
()
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
running_queue
=
self
.
running
while
running_queue
:
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
...
...
@@ -503,7 +505,7 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
return
running_queue
,
SchedulerRunningOutputs
(
return
SchedulerRunningOutputs
(
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
preempted
=
preempted
,
...
...
@@ -515,12 +517,10 @@ class Scheduler:
def
_schedule_swapped
(
self
,
swapped_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerSwappedInOutputs
]
:
)
->
SchedulerSwappedInOutputs
:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
...
...
@@ -528,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
...
...
@@ -549,10 +545,10 @@ class Scheduler:
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
infeasible_seq_groups
:
List
[
SequenceGroup
]
=
[]
swapped_queue
=
self
.
swapped
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
while
swapped_queue
:
seq_group
=
swapped_queue
[
0
]
...
...
@@ -617,7 +613,7 @@ class Scheduler:
swapped_queue
.
extendleft
(
leftover_swapped
)
return
swapped_queue
,
SchedulerSwappedInOutputs
(
return
SchedulerSwappedInOutputs
(
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
...
@@ -644,11 +640,10 @@ class Scheduler:
def
_schedule_prefills
(
self
,
waiting_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerPrefillOutputs
]
:
)
->
SchedulerPrefillOutputs
:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
...
...
@@ -660,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
...
...
@@ -672,14 +665,12 @@ class Scheduler:
all tokens.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue
=
deque
([
s
for
s
in
waiting_queue
])
waiting_queue
=
self
.
waiting
leftover_waiting_sequences
:
Deque
[
SequenceGroup
]
=
deque
()
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
...
...
@@ -758,7 +749,7 @@ class Scheduler:
if
len
(
seq_groups
)
>
0
:
self
.
prev_prompt
=
True
return
waiting_queue
,
SchedulerPrefillOutputs
(
return
SchedulerPrefillOutputs
(
seq_groups
=
seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
...
...
@@ -785,53 +776,43 @@ class Scheduler:
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
if
seq_group
.
lora_int_id
>
0
)
if
self
.
lora_enabled
else
None
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
running_scheduled
=
SchedulerRunningOutputs
.
create_empty
()
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
# If any requests are swapped, prioritized swapped requests.
if
not
self
.
swapped
:
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
False
)
prefills
=
self
.
_schedule_prefills
(
budget
,
curr_loras
,
enable_chunking
=
False
)
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if
len
(
prefills
.
seq_groups
)
==
0
:
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
self
.
running
,
budget
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
False
)
running_scheduled
=
self
.
_schedule_running
(
budget
,
curr_loras
,
enable_chunking
=
False
)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
))
...
...
@@ -877,42 +858,32 @@ class Scheduler:
)
curr_loras
:
Set
[
int
]
=
set
()
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
# Decoding should be always scheduled first by fcfs.
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
self
.
running
,
budget
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
True
)
running_scheduled
=
self
.
_schedule_running
(
budget
,
curr_loras
,
enable_chunking
=
True
)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
# Schedule new prefills.
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
True
)
prefills
=
self
.
_schedule_prefills
(
budget
,
curr_loras
,
enable_chunking
=
True
)
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
...
...
@@ -923,7 +894,6 @@ class Scheduler:
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment