Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
42f1042e
Unverified
Commit
42f1042e
authored
May 11, 2023
by
Woosuk Kwon
Committed by
GitHub
May 11, 2023
Browse files
Enhance SamplingParams (#96)
parent
55f8b0a5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
36 additions
and
54 deletions
+36
-54
benchmark/benchmark_latency.py
benchmark/benchmark_latency.py
+11
-11
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+1
-1
cacheflow/frontend/fastapi_frontend.py
cacheflow/frontend/fastapi_frontend.py
+2
-2
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+2
-2
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+18
-36
gradio_webserver.py
gradio_webserver.py
+1
-1
simple_server.py
simple_server.py
+1
-1
No files found.
benchmark/benchmark_latency.py
View file @
42f1042e
...
@@ -6,7 +6,7 @@ from tqdm import tqdm
...
@@ -6,7 +6,7 @@ from tqdm import tqdm
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
cacheflow.
master
.server
import
(
from
cacheflow.
core
.server
import
(
add_server_arguments
,
process_server_arguments
,
add_server_arguments
,
process_server_arguments
,
init_local_server_and_frontend_with_arguments
)
init_local_server_and_frontend_with_arguments
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
...
@@ -15,15 +15,14 @@ from cacheflow.sampling_params import SamplingParams
...
@@ -15,15 +15,14 @@ from cacheflow.sampling_params import SamplingParams
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
sampling_params_dict
=
{
sampling_params
=
SamplingParams
(
'n'
:
args
.
n
,
n
=
args
.
n
,
'temperature'
:
0.0
if
args
.
use_beam_search
else
1.0
,
temperature
=
0.0
if
args
.
use_beam_search
else
1.0
,
'top_p'
:
1.0
,
top_p
=
1.0
,
'use_beam_search'
:
args
.
use_beam_search
,
use_beam_search
=
args
.
use_beam_search
,
'stop_token_ids'
:
set
(),
stop_token_ids
=
set
(),
'max_num_steps'
:
args
.
output_len
,
max_tokens
=
args
.
output_len
,
}
)
sampling_params
=
SamplingParams
.
from_dict
(
sampling_params_dict
)
print
(
sampling_params
)
print
(
sampling_params
)
input_token_ids
=
[
0
]
*
args
.
input_len
input_token_ids
=
[
0
]
*
args
.
input_len
...
@@ -31,7 +30,8 @@ def main(args: argparse.Namespace):
...
@@ -31,7 +30,8 @@ def main(args: argparse.Namespace):
if
profile
:
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
for
_
in
range
(
args
.
batch_size
):
for
_
in
range
(
args
.
batch_size
):
frontend
.
_add_query
(
input_token_ids
,
sampling_params
)
dummy_prompt
=
""
frontend
.
_add_query
(
dummy_prompt
,
input_token_ids
,
sampling_params
)
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
True
:
while
True
:
...
...
cacheflow/core/scheduler.py
View file @
42f1042e
...
@@ -316,7 +316,7 @@ class Scheduler:
...
@@ -316,7 +316,7 @@ class Scheduler:
continue
continue
# Check if the sequence has reached the maximum number of steps.
# Check if the sequence has reached the maximum number of steps.
max_num_steps
=
self
.
sampling_params
[
group_id
].
max_
num_step
s
max_num_steps
=
self
.
sampling_params
[
group_id
].
max_
token
s
if
self
.
num_steps
[
group_id
]
==
max_num_steps
:
if
self
.
num_steps
[
group_id
]
==
max_num_steps
:
self
.
_free_seq
(
seq
)
self
.
_free_seq
(
seq
)
continue
continue
...
...
cacheflow/frontend/fastapi_frontend.py
View file @
42f1042e
...
@@ -89,8 +89,8 @@ class FastAPIServer:
...
@@ -89,8 +89,8 @@ class FastAPIServer:
async
def
generate
(
self
,
request_dict
:
Dict
):
async
def
generate
(
self
,
request_dict
:
Dict
):
# Preprocess the request.
# Preprocess the request.
prompt
=
request_dict
[
"prompt"
]
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
.
from_dict
(
request_dict
)
sampling_params
=
SamplingParams
(
**
request_dict
)
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
seqs
:
List
[
Sequence
]
=
[]
seqs
:
List
[
Sequence
]
=
[]
...
...
cacheflow/model_executor/layers/sampler.py
View file @
42f1042e
...
@@ -367,7 +367,7 @@ def _sample(
...
@@ -367,7 +367,7 @@ def _sample(
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
# Get top-k log probabilities for the next tokens.
next_logprobs
=
_get_topk_logprobs
(
next_logprobs
=
_get_topk_logprobs
(
logprob
,
sampling_params
.
num_
logprobs
)
logprob
,
sampling_params
.
logprobs
)
# Build the output.
# Build the output.
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
...
@@ -392,7 +392,7 @@ def _sample(
...
@@ -392,7 +392,7 @@ def _sample(
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
for
i
,
seq_id
in
enumerate
(
seq_ids
):
for
i
,
seq_id
in
enumerate
(
seq_ids
):
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
logprob
[
i
],
sampling_params
.
num_
logprobs
)
logprob
[
i
],
sampling_params
.
logprobs
)
# Build the output.
# Build the output.
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
...
...
cacheflow/sampling_params.py
View file @
42f1042e
...
@@ -5,16 +5,16 @@ class SamplingParams:
...
@@ -5,16 +5,16 @@ class SamplingParams:
def
__init__
(
def
__init__
(
self
,
self
,
n
:
int
,
n
:
int
=
1
,
presence_penalty
:
float
,
presence_penalty
:
float
=
0.0
,
frequency_penalty
:
float
,
frequency_penalty
:
float
=
0.0
,
temperature
:
float
,
temperature
:
float
=
1.0
,
top_p
:
float
,
top_p
:
float
=
1.0
,
top_k
:
int
,
top_k
:
int
=
-
1
,
use_beam_search
:
bool
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
],
stop_token_ids
:
Set
[
int
]
=
set
()
,
max_
num_step
s
:
int
,
max_
token
s
:
int
=
16
,
num_
logprobs
:
int
,
logprobs
:
int
=
0
,
)
->
None
:
)
->
None
:
if
n
<
1
:
if
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
...
@@ -32,12 +32,12 @@ class SamplingParams:
...
@@ -32,12 +32,12 @@ class SamplingParams:
if
top_k
<
-
1
or
top_k
==
0
:
if
top_k
<
-
1
or
top_k
==
0
:
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
f
"got
{
top_k
}
."
)
f
"got
{
top_k
}
."
)
if
max_
num_step
s
<
1
:
if
max_
token
s
<
1
:
raise
ValueError
(
raise
ValueError
(
f
"max_
num_step
s must be at least 1, got
{
max_
num_step
s
}
."
)
f
"max_
token
s must be at least 1, got
{
max_
token
s
}
."
)
if
num_
logprobs
<
0
:
if
logprobs
<
0
:
raise
ValueError
(
raise
ValueError
(
f
"
num_
logprobs must be non-negative, got
{
num_
logprobs
}
."
)
f
"logprobs must be non-negative, got
{
logprobs
}
."
)
if
use_beam_search
:
if
use_beam_search
:
if
n
==
1
:
if
n
==
1
:
...
@@ -72,8 +72,8 @@ class SamplingParams:
...
@@ -72,8 +72,8 @@ class SamplingParams:
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
stop_token_ids
=
stop_token_ids
self
.
stop_token_ids
=
stop_token_ids
self
.
max_
num_steps
=
max_num_step
s
self
.
max_
tokens
=
max_token
s
self
.
num_
logprobs
=
num_
logprobs
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
...
@@ -84,23 +84,5 @@ class SamplingParams:
...
@@ -84,23 +84,5 @@ class SamplingParams:
f
"top_k=
{
self
.
top_k
}
,"
f
"top_k=
{
self
.
top_k
}
,"
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"max_num_steps=
{
self
.
max_num_steps
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"num_logprobs=
{
self
.
num_logprobs
}
"
)
f
"logprobs=
{
self
.
logprobs
}
"
)
@
classmethod
def
from_dict
(
cls
,
d
:
Dict
)
->
"SamplingParams"
:
sampling_params
=
cls
(
n
=
d
.
pop
(
"n"
,
1
),
presence_penalty
=
d
.
pop
(
"presence_penalty"
,
0.0
),
frequency_penalty
=
d
.
pop
(
"frequency_penalty"
,
0.0
),
temperature
=
d
.
pop
(
"temperature"
,
1.0
),
top_p
=
d
.
pop
(
"top_p"
,
1.0
),
top_k
=
d
.
pop
(
"top_k"
,
-
1
),
use_beam_search
=
d
.
pop
(
"use_beam_search"
,
False
),
stop_token_ids
=
set
(
d
.
pop
(
"stop_token_ids"
,
set
())),
max_num_steps
=
d
.
pop
(
"max_num_steps"
,
16
),
num_logprobs
=
d
.
pop
(
"num_logprobs"
,
0
),
)
if
d
:
raise
ValueError
(
f
"Unrecognized keys in dict:
{
d
.
keys
()
}
"
)
return
sampling_params
gradio_webserver.py
View file @
42f1042e
...
@@ -10,7 +10,7 @@ def http_bot(prompt):
...
@@ -10,7 +10,7 @@ def http_bot(prompt):
headers
=
{
"User-Agent"
:
"Cacheflow Client"
}
headers
=
{
"User-Agent"
:
"Cacheflow Client"
}
pload
=
{
pload
=
{
"prompt"
:
prompt
,
"prompt"
:
prompt
,
"max_
num_step
s"
:
128
,
"max_
token
s"
:
128
,
}
}
response
=
requests
.
post
(
args
.
model_url
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
response
=
requests
.
post
(
args
.
model_url
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
...
...
simple_server.py
View file @
42f1042e
...
@@ -18,7 +18,7 @@ def main(args: argparse.Namespace):
...
@@ -18,7 +18,7 @@ def main(args: argparse.Namespace):
while
True
:
while
True
:
if
test_inputs
:
if
test_inputs
:
text
,
sampling_params_dict
=
test_inputs
.
pop
(
0
)
text
,
sampling_params_dict
=
test_inputs
.
pop
(
0
)
sampling_params
=
SamplingParams
.
from_dict
(
sampling_params_dict
)
sampling_params
=
SamplingParams
(
**
sampling_params_dict
)
sampling_params
=
frontend
.
add_eos_token
(
sampling_params
)
sampling_params
=
frontend
.
add_eos_token
(
sampling_params
)
frontend
.
query
(
text
,
sampling_params
)
frontend
.
query
(
text
,
sampling_params
)
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment