Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
3f942acf
Unverified
Commit
3f942acf
authored
May 22, 2023
by
Woosuk Kwon
Committed by
GitHub
May 22, 2023
Browse files
Fix latency benchmark script (#118)
parent
19d28994
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
31 deletions
+43
-31
benchmark/benchmark_latency.py
benchmark/benchmark_latency.py
+33
-29
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+10
-2
No files found.
benchmark/benchmark_latency.py
View file @
3f942acf
import
argparse
import
argparse
import
time
import
time
from
typing
import
List
from
tqdm
import
tqdm
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
tqdm
import
tqdm
from
cacheflow.core.server
import
(
from
cacheflow
import
LLM
,
SamplingParams
add_server_arguments
,
process_server_arguments
,
init_local_server_and_frontend_with_arguments
)
from
cacheflow.sampling_params
import
SamplingParams
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
print
(
args
)
# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the server will automatically process the request in multiple batches.
llm
=
LLM
(
model
=
args
.
model
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
max_num_seqs
=
args
.
batch_size
,
max_num_batched_tokens
=
args
.
batch_size
*
args
.
input_len
,
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
args
.
n
,
n
=
args
.
n
,
temperature
=
0.0
if
args
.
use_beam_search
else
1.0
,
temperature
=
0.0
if
args
.
use_beam_search
else
1.0
,
top_p
=
1.0
,
top_p
=
1.0
,
use_beam_search
=
args
.
use_beam_search
,
use_beam_search
=
args
.
use_beam_search
,
stop_token_ids
=
set
()
,
ignore_eos
=
True
,
max_tokens
=
args
.
output_len
,
max_tokens
=
args
.
output_len
,
)
)
print
(
sampling_params
)
print
(
sampling_params
)
input_token_ids
=
[
0
]
*
args
.
input_len
dummy_prompts
=
[
""
]
*
args
.
batch_size
dummy_prompt_token_ids
=
[[
0
]
*
args
.
input_len
]
*
args
.
batch_size
def
profile_step
(
profile
=
False
):
def
run_to_completion
(
profile
:
bool
=
False
):
if
profile
:
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
for
_
in
range
(
args
.
batch_size
):
dummy_prompt
=
""
frontend
.
_add_query
(
dummy_prompt
,
input_token_ids
,
sampling_params
)
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
True
:
server
.
step
()
llm
.
generate
(
dummy_prompts
,
sampling_params
,
dummy_prompt_token_ids
,
if
not
server
.
has_unfinished_requests
():
use_tqdm
=
False
)
break
end_time
=
time
.
time
()
end_time
=
time
.
time
()
latency
=
end_time
-
start_time
latency
=
end_time
-
start_time
if
profile
:
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
latency
return
latency
print
(
"Warm
up step
"
)
print
(
"Warm
ing up...
"
)
profile
_step
(
)
run_to_completion
(
profile
=
False
)
# Benchmark.
# Benchmark.
latencies
=
[]
latencies
=
[]
for
_
in
tqdm
(
range
(
3
),
desc
=
"Profil
e step
"
):
for
_
in
tqdm
(
range
(
args
.
num_iters
),
desc
=
"Profil
ing iterations
"
):
latencies
.
append
(
profile
_step
(
))
latencies
.
append
(
run_to_completion
(
profile
=
False
))
print
(
f
'Avg latency:
{
np
.
mean
(
latencies
)
}
seconds'
)
print
(
f
'Avg latency:
{
np
.
mean
(
latencies
)
}
seconds'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'Benchmark the latency of decoding a single sentence.'
)
description
=
'Benchmark the latency of processing a single batch of '
parser
=
add_server_arguments
(
parser
)
'requests till completion.'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
1
,
help
=
'Number of generated sequences per prompt.'
)
parser
.
add_argument
(
'--use-beam-search'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use-beam-search'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--num-iters'
,
type
=
int
,
default
=
3
,
help
=
'Number of iterations to run.'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
=
process_server_arguments
(
args
)
args
.
max_num_batched_tokens
=
max
(
args
.
max_num_batched_tokens
,
args
.
batch_size
*
args
.
input_len
)
print
(
args
)
main
(
args
)
main
(
args
)
cacheflow/entrypoints/llm.py
View file @
3f942acf
...
@@ -35,18 +35,26 @@ class LLM:
...
@@ -35,18 +35,26 @@ class LLM:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
# Initialize tqdm.
# Initialize tqdm.
if
use_tqdm
:
if
use_tqdm
:
pbar
=
tqdm
(
total
=
len
(
prompts
),
desc
=
"Processed prompts"
)
pbar
=
tqdm
(
total
=
len
(
prompts
),
desc
=
"Processed prompts"
)
# Add requests to the server.
# Add requests to the server.
for
prompt
in
prompts
:
for
i
in
range
(
len
(
prompts
)):
prompt
=
prompts
[
i
]
if
prompt_token_ids
is
None
:
token_ids
=
None
else
:
token_ids
=
prompt_token_ids
[
i
]
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_server
.
add_request
(
request_id
,
prompt
,
sampling_params
)
self
.
llm_server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
token_ids
)
# Run the server.
# Run the server.
outputs
:
List
[
RequestOutput
]
=
[]
outputs
:
List
[
RequestOutput
]
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment