Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1132fae0
Commit
1132fae0
authored
Feb 24, 2023
by
Woosuk Kwon
Browse files
Add Frontend
parent
46ce1356
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
17 deletions
+80
-17
cacheflow/master/frontend.py
cacheflow/master/frontend.py
+56
-0
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+24
-17
No files found.
cacheflow/master/frontend.py
0 → 100644
View file @
1132fae0
from
typing
import
List
,
Optional
,
Tuple
from
transformers
import
AutoTokenizer
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.utils
import
Counter
class
Frontend
:
def
__init__
(
self
,
model_name
:
str
,
block_size
:
int
,
)
->
None
:
self
.
block_size
=
block_size
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
self
.
seq_group_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
inputs
:
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]]
=
[]
def
query
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
)
->
None
:
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
()
token_ids
:
List
[
int
]
=
self
.
tokenizer
.
encode
(
prompt
)
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
group_id
=
next
(
self
.
seq_group_counter
)
seq_group
=
SequenceGroup
(
group_id
,
seqs
)
self
.
inputs
.
append
((
seq_group
,
sampling_params
))
def
get_inputs
(
self
)
->
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]]:
inputs
=
self
.
inputs
self
.
inputs
=
[]
return
inputs
def
print_response
(
self
,
seq_group
:
SequenceGroup
,
)
->
None
:
for
seq
in
seq_group
.
seqs
:
token_ids
=
seq
.
get_token_ids
()
output
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
print
(
f
'Seq
{
seq
.
seq_id
}
:
{
output
}
'
)
cacheflow/master/scheduler.py
View file @
1132fae0
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
from
cacheflow.master.block_manager
import
BlockSpaceManager
from
cacheflow.master.block_manager
import
BlockSpaceManager
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceStatus
from
cacheflow.sequence
import
SequenceStatus
...
@@ -12,11 +14,13 @@ class Scheduler:
...
@@ -12,11 +14,13 @@ class Scheduler:
def
__init__
(
def
__init__
(
self
,
self
,
frontend
:
Frontend
,
controllers
:
List
,
controllers
:
List
,
block_size
:
int
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_cpu_blocks
:
int
,
)
->
None
:
)
->
None
:
self
.
frontend
=
frontend
self
.
controllers
=
controllers
self
.
controllers
=
controllers
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_gpu_blocks
=
num_gpu_blocks
...
@@ -33,16 +37,20 @@ class Scheduler:
...
@@ -33,16 +37,20 @@ class Scheduler:
self
.
running
:
List
[
SequenceGroup
]
=
[]
self
.
running
:
List
[
SequenceGroup
]
=
[]
# Mapping: group_id -> num_steps.
# Mapping: group_id -> num_steps.
self
.
num_steps
:
Dict
[
int
,
int
]
=
{}
self
.
num_steps
:
Dict
[
int
,
int
]
=
{}
# Mapping: group_id -> max_num_steps.
# Mapping: group_id -> sampling params.
self
.
max_num_steps
:
Dict
[
int
,
int
]
=
{}
self
.
sampling_params
:
Dict
[
int
,
SamplingParams
]
=
{}
# Mapping: group_id -> stop_token_ids.
self
.
stop_token_ids
:
Dict
[
int
,
List
[
int
]]
=
{}
# Swapped sequence groups (LIFO).
# Swapped sequence groups (LIFO).
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
# Pending sequence groups (FIFO).
# Pending sequence groups (FIFO).
self
.
pending
:
List
[
SequenceGroup
]
=
[]
self
.
pending
:
List
[
SequenceGroup
]
=
[]
def
_fetch_inputs
(
self
)
->
None
:
inputs
=
self
.
frontend
.
get_inputs
()
for
seq_group
,
sampling_params
in
inputs
:
self
.
pending
.
append
(
seq_group
)
self
.
sampling_params
[
seq_group
.
group_id
]
=
sampling_params
def
_free_seq
(
self
,
seq
:
Sequence
)
->
None
:
def
_free_seq
(
self
,
seq
:
Sequence
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED
seq
.
status
=
SequenceStatus
.
FINISHED
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
...
@@ -145,6 +153,7 @@ class Scheduler:
...
@@ -145,6 +153,7 @@ class Scheduler:
# TODO(woosuk): Add a batching policy to control the batch size.
# TODO(woosuk): Add a batching policy to control the batch size.
if
not
self
.
swapped
:
if
not
self
.
swapped
:
# FIXME(woosuk): Acquire a lock to protect pending.
# FIXME(woosuk): Acquire a lock to protect pending.
self
.
_fetch_inputs
()
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
if
self
.
block_manager
.
can_allocate
(
seq_group
):
if
self
.
block_manager
.
can_allocate
(
seq_group
):
...
@@ -205,7 +214,7 @@ class Scheduler:
...
@@ -205,7 +214,7 @@ class Scheduler:
for
seq_group
in
self
.
running
:
for
seq_group
in
self
.
running
:
group_id
=
seq_group
.
group_id
group_id
=
seq_group
.
group_id
self
.
num_steps
[
group_id
]
+=
1
self
.
num_steps
[
group_id
]
+=
1
stop_token_ids
=
self
.
stop_token_ids
[
group_id
]
stop_token_ids
=
self
.
s
ampling_params
[
group_id
].
s
top_token_ids
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
...
@@ -230,24 +239,22 @@ class Scheduler:
...
@@ -230,24 +239,22 @@ class Scheduler:
continue
continue
# Check if the sequence has reached the maximum number of steps.
# Check if the sequence has reached the maximum number of steps.
if
self
.
num_steps
[
group_id
]
==
self
.
max_num_steps
[
group_id
]:
max_num_steps
=
self
.
sampling_params
[
group_id
].
max_num_steps
if
self
.
num_steps
[
group_id
]
==
max_num_steps
:
self
.
_free_seq
(
seq
)
self
.
_free_seq
(
seq
)
continue
continue
# Update the running sequences.
# Update the running sequences.
running
:
List
[
SequenceGroup
]
=
[]
running
:
List
[
SequenceGroup
]
=
[]
for
seq_group
in
self
.
running
:
for
seq_group
in
self
.
running
:
if
all
(
seq
.
status
==
SequenceStatus
.
FINISHED
for
seq
in
seq_group
.
seqs
):
if
seq_group
.
is_finished
():
del
self
.
num_steps
[
seq_group
.
group_id
]
self
.
_return
(
seq_group
)
del
self
.
max_num_steps
[
seq_group
.
group_id
]
del
self
.
stop_token_ids
[
seq_group
.
group_id
]
# TODO: Return the seq_group to the client.
from
transformers
import
AutoTokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'facebook/opt-125m'
)
for
seq
in
seq_group
.
seqs
:
token_ids
=
seq
.
get_token_ids
()
output
=
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
print
(
f
'Seq
{
seq
.
seq_id
}
:
{
output
}
'
)
else
:
else
:
running
.
append
(
seq_group
)
running
.
append
(
seq_group
)
self
.
running
=
running
self
.
running
=
running
def
_return
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
group_id
=
seq_group
.
group_id
del
self
.
num_steps
[
group_id
]
del
self
.
sampling_params
[
group_id
]
self
.
frontend
.
print_response
(
seq_group
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment