Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
05921a9a
Unverified
Commit
05921a9a
authored
Jan 07, 2024
by
Nadav Shmayovits
Committed by
GitHub
Jan 07, 2024
Browse files
Changed scheduler to use deques instead of lists (#2290)
Co-authored-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
d0215a58
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
24 deletions
+28
-24
vllm/core/policy.py
vllm/core/policy.py
+10
-8
vllm/core/scheduler.py
vllm/core/scheduler.py
+14
-14
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-2
No files found.
vllm/core/policy.py
View file @
05921a9a
from
typing
import
List
from
collections
import
deque
from
typing
import
Deque
from
vllm.sequence
import
SequenceGroup
from
vllm.sequence
import
SequenceGroup
...
@@ -15,13 +16,14 @@ class Policy:
...
@@ -15,13 +16,14 @@ class Policy:
def
sort_by_priority
(
def
sort_by_priority
(
self
,
self
,
now
:
float
,
now
:
float
,
seq_groups
:
List
[
SequenceGroup
],
seq_groups
:
Deque
[
SequenceGroup
],
)
->
List
[
SequenceGroup
]:
)
->
Deque
[
SequenceGroup
]:
return
sorted
(
return
deque
(
seq_groups
,
sorted
(
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
seq_groups
,
reverse
=
True
,
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
)
reverse
=
True
,
))
class
FCFS
(
Policy
):
class
FCFS
(
Policy
):
...
...
vllm/core/scheduler.py
View file @
05921a9a
from
collections
import
deque
import
enum
import
enum
import
time
import
time
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.core.block_manager
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.block_manager
import
AllocStatus
,
BlockSpaceManager
...
@@ -29,7 +30,7 @@ class SchedulerOutputs:
...
@@ -29,7 +30,7 @@ class SchedulerOutputs:
def
__init__
(
def
__init__
(
self
,
self
,
scheduled_seq_groups
:
List
[
SequenceGroup
],
scheduled_seq_groups
:
Iterable
[
SequenceGroup
],
prompt_run
:
bool
,
prompt_run
:
bool
,
num_batched_tokens
:
int
,
num_batched_tokens
:
int
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
...
@@ -75,13 +76,12 @@ class Scheduler:
...
@@ -75,13 +76,12 @@ class Scheduler:
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
sliding_window
=
self
.
cache_config
.
sliding_window
)
sliding_window
=
self
.
cache_config
.
sliding_window
)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.
# Sequence groups in the WAITING state.
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
self
.
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the RUNNING state.
# Sequence groups in the RUNNING state.
self
.
running
:
List
[
SequenceGroup
]
=
[]
self
.
running
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the SWAPPED state.
# Sequence groups in the SWAPPED state.
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
...
@@ -152,7 +152,7 @@ class Scheduler:
...
@@ -152,7 +152,7 @@ class Scheduler:
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
self
.
waiting
.
pop
left
()
continue
continue
# If the sequence group cannot be allocated, stop.
# If the sequence group cannot be allocated, stop.
...
@@ -166,7 +166,7 @@ class Scheduler:
...
@@ -166,7 +166,7 @@ class Scheduler:
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
self
.
waiting
.
pop
left
()
continue
continue
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
...
@@ -188,7 +188,7 @@ class Scheduler:
...
@@ -188,7 +188,7 @@ class Scheduler:
break
break
seq_lens
=
new_seq_lens
seq_lens
=
new_seq_lens
seq_group
=
self
.
waiting
.
pop
(
0
)
seq_group
=
self
.
waiting
.
pop
left
()
self
.
_allocate
(
seq_group
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
...
@@ -214,14 +214,14 @@ class Scheduler:
...
@@ -214,14 +214,14 @@ class Scheduler:
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
# Reserve new token slots for the running sequence groups.
# Reserve new token slots for the running sequence groups.
running
:
List
[
SequenceGroup
]
=
[]
running
:
Deque
[
SequenceGroup
]
=
deque
()
preempted
:
List
[
SequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
while
self
.
running
:
while
self
.
running
:
seq_group
=
self
.
running
.
pop
(
0
)
seq_group
=
self
.
running
.
pop
left
()
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
if
self
.
running
:
if
self
.
running
:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence groups.
victim_seq_group
=
self
.
running
.
pop
(
-
1
)
victim_seq_group
=
self
.
running
.
pop
()
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
preempted
.
append
(
victim_seq_group
)
preempted
.
append
(
victim_seq_group
)
else
:
else
:
...
@@ -255,7 +255,7 @@ class Scheduler:
...
@@ -255,7 +255,7 @@ class Scheduler:
self
.
scheduler_config
.
max_num_seqs
):
self
.
scheduler_config
.
max_num_seqs
):
break
break
seq_group
=
self
.
swapped
.
pop
(
0
)
seq_group
=
self
.
swapped
.
pop
left
()
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
...
@@ -376,7 +376,7 @@ class Scheduler:
...
@@ -376,7 +376,7 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
# NOTE: For FCFS, we insert the preempted sequence group to the front
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
# of the waiting queue.
self
.
waiting
.
insert
(
0
,
seq_group
)
self
.
waiting
.
appendleft
(
seq_group
)
def
_preempt_by_swap
(
def
_preempt_by_swap
(
self
,
self
,
...
...
vllm/engine/llm_engine.py
View file @
05921a9a
...
@@ -601,8 +601,10 @@ class LLMEngine:
...
@@ -601,8 +601,10 @@ class LLMEngine:
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
(
scheduled_seq_groups
+
for
seq_group
in
scheduled_seq_groups
:
scheduler_outputs
.
ignored_seq_groups
):
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment