Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
f746ced0
Unverified
Commit
f746ced0
authored
May 21, 2023
by
Woosuk Kwon
Committed by
GitHub
May 21, 2023
Browse files
Implement stop strings and best_of (#114)
parent
c3442c1f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
162 additions
and
116 deletions
+162
-116
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+6
-6
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+15
-48
cacheflow/entrypoints/fastapi_server.py
cacheflow/entrypoints/fastapi_server.py
+3
-2
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+8
-8
cacheflow/outputs.py
cacheflow/outputs.py
+23
-21
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+29
-11
cacheflow/sequence.py
cacheflow/sequence.py
+21
-5
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+55
-13
examples/simple_server.py
examples/simple_server.py
+2
-2
No files found.
cacheflow/core/block_manager.py
View file @
f746ced0
...
...
@@ -80,7 +80,7 @@ class BlockSpaceManager:
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq
=
seq_group
.
seqs
[
0
]
seq
=
seq_group
.
get_
seqs
()
[
0
]
num_required_blocks
=
len
(
seq
.
logical_token_blocks
)
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
# Use watermark to avoid frequent cache eviction.
...
...
@@ -88,7 +88,7 @@ class BlockSpaceManager:
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# NOTE: Here we assume that all sequences in the group have the same prompt.
seq
=
seq_group
.
seqs
[
0
]
seq
=
seq_group
.
get_
seqs
()
[
0
]
# Allocate new physical token blocks that will store the prompt tokens.
block_table
:
BlockTable
=
[]
...
...
@@ -99,7 +99,7 @@ class BlockSpaceManager:
block_table
.
append
(
block
)
# Assign the block table for each sequence.
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
...
...
@@ -147,7 +147,7 @@ class BlockSpaceManager:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
@@ -168,7 +168,7 @@ class BlockSpaceManager:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
new_block_table
:
BlockTable
=
[]
...
...
@@ -199,7 +199,7 @@ class BlockSpaceManager:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
new_block_table
:
BlockTable
=
[]
...
...
cacheflow/core/scheduler.py
View file @
f746ced0
...
...
@@ -73,8 +73,6 @@ class Scheduler:
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
# Sequence groups in the RUNNING state.
self
.
running
:
List
[
SequenceGroup
]
=
[]
# Mapping: request_id -> num_steps.
self
.
num_steps
:
Dict
[
str
,
int
]
=
{}
# Sequence groups in the SWAPPED state.
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
...
...
@@ -84,7 +82,6 @@ class Scheduler:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
assert
seq_group
.
request_id
not
in
self
.
num_steps
self
.
waiting
.
append
(
seq_group
)
def
has_unfinished_seqs
(
self
)
->
bool
:
...
...
@@ -178,7 +175,7 @@ class Scheduler:
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
num_prompt_tokens
=
seq_group
.
get_
seqs
()
[
0
].
get_len
()
if
(
num_batched_tokens
+
num_prompt_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
...
...
@@ -278,15 +275,8 @@ class Scheduler:
)
->
List
[
SequenceGroup
]:
# Update the running sequences and free blocks.
for
seq_group
in
self
.
running
:
request_id
=
seq_group
.
request_id
self
.
num_steps
[
request_id
]
+=
1
stop_token_ids
=
seq_group
.
sampling_params
.
stop_token_ids
# Process beam search results before processing the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
# Process beam search results before processing the new tokens.
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
output
=
seq_outputs
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
output
.
parent_seq_id
:
# The sequence is a fork of the parent sequence (beam search).
...
...
@@ -297,43 +287,27 @@ class Scheduler:
parent_seq
.
fork
(
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
# Process the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
# Process the new tokens.
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
# Append a new token to the sequence.
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append_token
(
output
.
output_token
,
output
.
logprobs
)
return
self
.
running
.
copy
()
# Check if the sequence has generated a stop token.
if
output
.
output_token
in
stop_token_ids
:
self
.
_free_seq
(
seq
)
continue
def
free_seq
(
self
,
seq
:
Sequence
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED
self
.
block_manager
.
free
(
seq
)
# Check if the sequence has reached the maximum number of steps.
max_num_steps
=
seq_group
.
sampling_params
.
max_tokens
if
self
.
num_steps
[
request_id
]
==
max_num_steps
:
self
.
_free_seq
(
seq
)
continue
# Update the running sequences.
updated
=
self
.
running
.
copy
()
running
:
List
[
SequenceGroup
]
=
[]
for
seq_group
in
self
.
running
:
if
seq_group
.
is_finished
():
self
.
_free_seq_group
(
seq_group
)
else
:
running
.
append
(
seq_group
)
self
.
running
=
running
return
updated
def
free_finished_seq_groups
(
self
)
->
None
:
self
.
running
=
[
seq_group
for
seq_group
in
self
.
running
if
not
seq_group
.
is_finished
()
]
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
seq
.
status
=
SequenceStatus
.
RUNNING
if
seq_group
.
request_id
not
in
self
.
num_steps
:
self
.
num_steps
[
seq_group
.
request_id
]
=
0
def
_append_slot
(
self
,
...
...
@@ -403,13 +377,6 @@ class Scheduler:
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
swapped
.
append
(
seq_group
)
def
_free_seq
(
self
,
seq
:
Sequence
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED
self
.
block_manager
.
free
(
seq
)
def
_free_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
del
self
.
num_steps
[
seq_group
.
request_id
]
def
_swap_in
(
self
,
seq_group
:
SequenceGroup
,
...
...
cacheflow/entrypoints/fastapi_server.py
View file @
f746ced0
...
...
@@ -123,6 +123,7 @@ if __name__ == "__main__":
parallel_config
=
server_configs
[
2
]
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
server
=
FastAPIServer
(
args
.
use_ray
,
*
server_configs
,
distributed_init_method
,
stage_devices
)
server
=
FastAPIServer
(
args
.
use_ray
,
*
server_configs
,
distributed_init_method
,
stage_devices
,
log_stats
=
not
args
.
disable_log_stats
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/model_executor/layers/sampler.py
View file @
f746ced0
...
...
@@ -283,20 +283,20 @@ def _sample_from_prompt(
)
->
List
[
int
]:
if
sampling_params
.
use_beam_search
:
# Beam search.
beam_width
=
sampling_params
.
n
beam_width
=
sampling_params
.
best_of
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
assert
sampling_params
.
n
==
1
assert
sampling_params
.
best_of
==
1
next_token_id
=
torch
.
argmax
(
prob
)
next_token_ids
=
[
next_token_id
.
item
()]
else
:
# Random sampling.
# Sample
n
tokens for the prompt.
n
=
sampling_params
.
n
# Sample
`best_of`
tokens for the prompt.
n
um_seqs
=
sampling_params
.
best_of
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
n
,
replacement
=
True
)
prob
,
num_samples
=
n
um_seqs
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
...
...
@@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
seq_logprobs
:
List
[
float
],
sampling_params
:
SamplingParams
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
# NOTE(woosuk): sampling_params.
n
can be greater than
# NOTE(woosuk): sampling_params.
best_of
can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if
sampling_params
.
use_beam_search
:
...
...
@@ -372,7 +372,7 @@ def _sample(
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# Generate the next tokens for a prompt input.
assert
len
(
seq_ids
)
==
sampling_params
.
n
assert
len
(
seq_ids
)
==
sampling_params
.
best_of
prob
=
probs
[
idx
]
logprob
=
logprobs
[
idx
]
idx
+=
1
...
...
@@ -397,7 +397,7 @@ def _sample(
# Sample the next tokens.
seq_logprobs
=
[
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
s
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
...
...
cacheflow/outputs.py
View file @
f746ced0
from
typing
import
Dict
,
List
,
Union
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing
import
Dict
,
List
from
cacheflow.sequence
import
SequenceGroup
...
...
@@ -9,20 +7,23 @@ class CompletionOutput:
def
__init__
(
self
,
index
:
int
,
text
:
str
,
token_ids
:
List
[
int
],
cumulative_logprob
s
:
float
,
cumulative_logprob
:
float
,
logprobs
:
List
[
Dict
[
int
,
float
]],
)
->
None
:
self
.
index
=
index
self
.
text
=
text
self
.
token_ids
=
token_ids
self
.
cumulative_logprob
s
=
cumulative_logprob
s
self
.
cumulative_logprob
=
cumulative_logprob
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
"CompletionOutput(output=
{
self
.
text
!
r
}
, "
return
(
f
"CompletionOutput(index=
{
self
.
index
}
, "
f
"text=
{
self
.
text
!
r
}
, "
f
"token_ids=
{
self
.
token_ids
}
, "
f
"cumulative_logprob
s
=
{
self
.
cumulative_logprob
s
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
...
...
@@ -43,31 +44,32 @@ class RequestOutput:
self
.
done
=
done
@
staticmethod
def
from_seq_group
(
seq_group
:
SequenceGroup
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
)
->
"RequestOutput"
:
outputs
:
List
[
CompletionOutput
]
=
[]
def
from_seq_group
(
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
seqs
=
seq_group
.
get_seqs
()
for
seq
in
seqs
:
output_token_ids
=
seq
.
data
.
output_token_ids
output_str
=
tokenizer
.
decode
(
output_token_ids
,
skip_special_tokens
=
True
)
seq_logprobs
=
seq
.
data
.
cumulative_logprobs
assert
n
<=
len
(
seqs
)
sorted_seqs
=
sorted
(
seqs
,
key
=
lambda
seq
:
seq
.
get_cumulative_logprob
(),
reverse
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
# Create the outputs.
outputs
:
List
[
CompletionOutput
]
=
[]
for
seq
in
top_n_seqs
:
logprobs
=
seq
.
output_logprobs
if
seq_group
.
sampling_params
.
logprobs
==
0
:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs
=
{}
output
=
CompletionOutput
(
output_str
,
output_token_ids
,
seq_logprobs
,
logprobs
)
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
logprobs
)
outputs
.
append
(
output
)
# Every sequence in the sequence group should have the same prompt.
prompt
=
seqs
[
0
].
prompt
prompt_token_ids
=
seqs
[
0
].
data
.
prompt_token_ids
prompt
=
top_n_
seqs
[
0
].
prompt
prompt_token_ids
=
top_n_
seqs
[
0
].
data
.
prompt_token_ids
return
RequestOutput
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
,
seq_group
.
is_finished
())
...
...
cacheflow/sampling_params.py
View file @
f746ced0
"""Sampling parameters for text generation."""
from
typing
import
Set
from
typing
import
List
,
Optional
,
Union
class
SamplingParams
:
...
...
@@ -10,8 +10,12 @@ class SamplingParams:
In addition, we support beam search, which is not supported by OpenAI.
Args:
n: Number of output sequences to generate from the given prompt. This is
regarded as the beam width when using beam search.
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
...
...
@@ -28,7 +32,10 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
stop_token_ids: Set of token IDs that indicate the end of a sequence.
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
"""
...
...
@@ -36,24 +43,28 @@ class SamplingParams:
def
__init__
(
self
,
n
:
int
=
1
,
best_of
:
Optional
[
int
]
=
None
,
presence_penalty
:
float
=
0.0
,
frequency_penalty
:
float
=
0.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
]
=
set
(),
stop
:
Union
[
str
,
List
[
str
]]
=
[],
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
logprobs
:
int
=
0
,
)
->
None
:
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
use_beam_search
=
use_beam_search
self
.
stop_token_ids
=
stop_token_ids
self
.
stop
=
[
stop
]
if
isinstance
(
stop
,
str
)
else
list
(
stop
)
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
logprobs
=
logprobs
...
...
@@ -67,6 +78,9 @@ class SamplingParams:
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
if
self
.
best_of
<
self
.
n
:
raise
ValueError
(
f
"best_of must be greater than or equal to n, "
f
"got n=
{
self
.
n
}
and best_of=
{
self
.
best_of
}
."
)
if
not
-
2.0
<=
self
.
presence_penalty
<=
2.0
:
raise
ValueError
(
"presence_penalty must be in [-2, 2], got "
f
"
{
self
.
presence_penalty
}
."
)
...
...
@@ -89,8 +103,9 @@ class SamplingParams:
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
def
_verity_beam_search
(
self
)
->
None
:
if
self
.
n
==
1
:
raise
ValueError
(
"n must be greater than 1 when using beam search."
)
if
self
.
best_of
==
1
:
raise
ValueError
(
"best_of must be greater than 1 when using beam "
f
"search. Got
{
self
.
best_of
}
."
)
if
self
.
temperature
>
0.0
:
raise
ValueError
(
"temperature must be 0 when using beam search."
)
if
self
.
top_p
<
1.0
:
...
...
@@ -99,8 +114,9 @@ class SamplingParams:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
if
self
.
n
>
1
:
raise
ValueError
(
"n must be 1 when using greedy sampling."
)
if
self
.
best_of
>
1
:
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
if
self
.
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
if
self
.
top_k
!=
-
1
:
...
...
@@ -108,12 +124,14 @@ class SamplingParams:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
,"
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
cacheflow/sequence.py
View file @
f746ced0
...
...
@@ -22,11 +22,18 @@ class SequenceData:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
cumulative_logprobs
=
0.0
self
.
cumulative_logprob
=
0.0
def
append_token
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
output_token_ids
.
append
(
token_id
)
self
.
cumulative_logprob
+=
logprob
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
output_token_ids
...
...
@@ -37,9 +44,9 @@ class SequenceData:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
f
"prompt=
{
self
.
prompt
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
output_token_ids
}
)"
)
f
"output_token_ids=
{
self
.
output_token_ids
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
)"
)
class
Sequence
:
...
...
@@ -57,6 +64,7 @@ class Sequence:
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
output_text
=
""
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the prompt token ids.
...
...
@@ -88,18 +96,26 @@ class Sequence:
assert
token_id
in
logprobs
self
.
_append_tokens_to_blocks
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
output_token_ids
.
append
(
token_id
)
self
.
data
.
cumulative_logprobs
+=
logprobs
[
token_id
]
self
.
data
.
append_token
(
token_id
,
logprobs
[
token_id
])
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
def
get_output_len
(
self
)
->
int
:
return
self
.
data
.
get_output_len
()
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
get_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
get_last_token_id
()
def
get_output_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
output_token_ids
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
...
...
cacheflow/server/llm_server.py
View file @
f746ced0
...
...
@@ -13,7 +13,7 @@ from cacheflow.logger import init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.utils
import
Counter
from
cacheflow.worker.worker
import
Worker
...
...
@@ -49,7 +49,6 @@ class LLMServer:
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
model_config
.
model
)
...
...
@@ -124,15 +123,11 @@ class LLMServer:
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
for
_
in
range
(
sampling_params
.
best_of
):
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seqs
.
append
(
seq
)
# FIXME(woosuk)
# Add the EOS token to the stop token list.
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
seqs
,
sampling_params
,
arrival_time
)
...
...
@@ -157,18 +152,65 @@ class LLMServer:
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
# Update the scheduler.
updated_seq_groups
=
self
.
scheduler
.
update
(
output
)
# Update the scheduler with the model outputs.
seq_groups
=
self
.
scheduler
.
update
(
output
)
# Decode the sequences.
self
.
_decode_sequences
(
seq_groups
)
# Stop the sequences that meet the stopping criteria.
self
.
_stop_sequences
(
seq_groups
)
# Free the finished sequence groups.
self
.
scheduler
.
free_finished_seq_groups
()
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
updated_seq_groups
:
# TODO(woosuk): Batch-decode the outputs for speedup.
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
,
self
.
tokenizer
)
for
seq_group
in
seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
return
request_outputs
def
_decode_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
# Batch-decode the sequence outputs.
seqs
:
List
[
Sequence
]
=
[]
for
seq_group
in
seq_groups
:
seqs
.
extend
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
))
output_tokens_per_seq
=
[]
for
seq
in
seqs
:
output_tokens_per_seq
.
append
(
seq
.
get_output_token_ids
())
output_texts
=
self
.
tokenizer
.
batch_decode
(
output_tokens_per_seq
,
skip_special_tokens
=
True
)
# Update the sequences with the output texts.
for
seq
,
output_text
in
zip
(
seqs
,
output_texts
):
seq
.
output_text
=
output_text
def
_stop_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
# Stop the sequences.
for
seq_group
in
seq_groups
:
sampling_params
=
seq_group
.
sampling_params
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
# Check if the sequence has generated a stop string.
stopped
=
False
for
stop_str
in
sampling_params
.
stop
:
if
seq
.
output_text
.
endswith
(
stop_str
):
# Truncate the output text so that the stop string is
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
self
.
scheduler
.
free_seq
(
seq
)
stopped
=
True
break
if
stopped
:
continue
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
self
.
scheduler
.
free_seq
(
seq
)
continue
# Check if the sequence has generated the EOS token.
if
not
sampling_params
.
ignore_eos
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
self
.
scheduler
.
free_seq
(
seq
)
continue
def
_run_workers
(
self
,
method
:
str
,
...
...
examples/simple_server.py
View file @
f746ced0
...
...
@@ -15,9 +15,9 @@ def main(args: argparse.Namespace):
(
"To be or not to be,"
,
SamplingParams
(
temperature
=
0.8
,
top_k
=
5
,
presence_penalty
=
0.2
)),
(
"What is the meaning of life?"
,
SamplingParams
(
n
=
2
,
temperature
=
0.8
,
top_p
=
0.95
,
frequency_penalty
=
0.1
)),
SamplingParams
(
n
=
2
,
best_of
=
5
,
temperature
=
0.8
,
top_p
=
0.95
,
frequency_penalty
=
0.1
)),
(
"It is only with the heart that one can see rightly"
,
SamplingParams
(
n
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
]
# Run the server.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment