Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8a7e932
Unverified
Commit
c8a7e932
authored
Jul 31, 2024
by
youkaichao
Committed by
GitHub
Jul 31, 2024
Browse files
[core][scheduler] simplify and improve scheduler (#6867)
parent
3c10591e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
214 deletions
+112
-214
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+1
-1
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+68
-95
vllm/core/policy.py
vllm/core/policy.py
+0
-45
vllm/core/scheduler.py
vllm/core/scheduler.py
+43
-73
No files found.
tests/core/block/e2e/test_correctness.py
View file @
c8a7e932
...
@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
...
@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case.
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
# Note 16 = 128/block_size
"num_gpu_blocks_override"
:
2
*
(
16
+
1
),
"num_gpu_blocks_override"
:
2
*
(
16
+
2
),
}
}
])
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
...
...
tests/core/test_scheduler.py
View file @
c8a7e932
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
typing
import
Deque
,
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
# noqa
import
pytest
# noqa
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
Logprob
,
SequenceGroup
,
SequenceStatus
...
@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
...
@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
"""
"""
scheduler
=
initialize_scheduler
(
max_model_len
=
30
)
scheduler
=
initialize_scheduler
(
max_model_len
=
30
)
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
=
60
)
waiting
=
deque
([
seq_group
]
)
scheduler
.
add_seq_group
(
seq_group
)
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
1
assert
len
(
output
.
ignored_seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
...
@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
Test token budget respected.
Test token budget respected.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
token_budget
=
0
)
budget
=
create_token_budget
(
token_budget
=
0
)
for
i
in
range
(
2
):
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# 0 token budget == nothing is scheduled.
# 0 token budget == nothing is scheduled.
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
...
@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
# 60 token budget == 1 request scheduled.
# 60 token budget == 1 request scheduled.
budget
=
create_token_budget
(
token_budget
=
60
)
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
60
assert
budget
.
num_batched_tokens
==
60
...
@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
...
@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
# Test when current_batched_tokens respected.
# Test when current_batched_tokens respected.
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
(
token_budget
=
60
)
budget
=
create_token_budget
(
token_budget
=
60
)
add_token_budget
(
budget
,
30
,
0
)
add_token_budget
(
budget
,
30
,
0
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
# Cannot schedule a prompt that doesn't fit the budget.
# Cannot schedule a prompt that doesn't fit the budget.
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
30
assert
budget
.
num_batched_tokens
==
30
...
@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
...
@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
assert
len
(
remaining_waiting
)
==
1
assert
len
(
remaining_waiting
)
==
1
budget
=
create_token_budget
(
token_budget
=
90
)
budget
=
create_token_budget
(
token_budget
=
90
)
add_token_budget
(
budget
,
30
,
0
)
add_token_budget
(
budget
,
30
,
0
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
1
assert
budget
.
num_batched_tokens
==
90
assert
budget
.
num_batched_tokens
==
90
assert
budget
.
num_curr_seqs
==
1
assert
budget
.
num_curr_seqs
==
1
...
@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
...
@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
Test max seq respected.
Test max seq respected.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
max_num_seqs
=
2
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
assert
budget
.
num_batched_tokens
==
120
...
@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
...
@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
assert
len
(
remaining_waiting
)
==
1
assert
len
(
remaining_waiting
)
==
1
# Verify curr_num_seqs respected.
# Verify curr_num_seqs respected.
waiting
=
deque
()
scheduler
.
waiting
=
deque
()
budget
=
create_token_budget
(
max_num_seqs
=
2
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
add_token_budget
(
budget
,
0
,
2
)
add_token_budget
(
budget
,
0
,
2
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
...
@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
"""
"""
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
(
token_budget
=
120
)
budget
=
create_token_budget
(
token_budget
=
120
)
curr_loras
:
Set
[
int
]
=
set
()
curr_loras
:
Set
[
int
]
=
set
()
for
i
in
range
(
2
):
for
i
in
range
(
2
):
...
@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
...
@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
lora_name
=
str
(
i
),
lora_name
=
str
(
i
),
lora_int_id
=
i
+
1
,
lora_int_id
=
i
+
1
,
lora_path
=
"abc"
))
lora_path
=
"abc"
))
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# Add two more requests to verify lora is prioritized.
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
# 0: Lora, 1: Lora, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled.
# In the first iteration, index 0, 2 is scheduled.
...
@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
...
@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
# prioritized. Verify that.
# prioritized. Verify that.
for
i
in
range
(
2
,
4
):
for
i
in
range
(
2
,
4
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
# Schedule 2 requests (0 and 2)
# Schedule 2 requests (0 and 2)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
curr_loras
)
waiting
,
budget
,
curr_loras
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
2
assert
len
(
output
.
seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
120
assert
budget
.
num_batched_tokens
==
120
...
@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
...
@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
# Reset curr_loras so that it can be scheduled.
# Reset curr_loras so that it can be scheduled.
curr_loras
=
set
()
curr_loras
=
set
()
budget
=
create_token_budget
(
token_budget
=
60
)
budget
=
create_token_budget
(
token_budget
=
60
)
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
curr_loras
)
remaining_waiting
,
budget
,
curr_loras
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
seq_groups
)
==
1
assert
len
(
output
.
seq_groups
)
==
1
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"1"
assert
output
.
seq_groups
[
0
].
seq_group
.
request_id
==
"1"
assert
len
(
remaining_waiting
)
==
1
assert
len
(
remaining_waiting
)
==
1
...
@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
...
@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
Test sequence cannot be scheduled due to block manager has no capacity.
Test sequence cannot be scheduled due to block manager has no capacity.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
budget
=
create_token_budget
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
LATER
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
LATER
remainig_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
ignored_seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
assert
len
(
remainig_waiting
)
==
3
assert
len
(
remaini
n
g_waiting
)
==
3
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
waiting
=
deque
()
budget
=
create_token_budget
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
waiting
.
append
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
=
MagicMock
()
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
NEVER
scheduler
.
block_manager
.
can_allocate
.
return_value
=
AllocStatus
.
NEVER
remaining_waiting
,
output
=
scheduler
.
_schedule_prefills
(
output
=
scheduler
.
_schedule_prefills
(
budget
,
None
)
waiting
,
budget
,
None
)
remaining_waiting
=
scheduler
.
waiting
assert
len
(
output
.
ignored_seq_groups
)
==
3
assert
len
(
output
.
ignored_seq_groups
)
==
3
assert
len
(
output
.
seq_groups
)
==
0
assert
len
(
output
.
seq_groups
)
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
...
@@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
Test decodes cannot be scheduled and preempted.
Test decodes cannot be scheduled and preempted.
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
...
@@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
...
@@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
# 1 cannot be scheduled, and the lowest priority (request 2)
# 1 cannot be scheduled, and the lowest priority (request 2)
# should be preempted. 1 will also be preempted.
# should be preempted. 1 will also be preempted.
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remainig_running
,
output
=
scheduler
.
_schedule_running
(
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
running
,
budget
,
curr_loras
,
policy
)
remainig_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
@@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
...
@@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
Test best_of > 1 swap out blocks
Test best_of > 1 swap out blocks
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
budget
=
create_token_budget
()
budget
=
create_token_budget
()
for
i
in
range
(
3
):
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
budget
.
add_num_seqs
(
seq_group
.
request_id
,
seq_group
.
get_max_num_running_seqs
())
seq_group
.
get_max_num_running_seqs
())
...
@@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
...
@@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
expected_swap_mapping
=
[(
"5"
,
"7"
)]
expected_swap_mapping
=
[(
"5"
,
"7"
)]
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
remainig_running
,
output
=
scheduler
.
_schedule_running
(
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
running
,
budget
,
curr_loras
,
policy
)
remainig_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
2
assert
len
(
output
.
decode_seq_groups
)
==
2
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
...
@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
"""
"""
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
running
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
running
.
append
(
seq_group
)
scheduler
.
_add_seq_group_to_
running
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_running
,
output
=
scheduler
.
_schedule_running
(
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
running
,
budget
,
curr_loras
,
policy
)
remaining_running
=
scheduler
.
running
assert
len
(
remaining_running
)
==
0
assert
len
(
remaining_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
...
@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
def
test_schedule_swapped_simple
():
def
test_schedule_swapped_simple
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
remaining_swapped
)
==
0
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
...
@@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
def
test_schedule_swapped_max_token_budget
():
def
test_schedule_swapped_max_token_budget
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
...
@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
...
@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
(
token_budget
=
1
)
budget
=
create_token_budget
(
token_budget
=
1
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
...
@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
# Verify num_batched_tokens are respected.
# Verify num_batched_tokens are respected.
budget
=
create_token_budget
(
token_budget
=
1
)
budget
=
create_token_budget
(
token_budget
=
1
)
add_token_budget
(
budget
,
1
,
0
)
add_token_budget
(
budget
,
1
,
0
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
...
@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
...
@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
def
test_schedule_swapped_max_seqs
():
def
test_schedule_swapped_max_seqs
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
4
):
for
i
in
range
(
4
):
...
@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
...
@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
...
@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
# Verify num_curr_seqs are respected.
# Verify num_curr_seqs are respected.
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
assert
budget
.
num_curr_seqs
==
2
...
@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
...
@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
def
test_schedule_swapped_max_loras
():
def
test_schedule_swapped_max_loras
():
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
scheduler
=
initialize_scheduler
(
lora_config
=
lora_config
)
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
:
Set
[
int
]
=
set
()
curr_loras
:
Set
[
int
]
=
set
()
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
...
@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
...
@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
1
assert
budget
.
num_curr_seqs
==
1
...
@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
...
@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
def
test_schedule_swapped_cannot_swap_in
():
def
test_schedule_swapped_cannot_swap_in
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
...
@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
...
@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
LATER
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
LATER
# Since we cannot swap in, none of the requests are swapped in.
# Since we cannot swap in, none of the requests are swapped in.
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
...
@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
...
@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
def
test_infeasible_swap
():
def
test_infeasible_swap
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
...
@@ -815,15 +790,15 @@ def test_infeasible_swap():
...
@@ -815,15 +790,15 @@ def test_infeasible_swap():
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
NEVER
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
NEVER
# Since we cannot swap in, none of the requests are swapped in.
# Since we cannot swap in, none of the requests are swapped in.
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
infeasible_seq_groups
)
==
2
assert
len
(
output
.
infeasible_seq_groups
)
==
2
assert
budget
.
num_batched_tokens
==
0
assert
budget
.
num_batched_tokens
==
0
...
@@ -834,23 +809,21 @@ def test_infeasible_swap():
...
@@ -834,23 +809,21 @@ def test_infeasible_swap():
def
test_schedule_swapped_blocks_to_copy
():
def
test_schedule_swapped_blocks_to_copy
():
scheduler
=
initialize_scheduler
()
scheduler
=
initialize_scheduler
()
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
s
wapped
.
appe
n
d
(
seq_group
)
s
cheduler
.
_add_seq_group_to_sw
apped
(
seq_group
)
# The last request should be swapped out.
# The last request should be swapped out.
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
=
MagicMock
()
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
scheduler
.
block_manager
.
append_slots
.
return_value
=
[(
2
,
3
)]
budget
=
create_token_budget
()
budget
=
create_token_budget
()
remaining_swapped
,
output
=
scheduler
.
_schedule_swapped
(
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped
,
budget
,
curr_loras
,
policy
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
...
...
vllm/core/policy.py
deleted
100644 → 0
View file @
3c10591e
from
collections
import
deque
from
typing
import
Deque
from
vllm.sequence
import
SequenceGroup
class
Policy
:
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
raise
NotImplementedError
def
sort_by_priority
(
self
,
now
:
float
,
seq_groups
:
Deque
[
SequenceGroup
],
)
->
Deque
[
SequenceGroup
]:
return
deque
(
sorted
(
seq_groups
,
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
reverse
=
True
,
))
class
FCFS
(
Policy
):
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
return
now
-
seq_group
.
metrics
.
arrival_time
class
PolicyFactory
:
_POLICY_REGISTRY
=
{
'fcfs'
:
FCFS
}
@
classmethod
def
get_policy
(
cls
,
policy_name
:
str
,
**
kwargs
)
->
Policy
:
return
cls
.
_POLICY_REGISTRY
[
policy_name
](
**
kwargs
)
vllm/core/scheduler.py
View file @
c8a7e932
...
@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
...
@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
@@ -345,6 +344,16 @@ class Scheduler:
...
@@ -345,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
_add_seq_group_to_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the running queue.
# Only for testing purposes.
self
.
running
.
append
(
seq_group
)
def
_add_seq_group_to_swapped
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self
.
swapped
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a sequence group with the given ID.
"""Aborts a sequence group with the given ID.
...
@@ -398,32 +407,26 @@ class Scheduler:
...
@@ -398,32 +407,26 @@ class Scheduler:
def
_schedule_running
(
def
_schedule_running
(
self
,
self
,
running_queue
:
deque
,
budget
:
SchedulingBudget
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerRunningOutputs
]
:
)
->
SchedulerRunningOutputs
:
"""Schedule sequence groups that are running.
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Running queue should include decode and chunked prefill requests.
Args:
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
all tokens.
Returns:
Returns:
A tuple of remaining running queue (should be always 0) after
SchedulerRunningOutputs.
scheduling and SchedulerRunningOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
...
@@ -436,10 +439,9 @@ class Scheduler:
...
@@ -436,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
running_queue
=
self
.
running
now
=
time
.
time
()
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
while
running_queue
:
while
running_queue
:
seq_group
=
running_queue
[
0
]
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
num_running_tokens
=
self
.
_get_num_new_tokens
(
...
@@ -503,7 +505,7 @@ class Scheduler:
...
@@ -503,7 +505,7 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
curr_loras
.
add
(
seq_group
.
lora_int_id
)
return
running_queue
,
SchedulerRunningOutputs
(
return
SchedulerRunningOutputs
(
decode_seq_groups
=
decode_seq_groups
,
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
preempted
=
preempted
,
preempted
=
preempted
,
...
@@ -515,12 +517,10 @@ class Scheduler:
...
@@ -515,12 +517,10 @@ class Scheduler:
def
_schedule_swapped
(
def
_schedule_swapped
(
self
,
self
,
swapped_queue
:
deque
,
budget
:
SchedulingBudget
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerSwappedInOutputs
]
:
)
->
SchedulerSwappedInOutputs
:
"""Schedule sequence groups that are swapped out.
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
It schedules swapped requests as long as it fits `budget` and
...
@@ -528,20 +528,16 @@ class Scheduler:
...
@@ -528,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
all tokens.
Returns:
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
SchedulerSwappedInOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
...
@@ -549,10 +545,10 @@ class Scheduler:
...
@@ -549,10 +545,10 @@ class Scheduler:
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
infeasible_seq_groups
:
List
[
SequenceGroup
]
=
[]
infeasible_seq_groups
:
List
[
SequenceGroup
]
=
[]
swapped_queue
=
self
.
swapped
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
while
swapped_queue
:
while
swapped_queue
:
seq_group
=
swapped_queue
[
0
]
seq_group
=
swapped_queue
[
0
]
...
@@ -617,7 +613,7 @@ class Scheduler:
...
@@ -617,7 +613,7 @@ class Scheduler:
swapped_queue
.
extendleft
(
leftover_swapped
)
swapped_queue
.
extendleft
(
leftover_swapped
)
return
swapped_queue
,
SchedulerSwappedInOutputs
(
return
SchedulerSwappedInOutputs
(
decode_seq_groups
=
decode_seq_groups
,
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
@@ -644,11 +640,10 @@ class Scheduler:
...
@@ -644,11 +640,10 @@ class Scheduler:
def
_schedule_prefills
(
def
_schedule_prefills
(
self
,
self
,
waiting_queue
:
deque
,
budget
:
SchedulingBudget
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
curr_loras
:
Optional
[
Set
[
int
]],
enable_chunking
:
bool
=
False
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerPrefillOutputs
]
:
)
->
SchedulerPrefillOutputs
:
"""Schedule sequence groups that are in prefill stage.
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
...
@@ -660,8 +655,6 @@ class Scheduler:
...
@@ -660,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
curr_loras: Currently batched lora request ids. The argument is
...
@@ -672,14 +665,12 @@ class Scheduler:
...
@@ -672,14 +665,12 @@ class Scheduler:
all tokens.
all tokens.
Returns:
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
SchedulerSwappedInOutputs.
"""
"""
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue
=
self
.
waiting
waiting_queue
=
deque
([
s
for
s
in
waiting_queue
])
leftover_waiting_sequences
:
Deque
[
SequenceGroup
]
=
deque
()
leftover_waiting_sequences
:
Deque
[
SequenceGroup
]
=
deque
()
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
...
@@ -758,7 +749,7 @@ class Scheduler:
...
@@ -758,7 +749,7 @@ class Scheduler:
if
len
(
seq_groups
)
>
0
:
if
len
(
seq_groups
)
>
0
:
self
.
prev_prompt
=
True
self
.
prev_prompt
=
True
return
waiting_queue
,
SchedulerPrefillOutputs
(
return
SchedulerPrefillOutputs
(
seq_groups
=
seq_groups
,
seq_groups
=
seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
...
@@ -785,53 +776,43 @@ class Scheduler:
...
@@ -785,53 +776,43 @@ class Scheduler:
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
if
seq_group
.
lora_int_id
>
0
)
if
self
.
lora_enabled
else
None
if
seq_group
.
lora_int_id
>
0
)
if
self
.
lora_enabled
else
None
remaining_waiting
,
prefills
=
(
self
.
waiting
,
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
SchedulerPrefillOutputs
.
create_empty
())
running_scheduled
=
SchedulerRunningOutputs
.
create_empty
()
remaining_running
,
running_scheduled
=
(
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
# If any requests are swapped, prioritized swapped requests.
# If any requests are swapped, prioritized swapped requests.
if
not
self
.
swapped
:
if
not
self
.
swapped
:
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
prefills
=
self
.
_schedule_prefills
(
budget
,
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
False
)
curr_loras
,
enable_chunking
=
False
)
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
# Don't schedule decodes if prefills are scheduled.
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
# only contains decode requests, not chunked prefills.
if
len
(
prefills
.
seq_groups
)
==
0
:
if
len
(
prefills
.
seq_groups
)
==
0
:
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
running_scheduled
=
self
.
_schedule_running
(
budget
,
self
.
running
,
budget
,
curr_loras
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
False
)
enable_chunking
=
False
)
# If any sequence group is preempted, do not swap in any sequence
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
# group. because it means there's no slot for new running requests.
if
len
(
running_scheduled
.
preempted
)
+
len
(
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
assert
(
budget
.
num_batched_tokens
<=
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
# Update swapped requests.
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
))
len
(
running_scheduled
.
swapped_out
))
...
@@ -877,42 +858,32 @@ class Scheduler:
...
@@ -877,42 +858,32 @@ class Scheduler:
)
)
curr_loras
:
Set
[
int
]
=
set
()
curr_loras
:
Set
[
int
]
=
set
()
remaining_waiting
,
prefills
=
(
self
.
waiting
,
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
SchedulerPrefillOutputs
.
create_empty
())
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
# Decoding should be always scheduled first by fcfs.
# Decoding should be always scheduled first by fcfs.
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
running_scheduled
=
self
.
_schedule_running
(
budget
,
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
self
.
running
,
budget
,
curr_loras
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
True
)
enable_chunking
=
True
)
# Schedule swapped out requests.
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
# If preemption happens, it means we don't have space for swap-in.
if
len
(
running_scheduled
.
preempted
)
+
len
(
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
# Schedule new prefills.
# Schedule new prefills.
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
prefills
=
self
.
_schedule_prefills
(
budget
,
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
True
)
curr_loras
,
enable_chunking
=
True
)
assert
(
budget
.
num_batched_tokens
<=
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
...
@@ -923,7 +894,6 @@ class Scheduler:
...
@@ -923,7 +894,6 @@ class Scheduler:
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
# Update swapped requests.
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
return
SchedulerOutputs
(
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment