Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
53f70e73
Commit
53f70e73
authored
Feb 24, 2023
by
Woosuk Kwon
Browse files
Reduce the number of states in scheduler
parent
762fd1c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
29 deletions
+44
-29
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+44
-29
No files found.
cacheflow/master/scheduler.py
View file @
53f70e73
...
@@ -43,11 +43,6 @@ class Scheduler:
...
@@ -43,11 +43,6 @@ class Scheduler:
# Pending sequence groups (FIFO).
# Pending sequence groups (FIFO).
self
.
pending
:
List
[
SequenceGroup
]
=
[]
self
.
pending
:
List
[
SequenceGroup
]
=
[]
# Blocks that need to be swaped or copied before model execution.
self
.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
self
.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
self
.
blocks_to_copy
:
Dict
[
int
,
int
]
=
{}
def
_free_seq
(
self
,
seq
:
Sequence
)
->
None
:
def
_free_seq
(
self
,
seq
:
Sequence
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED
seq
.
status
=
SequenceStatus
.
FINISHED
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
...
@@ -57,36 +52,53 @@ class Scheduler:
...
@@ -57,36 +52,53 @@ class Scheduler:
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
# FIXME
# FIXME
(woosuk): Support interactive generation.
self
.
num_steps
[
seq_group
.
group_id
]
=
0
self
.
num_steps
[
seq_group
.
group_id
]
=
0
def
_append
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_append
(
self
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
int
],
)
->
None
:
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
continue
ret
=
self
.
block_manager
.
append
(
seq
)
ret
=
self
.
block_manager
.
append
(
seq
)
if
ret
is
not
None
:
if
ret
is
not
None
:
src_block
,
dst_block
=
ret
src_block
,
dst_block
=
ret
self
.
blocks_to_copy
[
src_block
]
=
dst_block
blocks_to_copy
[
src_block
]
=
dst_block
def
_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_swap_in
(
self
,
seq_group
:
SequenceGroup
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
)
->
None
:
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
self
.
blocks_to_swap_in
.
update
(
mapping
)
blocks_to_swap_in
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
SWAPPED
:
if
seq
.
status
==
SequenceStatus
.
SWAPPED
:
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
def
_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_swap_out
(
self
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
)
->
None
:
assert
self
.
block_manager
.
can_swap_out
(
seq_group
)
assert
self
.
block_manager
.
can_swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
self
.
blocks_to_swap_out
.
update
(
mapping
)
blocks_to_swap_out
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
RUNNING
:
if
seq
.
status
==
SequenceStatus
.
RUNNING
:
seq
.
status
=
SequenceStatus
.
SWAPPED
seq
.
status
=
SequenceStatus
.
SWAPPED
self
.
swapped
.
append
(
seq_group
)
self
.
swapped
.
append
(
seq_group
)
def
prepare
(
self
)
->
None
:
def
pre_step
(
self
)
->
None
:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
int
]
=
{}
# 1. Prepare new slots for the running sequences.
# 1. Prepare new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
# NOTE: Here we implicitly assume FCFS scheduling.
# That is, the most recently added sequence group is the first
# That is, the most recently added sequence group is the first
...
@@ -99,13 +111,13 @@ class Scheduler:
...
@@ -99,13 +111,13 @@ class Scheduler:
# OOM. Swap out the victim sequence groups.
# OOM. Swap out the victim sequence groups.
while
not
self
.
block_manager
.
can_append
(
seq_group
):
while
not
self
.
block_manager
.
can_append
(
seq_group
):
victim_seq_group
=
self
.
running
[
victim_idx
]
victim_seq_group
=
self
.
running
[
victim_idx
]
self
.
_swap_out
(
victim_seq_group
)
self
.
_swap_out
(
victim_seq_group
,
blocks_to_swap_out
)
victim_idx
-=
1
victim_idx
-=
1
if
i
>
victim_idx
:
if
i
>
victim_idx
:
# No other sequence groups can be swapped out.
# No other sequence groups can be swapped out.
break
break
else
:
else
:
self
.
_append
(
seq_group
)
self
.
_append
(
seq_group
,
blocks_to_copy
)
self
.
running
=
self
.
running
[:
victim_idx
+
1
]
self
.
running
=
self
.
running
[:
victim_idx
+
1
]
# 2. Swap in the swapped sequences if possible.
# 2. Swap in the swapped sequences if possible.
...
@@ -113,8 +125,8 @@ class Scheduler:
...
@@ -113,8 +125,8 @@ class Scheduler:
# The swapped sequences are in LIFO order.
# The swapped sequences are in LIFO order.
for
i
,
seq_group
in
enumerate
(
reversed
(
self
.
swapped
)):
for
i
,
seq_group
in
enumerate
(
reversed
(
self
.
swapped
)):
if
self
.
block_manager
.
can_swap_in
(
seq_group
):
if
self
.
block_manager
.
can_swap_in
(
seq_group
):
self
.
_swap_in
(
seq_group
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append
(
seq_group
)
self
.
_append
(
seq_group
,
blocks_to_copy
)
else
:
else
:
# OOM. Stop swapping.
# OOM. Stop swapping.
self
.
swapped
=
self
.
swapped
[:
len
(
self
.
swapped
)
-
i
]
self
.
swapped
=
self
.
swapped
[:
len
(
self
.
swapped
)
-
i
]
...
@@ -147,10 +159,18 @@ class Scheduler:
...
@@ -147,10 +159,18 @@ class Scheduler:
else
:
else
:
self
.
pending
.
clear
()
self
.
pending
.
clear
()
def
step
(
self
)
->
None
:
# Execute step.
# Ensure that either swap-in or swap-out is performed.
self
.
step
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
if
self
.
blocks_to_swap_in
:
assert
not
self
.
blocks_to_swap_out
def
step
(
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
)
->
None
:
# Ensure that swap-in and swap-out never happen at the same timestep.
if
blocks_to_swap_in
:
assert
not
blocks_to_swap_out
# Create input data structures.
# Create input data structures.
prompt_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
prompt_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
...
@@ -181,16 +201,11 @@ class Scheduler:
...
@@ -181,16 +201,11 @@ class Scheduler:
generation_tokens
,
generation_tokens
,
context_lens
,
context_lens
,
block_tables
,
block_tables
,
self
.
blocks_to_swap_in
.
copy
()
,
blocks_to_swap_in
,
self
.
blocks_to_swap_out
.
copy
()
,
blocks_to_swap_out
,
self
.
blocks_to_copy
.
copy
()
,
blocks_to_copy
,
)
)
# Clear for the next step.
self
.
blocks_to_swap_in
.
clear
()
self
.
blocks_to_swap_out
.
clear
()
self
.
blocks_to_copy
.
clear
()
def
post_step
(
def
post_step
(
self
,
self
,
next_tokens
:
Dict
[
int
,
Tuple
[
int
,
int
]],
next_tokens
:
Dict
[
int
,
Tuple
[
int
,
int
]],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment