Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
f746ced0
"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "709a69176ea86f60786acb87ede52e62d9efb036"
Unverified
Commit
f746ced0
authored
May 21, 2023
by
Woosuk Kwon
Committed by
GitHub
May 21, 2023
Browse files
Implement stop strings and best_of (#114)
parent
c3442c1f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
162 additions
and
116 deletions
+162
-116
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+6
-6
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+15
-48
cacheflow/entrypoints/fastapi_server.py
cacheflow/entrypoints/fastapi_server.py
+3
-2
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+8
-8
cacheflow/outputs.py
cacheflow/outputs.py
+23
-21
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+29
-11
cacheflow/sequence.py
cacheflow/sequence.py
+21
-5
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+55
-13
examples/simple_server.py
examples/simple_server.py
+2
-2
No files found.
cacheflow/core/block_manager.py
View file @
f746ced0
...
@@ -80,7 +80,7 @@ class BlockSpaceManager:
...
@@ -80,7 +80,7 @@ class BlockSpaceManager:
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
# FIXME(woosuk): Here we assume that all sequences in the group share
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
# the same prompt. This may not be true for preempted sequences.
seq
=
seq_group
.
seqs
[
0
]
seq
=
seq_group
.
get_
seqs
()
[
0
]
num_required_blocks
=
len
(
seq
.
logical_token_blocks
)
num_required_blocks
=
len
(
seq
.
logical_token_blocks
)
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
# Use watermark to avoid frequent cache eviction.
# Use watermark to avoid frequent cache eviction.
...
@@ -88,7 +88,7 @@ class BlockSpaceManager:
...
@@ -88,7 +88,7 @@ class BlockSpaceManager:
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# NOTE: Here we assume that all sequences in the group have the same prompt.
# NOTE: Here we assume that all sequences in the group have the same prompt.
seq
=
seq_group
.
seqs
[
0
]
seq
=
seq_group
.
get_
seqs
()
[
0
]
# Allocate new physical token blocks that will store the prompt tokens.
# Allocate new physical token blocks that will store the prompt tokens.
block_table
:
BlockTable
=
[]
block_table
:
BlockTable
=
[]
...
@@ -99,7 +99,7 @@ class BlockSpaceManager:
...
@@ -99,7 +99,7 @@ class BlockSpaceManager:
block_table
.
append
(
block
)
block_table
.
append
(
block
)
# Assign the block table for each sequence.
# Assign the block table for each sequence.
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
...
@@ -147,7 +147,7 @@ class BlockSpaceManager:
...
@@ -147,7 +147,7 @@ class BlockSpaceManager:
# NOTE: Here, we assume that the physical blocks are only shared by
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
# the sequences in the same group.
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
continue
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -168,7 +168,7 @@ class BlockSpaceManager:
...
@@ -168,7 +168,7 @@ class BlockSpaceManager:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
# CPU block -> GPU block.
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
...
@@ -199,7 +199,7 @@ class BlockSpaceManager:
...
@@ -199,7 +199,7 @@ class BlockSpaceManager:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
# GPU block -> CPU block.
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
...
...
cacheflow/core/scheduler.py
View file @
f746ced0
...
@@ -73,8 +73,6 @@ class Scheduler:
...
@@ -73,8 +73,6 @@ class Scheduler:
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
# Sequence groups in the RUNNING state.
# Sequence groups in the RUNNING state.
self
.
running
:
List
[
SequenceGroup
]
=
[]
self
.
running
:
List
[
SequenceGroup
]
=
[]
# Mapping: request_id -> num_steps.
self
.
num_steps
:
Dict
[
str
,
int
]
=
{}
# Sequence groups in the SWAPPED state.
# Sequence groups in the SWAPPED state.
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
...
@@ -84,7 +82,6 @@ class Scheduler:
...
@@ -84,7 +82,6 @@ class Scheduler:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
assert
seq_group
.
request_id
not
in
self
.
num_steps
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
has_unfinished_seqs
(
self
)
->
bool
:
def
has_unfinished_seqs
(
self
)
->
bool
:
...
@@ -178,7 +175,7 @@ class Scheduler:
...
@@ -178,7 +175,7 @@ class Scheduler:
break
break
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
num_prompt_tokens
=
seq_group
.
get_
seqs
()
[
0
].
get_len
()
if
(
num_batched_tokens
+
num_prompt_tokens
if
(
num_batched_tokens
+
num_prompt_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
break
...
@@ -278,15 +275,8 @@ class Scheduler:
...
@@ -278,15 +275,8 @@ class Scheduler:
)
->
List
[
SequenceGroup
]:
)
->
List
[
SequenceGroup
]:
# Update the running sequences and free blocks.
# Update the running sequences and free blocks.
for
seq_group
in
self
.
running
:
for
seq_group
in
self
.
running
:
request_id
=
seq_group
.
request_id
# Process beam search results before processing the new tokens.
self
.
num_steps
[
request_id
]
+=
1
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
stop_token_ids
=
seq_group
.
sampling_params
.
stop_token_ids
# Process beam search results before processing the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
output
=
seq_outputs
[
seq
.
seq_id
]
output
=
seq_outputs
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
output
.
parent_seq_id
:
if
seq
.
seq_id
!=
output
.
parent_seq_id
:
# The sequence is a fork of the parent sequence (beam search).
# The sequence is a fork of the parent sequence (beam search).
...
@@ -297,43 +287,27 @@ class Scheduler:
...
@@ -297,43 +287,27 @@ class Scheduler:
parent_seq
.
fork
(
seq
)
parent_seq
.
fork
(
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
# Process the next tokens.
# Process the new tokens.
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
# Append a new token to the sequence.
# Append a new token to the sequence.
output
=
seq_outputs
[
seq
.
seq_id
]
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append_token
(
output
.
output_token
,
output
.
logprobs
)
seq
.
append_token
(
output
.
output_token
,
output
.
logprobs
)
return
self
.
running
.
copy
()
# Check if the sequence has generated a stop token.
def
free_seq
(
self
,
seq
:
Sequence
)
->
None
:
if
output
.
output_token
in
stop_token_ids
:
seq
.
status
=
SequenceStatus
.
FINISHED
self
.
_free_seq
(
seq
)
self
.
block_manager
.
free
(
seq
)
continue
# Check if the sequence has reached the maximum number of steps.
def
free_finished_seq_groups
(
self
)
->
None
:
max_num_steps
=
seq_group
.
sampling_params
.
max_tokens
self
.
running
=
[
if
self
.
num_steps
[
request_id
]
==
max_num_steps
:
seq_group
for
seq_group
in
self
.
running
self
.
_free_seq
(
seq
)
if
not
seq_group
.
is_finished
()
continue
]
# Update the running sequences.
updated
=
self
.
running
.
copy
()
running
:
List
[
SequenceGroup
]
=
[]
for
seq_group
in
self
.
running
:
if
seq_group
.
is_finished
():
self
.
_free_seq_group
(
seq_group
)
else
:
running
.
append
(
seq_group
)
self
.
running
=
running
return
updated
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
self
.
block_manager
.
allocate
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
if
seq_group
.
request_id
not
in
self
.
num_steps
:
self
.
num_steps
[
seq_group
.
request_id
]
=
0
def
_append_slot
(
def
_append_slot
(
self
,
self
,
...
@@ -403,13 +377,6 @@ class Scheduler:
...
@@ -403,13 +377,6 @@ class Scheduler:
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
swapped
.
append
(
seq_group
)
self
.
swapped
.
append
(
seq_group
)
def
_free_seq
(
self
,
seq
:
Sequence
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED
self
.
block_manager
.
free
(
seq
)
def
_free_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
del
self
.
num_steps
[
seq_group
.
request_id
]
def
_swap_in
(
def
_swap_in
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
...
...
cacheflow/entrypoints/fastapi_server.py
View file @
f746ced0
...
@@ -123,6 +123,7 @@ if __name__ == "__main__":
...
@@ -123,6 +123,7 @@ if __name__ == "__main__":
parallel_config
=
server_configs
[
2
]
parallel_config
=
server_configs
[
2
]
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
stage_devices
=
initialize_cluster
(
parallel_config
)
server
=
FastAPIServer
(
server
=
FastAPIServer
(
args
.
use_ray
,
*
server_configs
,
args
.
use_ray
,
*
server_configs
,
distributed_init_method
,
stage_devices
)
distributed_init_method
,
stage_devices
,
log_stats
=
not
args
.
disable_log_stats
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/model_executor/layers/sampler.py
View file @
f746ced0
...
@@ -283,20 +283,20 @@ def _sample_from_prompt(
...
@@ -283,20 +283,20 @@ def _sample_from_prompt(
)
->
List
[
int
]:
)
->
List
[
int
]:
if
sampling_params
.
use_beam_search
:
if
sampling_params
.
use_beam_search
:
# Beam search.
# Beam search.
beam_width
=
sampling_params
.
n
beam_width
=
sampling_params
.
best_of
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
==
0.0
:
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
# Greedy sampling.
assert
sampling_params
.
n
==
1
assert
sampling_params
.
best_of
==
1
next_token_id
=
torch
.
argmax
(
prob
)
next_token_id
=
torch
.
argmax
(
prob
)
next_token_ids
=
[
next_token_id
.
item
()]
next_token_ids
=
[
next_token_id
.
item
()]
else
:
else
:
# Random sampling.
# Random sampling.
# Sample
n
tokens for the prompt.
# Sample
`best_of`
tokens for the prompt.
n
=
sampling_params
.
n
n
um_seqs
=
sampling_params
.
best_of
next_token_ids
=
torch
.
multinomial
(
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
n
,
replacement
=
True
)
prob
,
num_samples
=
n
um_seqs
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
return
next_token_ids
...
@@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
...
@@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
seq_logprobs
:
List
[
float
],
seq_logprobs
:
List
[
float
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
# NOTE(woosuk): sampling_params.
n
can be greater than
# NOTE(woosuk): sampling_params.
best_of
can be greater than
# len(seq_ids) because some sequences in the group might have
# len(seq_ids) because some sequences in the group might have
# been already terminated.
# been already terminated.
if
sampling_params
.
use_beam_search
:
if
sampling_params
.
use_beam_search
:
...
@@ -372,7 +372,7 @@ def _sample(
...
@@ -372,7 +372,7 @@ def _sample(
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
if
i
<
input_metadata
.
num_prompts
:
# Generate the next tokens for a prompt input.
# Generate the next tokens for a prompt input.
assert
len
(
seq_ids
)
==
sampling_params
.
n
assert
len
(
seq_ids
)
==
sampling_params
.
best_of
prob
=
probs
[
idx
]
prob
=
probs
[
idx
]
logprob
=
logprobs
[
idx
]
logprob
=
logprobs
[
idx
]
idx
+=
1
idx
+=
1
...
@@ -397,7 +397,7 @@ def _sample(
...
@@ -397,7 +397,7 @@ def _sample(
# Sample the next tokens.
# Sample the next tokens.
seq_logprobs
=
[
seq_logprobs
=
[
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
s
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
]
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
...
...
cacheflow/outputs.py
View file @
f746ced0
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroup
...
@@ -9,20 +7,23 @@ class CompletionOutput:
...
@@ -9,20 +7,23 @@ class CompletionOutput:
def
__init__
(
def
__init__
(
self
,
self
,
index
:
int
,
text
:
str
,
text
:
str
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
cumulative_logprob
s
:
float
,
cumulative_logprob
:
float
,
logprobs
:
List
[
Dict
[
int
,
float
]],
logprobs
:
List
[
Dict
[
int
,
float
]],
)
->
None
:
)
->
None
:
self
.
index
=
index
self
.
text
=
text
self
.
text
=
text
self
.
token_ids
=
token_ids
self
.
token_ids
=
token_ids
self
.
cumulative_logprob
s
=
cumulative_logprob
s
self
.
cumulative_logprob
=
cumulative_logprob
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"CompletionOutput(output=
{
self
.
text
!
r
}
, "
return
(
f
"CompletionOutput(index=
{
self
.
index
}
, "
f
"text=
{
self
.
text
!
r
}
, "
f
"token_ids=
{
self
.
token_ids
}
, "
f
"token_ids=
{
self
.
token_ids
}
, "
f
"cumulative_logprob
s
=
{
self
.
cumulative_logprob
s
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
f
"logprobs=
{
self
.
logprobs
}
)"
)
...
@@ -43,31 +44,32 @@ class RequestOutput:
...
@@ -43,31 +44,32 @@ class RequestOutput:
self
.
done
=
done
self
.
done
=
done
@
staticmethod
@
staticmethod
def
from_seq_group
(
def
from_seq_group
(
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
seq_group
:
SequenceGroup
,
# Get the top-n sequences.
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
n
=
seq_group
.
sampling_params
.
n
)
->
"RequestOutput"
:
outputs
:
List
[
CompletionOutput
]
=
[]
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
for
seq
in
seqs
:
assert
n
<=
len
(
seqs
)
output_token_ids
=
seq
.
data
.
output_token_ids
sorted_seqs
=
sorted
(
output_str
=
tokenizer
.
decode
(
output_token_ids
,
seqs
,
key
=
lambda
seq
:
seq
.
get_cumulative_logprob
(),
reverse
=
True
)
skip_special_tokens
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
seq_logprobs
=
seq
.
data
.
cumulative_logprobs
# Create the outputs.
outputs
:
List
[
CompletionOutput
]
=
[]
for
seq
in
top_n_seqs
:
logprobs
=
seq
.
output_logprobs
logprobs
=
seq
.
output_logprobs
if
seq_group
.
sampling_params
.
logprobs
==
0
:
if
seq_group
.
sampling_params
.
logprobs
==
0
:
# NOTE: We need to take care of this case because the sequence
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
# logprobs are not requested.
logprobs
=
{}
logprobs
=
{}
output
=
CompletionOutput
(
output_str
,
output_token_ids
,
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
seq_logprobs
,
logprobs
)
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
logprobs
)
outputs
.
append
(
output
)
outputs
.
append
(
output
)
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
prompt
=
seqs
[
0
].
prompt
prompt
=
top_n_
seqs
[
0
].
prompt
prompt_token_ids
=
seqs
[
0
].
data
.
prompt_token_ids
prompt_token_ids
=
top_n_
seqs
[
0
].
data
.
prompt_token_ids
return
RequestOutput
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
return
RequestOutput
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
outputs
,
seq_group
.
is_finished
())
outputs
,
seq_group
.
is_finished
())
...
...
cacheflow/sampling_params.py
View file @
f746ced0
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
from
typing
import
Set
from
typing
import
List
,
Optional
,
Union
class
SamplingParams
:
class
SamplingParams
:
...
@@ -10,8 +10,12 @@ class SamplingParams:
...
@@ -10,8 +10,12 @@ class SamplingParams:
In addition, we support beam search, which is not supported by OpenAI.
In addition, we support beam search, which is not supported by OpenAI.
Args:
Args:
n: Number of output sequences to generate from the given prompt. This is
n: Number of output sequences to return for the given prompt.
regarded as the beam width when using beam search.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
to use new tokens, while values < 0 encourage the model to repeat
...
@@ -28,7 +32,10 @@ class SamplingParams:
...
@@ -28,7 +32,10 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
use_beam_search: Whether to use beam search instead of sampling.
stop_token_ids: Set of token IDs that indicate the end of a sequence.
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
logprobs: Number of log probabilities to return per output token.
"""
"""
...
@@ -36,24 +43,28 @@ class SamplingParams:
...
@@ -36,24 +43,28 @@ class SamplingParams:
def
__init__
(
def
__init__
(
self
,
self
,
n
:
int
=
1
,
n
:
int
=
1
,
best_of
:
Optional
[
int
]
=
None
,
presence_penalty
:
float
=
0.0
,
presence_penalty
:
float
=
0.0
,
frequency_penalty
:
float
=
0.0
,
frequency_penalty
:
float
=
0.0
,
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
use_beam_search
:
bool
=
False
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
]
=
set
(),
stop
:
Union
[
str
,
List
[
str
]]
=
[],
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
logprobs
:
int
=
0
,
logprobs
:
int
=
0
,
)
->
None
:
)
->
None
:
self
.
n
=
n
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
presence_penalty
=
presence_penalty
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
stop_token_ids
=
stop_token_ids
self
.
stop
=
[
stop
]
if
isinstance
(
stop
,
str
)
else
list
(
stop
)
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
max_tokens
=
max_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
...
@@ -67,6 +78,9 @@ class SamplingParams:
...
@@ -67,6 +78,9 @@ class SamplingParams:
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
if
self
.
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
if
self
.
best_of
<
self
.
n
:
raise
ValueError
(
f
"best_of must be greater than or equal to n, "
f
"got n=
{
self
.
n
}
and best_of=
{
self
.
best_of
}
."
)
if
not
-
2.0
<=
self
.
presence_penalty
<=
2.0
:
if
not
-
2.0
<=
self
.
presence_penalty
<=
2.0
:
raise
ValueError
(
"presence_penalty must be in [-2, 2], got "
raise
ValueError
(
"presence_penalty must be in [-2, 2], got "
f
"
{
self
.
presence_penalty
}
."
)
f
"
{
self
.
presence_penalty
}
."
)
...
@@ -89,8 +103,9 @@ class SamplingParams:
...
@@ -89,8 +103,9 @@ class SamplingParams:
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
def
_verity_beam_search
(
self
)
->
None
:
def
_verity_beam_search
(
self
)
->
None
:
if
self
.
n
==
1
:
if
self
.
best_of
==
1
:
raise
ValueError
(
"n must be greater than 1 when using beam search."
)
raise
ValueError
(
"best_of must be greater than 1 when using beam "
f
"search. Got
{
self
.
best_of
}
."
)
if
self
.
temperature
>
0.0
:
if
self
.
temperature
>
0.0
:
raise
ValueError
(
"temperature must be 0 when using beam search."
)
raise
ValueError
(
"temperature must be 0 when using beam search."
)
if
self
.
top_p
<
1.0
:
if
self
.
top_p
<
1.0
:
...
@@ -99,8 +114,9 @@ class SamplingParams:
...
@@ -99,8 +114,9 @@ class SamplingParams:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
raise
ValueError
(
"top_k must be -1 when using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
def
_verify_greedy_sampling
(
self
)
->
None
:
if
self
.
n
>
1
:
if
self
.
best_of
>
1
:
raise
ValueError
(
"n must be 1 when using greedy sampling."
)
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
if
self
.
top_p
<
1.0
:
if
self
.
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
if
self
.
top_k
!=
-
1
:
if
self
.
top_k
!=
-
1
:
...
@@ -108,12 +124,14 @@ class SamplingParams:
...
@@ -108,12 +124,14 @@ class SamplingParams:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
,"
f
"top_k=
{
self
.
top_k
}
,"
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
f
"logprobs=
{
self
.
logprobs
}
)"
)
cacheflow/sequence.py
View file @
f746ced0
...
@@ -22,11 +22,18 @@ class SequenceData:
...
@@ -22,11 +22,18 @@ class SequenceData:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
cumulative_logprobs
=
0.0
self
.
cumulative_logprob
=
0.0
def
append_token
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
output_token_ids
.
append
(
token_id
)
self
.
cumulative_logprob
+=
logprob
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
output_token_ids
return
self
.
prompt_token_ids
+
self
.
output_token_ids
...
@@ -37,9 +44,9 @@ class SequenceData:
...
@@ -37,9 +44,9 @@ class SequenceData:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
return
(
f
"SequenceData("
f
"prompt=
{
self
.
prompt
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
output_token_ids
}
)"
)
f
"output_token_ids=
{
self
.
output_token_ids
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
)"
)
class
Sequence
:
class
Sequence
:
...
@@ -57,6 +64,7 @@ class Sequence:
...
@@ -57,6 +64,7 @@ class Sequence:
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
output_text
=
""
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the prompt token ids.
# Initialize the logical token blocks with the prompt token ids.
...
@@ -88,18 +96,26 @@ class Sequence:
...
@@ -88,18 +96,26 @@ class Sequence:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
_append_tokens_to_blocks
([
token_id
])
self
.
_append_tokens_to_blocks
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
output_token_ids
.
append
(
token_id
)
self
.
data
.
append_token
(
token_id
,
logprobs
[
token_id
])
self
.
data
.
cumulative_logprobs
+=
logprobs
[
token_id
]
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
return
self
.
data
.
get_len
()
def
get_output_len
(
self
)
->
int
:
return
self
.
data
.
get_output_len
()
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
get_token_ids
()
return
self
.
data
.
get_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
get_output_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
output_token_ids
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
...
...
cacheflow/server/llm_server.py
View file @
f746ced0
...
@@ -13,7 +13,7 @@ from cacheflow.logger import init_logger
...
@@ -13,7 +13,7 @@ from cacheflow.logger import init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.utils
import
Counter
from
cacheflow.utils
import
Counter
from
cacheflow.worker.worker
import
Worker
from
cacheflow.worker.worker
import
Worker
...
@@ -49,7 +49,6 @@ class LLMServer:
...
@@ -49,7 +49,6 @@ class LLMServer:
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
model_config
.
model
)
self
.
tokenizer
=
get_tokenizer
(
model_config
.
model
)
...
@@ -124,15 +123,11 @@ class LLMServer:
...
@@ -124,15 +123,11 @@ class LLMServer:
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
seqs
:
List
[
Sequence
]
=
[]
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
for
_
in
range
(
sampling_params
.
best_of
):
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
# FIXME(woosuk)
# Add the EOS token to the stop token list.
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
seqs
,
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
,
seqs
,
sampling_params
,
arrival_time
)
arrival_time
)
...
@@ -157,18 +152,65 @@ class LLMServer:
...
@@ -157,18 +152,65 @@ class LLMServer:
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
)
# Update the scheduler.
# Update the scheduler with the model outputs.
updated_seq_groups
=
self
.
scheduler
.
update
(
output
)
seq_groups
=
self
.
scheduler
.
update
(
output
)
# Decode the sequences.
self
.
_decode_sequences
(
seq_groups
)
# Stop the sequences that meet the stopping criteria.
self
.
_stop_sequences
(
seq_groups
)
# Free the finished sequence groups.
self
.
scheduler
.
free_finished_seq_groups
()
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
updated_seq_groups
:
for
seq_group
in
seq_groups
:
# TODO(woosuk): Batch-decode the outputs for speedup.
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
,
self
.
tokenizer
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
return
request_outputs
return
request_outputs
def
_decode_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
# Batch-decode the sequence outputs.
seqs
:
List
[
Sequence
]
=
[]
for
seq_group
in
seq_groups
:
seqs
.
extend
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
))
output_tokens_per_seq
=
[]
for
seq
in
seqs
:
output_tokens_per_seq
.
append
(
seq
.
get_output_token_ids
())
output_texts
=
self
.
tokenizer
.
batch_decode
(
output_tokens_per_seq
,
skip_special_tokens
=
True
)
# Update the sequences with the output texts.
for
seq
,
output_text
in
zip
(
seqs
,
output_texts
):
seq
.
output_text
=
output_text
def
_stop_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
# Stop the sequences.
for
seq_group
in
seq_groups
:
sampling_params
=
seq_group
.
sampling_params
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
# Check if the sequence has generated a stop string.
stopped
=
False
for
stop_str
in
sampling_params
.
stop
:
if
seq
.
output_text
.
endswith
(
stop_str
):
# Truncate the output text so that the stop string is
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
self
.
scheduler
.
free_seq
(
seq
)
stopped
=
True
break
if
stopped
:
continue
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
self
.
scheduler
.
free_seq
(
seq
)
continue
# Check if the sequence has generated the EOS token.
if
not
sampling_params
.
ignore_eos
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
self
.
scheduler
.
free_seq
(
seq
)
continue
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
str
,
...
...
examples/simple_server.py
View file @
f746ced0
...
@@ -15,9 +15,9 @@ def main(args: argparse.Namespace):
...
@@ -15,9 +15,9 @@ def main(args: argparse.Namespace):
(
"To be or not to be,"
,
(
"To be or not to be,"
,
SamplingParams
(
temperature
=
0.8
,
top_k
=
5
,
presence_penalty
=
0.2
)),
SamplingParams
(
temperature
=
0.8
,
top_k
=
5
,
presence_penalty
=
0.2
)),
(
"What is the meaning of life?"
,
(
"What is the meaning of life?"
,
SamplingParams
(
n
=
2
,
temperature
=
0.8
,
top_p
=
0.95
,
frequency_penalty
=
0.1
)),
SamplingParams
(
n
=
2
,
best_of
=
5
,
temperature
=
0.8
,
top_p
=
0.95
,
frequency_penalty
=
0.1
)),
(
"It is only with the heart that one can see rightly"
,
(
"It is only with the heart that one can see rightly"
,
SamplingParams
(
n
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
]
]
# Run the server.
# Run the server.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment