Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ebf69964
Unverified
Commit
ebf69964
authored
Aug 04, 2024
by
min-xu-et
Committed by
GitHub
Aug 04, 2024
Browse files
latency test enhancement - final part (#921)
parent
141e8c71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
145 additions
and
35 deletions
+145
-35
python/pyproject.toml
python/pyproject.toml
+3
-1
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+142
-34
No files found.
python/pyproject.toml
View file @
ebf69964
...
@@ -20,14 +20,16 @@ dependencies = [
...
@@ -20,14 +20,16 @@ dependencies = [
]
]
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"jsonlines"
,
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.3.post1"
,
"outlines>=0.0.44"
]
"vllm==0.5.3.post1"
,
"outlines>=0.0.44"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
litellm
=
["litellm>=1.0.0"]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
[project.urls]
[project.urls]
"Homepage"
=
"https://github.com/sgl-project/sglang"
"Homepage"
=
"https://github.com/sgl-project/sglang"
...
...
python/sglang/bench_latency.py
View file @
ebf69964
"""
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test) with dummy weights:
# Usage (latency test)
## with dummy weights:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
## do some changes, and store the results under a different run_name:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
## plot the results in series of lines:
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
# Usage (correctness test):
# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
##
#
Reference output (of the correctness test above, can be gpu dependent):
## Reference output (of the correctness test above, can be gpu dependent):
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
...
@@ -28,13 +36,16 @@ I'm going to the park
...
@@ -28,13 +36,16 @@ I'm going to the park
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
itertools
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
sqlite3
import
time
import
time
from
typing
import
Tuple
from
typing
import
Tuple
import
jsonlines
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -49,26 +60,42 @@ from sglang.srt.utils import suppress_other_loggers
...
@@ -49,26 +60,42 @@ from sglang.srt.utils import suppress_other_loggers
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BenchArgs
:
class
BenchArgs
:
run_name
:
str
=
"before"
batch_size
:
Tuple
[
int
]
=
(
1
,)
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
int
=
1024
input_len
:
Tuple
[
int
]
=
(
1024
,)
output_len
:
int
=
4
output_len
:
Tuple
[
int
]
=
(
4
,)
result_filename
:
str
=
""
result_filename
:
str
=
""
correctness_test
:
bool
=
False
correctness_test
:
bool
=
False
# This is only used for correctness test
# This is only used for correctness test
cut_len
:
int
=
4
cut_len
:
int
=
4
# Plotting args
graph_sql
:
str
=
(
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
)
graph_filename
:
str
=
"out.png"
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--run-name"
,
type
=
str
,
default
=
BenchArgs
.
run_name
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
batch_size
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
batch_size
)
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
BenchArgs
.
input_len
)
parser
.
add_argument
(
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
BenchArgs
.
output_len
)
"--input-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
input_len
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
)
parser
.
add_argument
(
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
)
)
parser
.
add_argument
(
"--correctness-test"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--correctness-test"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--cut-len"
,
type
=
int
,
default
=
BenchArgs
.
cut_len
)
parser
.
add_argument
(
"--cut-len"
,
type
=
int
,
default
=
BenchArgs
.
cut_len
)
# graphing
parser
.
add_argument
(
"--graph-sql"
,
type
=
str
,
default
=
BenchArgs
.
graph_sql
)
parser
.
add_argument
(
"--graph-filename"
,
type
=
str
,
default
=
BenchArgs
.
graph_filename
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
@@ -222,15 +249,21 @@ def correctness_test(
...
@@ -222,15 +249,21 @@ def correctness_test(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
latency_test_run_once
(
def
latency_test_run_once
(
model_runner
,
rank_print
,
reqs
,
batch_size
,
input_len
,
output_len
run_name
,
model_runner
,
rank_print
,
reqs
,
batch_size
,
input_len
,
output_len
):
):
max_batch_size
=
model_runner
.
max_total_num_tokens
//
(
input_len
+
output_len
)
if
batch_size
>
max_batch_size
:
rank_print
(
f
"skipping (
{
batch_size
}
,
{
input_len
}
,
{
output_len
}
) due to max batch size limit"
)
return
# Clear the pools.
# Clear the pools.
model_runner
.
req_to_token_pool
.
clear
()
model_runner
.
req_to_token_pool
.
clear
()
model_runner
.
token_to_kv_pool
.
clear
()
model_runner
.
token_to_kv_pool
.
clear
()
measurement_results
=
{
measurement_results
=
{
"run_name"
:
"before"
,
"run_name"
:
run_name
,
"batch_size"
:
batch_size
,
"batch_size"
:
batch_size
,
"input_len"
:
input_len
,
"input_len"
:
input_len
,
"output_len"
:
output_len
,
"output_len"
:
output_len
,
...
@@ -291,49 +324,119 @@ def latency_test(
...
@@ -291,49 +324,119 @@ def latency_test(
# Load the model
# Load the model
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
rank_print
(
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
)
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
# Prepare inputs for warm up
bench_args
.
batch_size
=
bench_args
.
batch_size
[
0
]
# Prepare inputs
reqs
=
prepare_synthetic_inputs_for_latency_test
(
reqs
=
prepare_synthetic_inputs_for_latency_test
(
bench_args
.
batch_size
,
bench_args
.
input_len
bench_args
.
batch_size
[
0
]
,
bench_args
.
input_len
[
0
]
)
)
# Warm up
# Warm up
latency_test_run_once
(
latency_test_run_once
(
model_runner
,
rank_print
,
reqs
,
bench_args
.
batch_size
,
bench_args
.
input_len
,
4
bench_args
.
run_name
,
model_runner
,
rank_print
,
reqs
,
bench_args
.
batch_size
[
0
],
bench_args
.
input_len
[
0
],
4
,
# shorter decoding to speed up the warmup
)
)
# Run
again
# Run
the sweep
result_list
=
[]
result_list
=
[]
result_list
.
append
(
for
bs
,
il
,
ol
in
itertools
.
product
(
latency_test_run_once
(
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
model_runner
,
):
rank_print
,
req
=
prepare_synthetic_inputs_for_latency_test
(
bs
,
il
)
reqs
,
ret
=
latency_test_run_once
(
bench_args
.
batch_size
,
bench_args
.
run_name
,
model_runner
,
rank_print
,
reqs
,
bs
,
il
,
ol
bench_args
.
input_len
,
bench_args
.
output_len
,
)
)
)
if
ret
is
not
None
:
result_list
.
append
(
ret
)
# Write results in jsonlines format on rank 0.
if
tp_rank
==
0
and
bench_args
.
result_filename
:
import
jsonlines
# Write results in jsonlines format.
if
bench_args
.
result_filename
:
with
jsonlines
.
open
(
bench_args
.
result_filename
,
"a"
)
as
f
:
with
jsonlines
.
open
(
bench_args
.
result_filename
,
"a"
)
as
f
:
f
.
write_all
(
result_list
)
f
.
write_all
(
result_list
)
def
plot_latency_test
(
server_args
,
bench_args
,
tp_rank
,
):
assert
tp_rank
==
0
# read the jsonl file and put in sqlite
df
=
pd
.
read_json
(
bench_args
.
result_filename
,
lines
=
True
)
conn
=
sqlite3
.
connect
(
":memory:"
)
cur
=
conn
.
cursor
()
# get the columns and their types
column_names
=
list
(
df
.
iloc
[
0
].
keys
())
type_dict
=
{
str
:
"TEXT"
,
np
.
int64
:
"INTEGER"
,
np
.
float64
:
"FLOAT"
,
}
column_types
=
[
type_dict
[
type
(
i
)]
for
i
in
list
(
df
.
iloc
[
0
])]
# create the table
cur
.
execute
(
f
"""
CREATE TABLE IF NOT EXISTS results (
{
", "
.
join
([
f
"
{
name
}
{
type
}
" for name, type in zip(column_names, column_types)])
}
)
"""
)
conn
.
commit
()
# write the results to DB
df
.
to_sql
(
"results"
,
conn
,
if_exists
=
"replace"
,
index
=
False
)
conn
.
commit
()
# read it back using sql
df
=
pd
.
read_sql_query
(
bench_args
.
graph_sql
,
conn
)
conn
.
close
()
# plot it and save to a file
import
matplotlib.pyplot
as
plt
assert
(
len
(
df
.
columns
)
==
3
),
f
"The sql should have fetched <series, x, y> columns, not
{
df
.
columns
}
"
for
label
in
df
[
df
.
columns
[
0
]].
unique
():
q
=
f
"
{
df
.
columns
[
0
]
}
=='
{
label
}
'"
series
=
df
.
query
(
q
)
plt
.
plot
(
series
[
df
.
columns
[
1
]],
series
[
df
.
columns
[
2
]],
label
=
q
,
marker
=
"o"
)
plt
.
xlabel
(
df
.
columns
[
1
])
plt
.
ylabel
(
df
.
columns
[
2
])
plt
.
legend
()
plt
.
savefig
(
bench_args
.
graph_filename
,
dpi
=
300
)
# if in kitty, just dump it to the terminal
if
os
.
environ
[
"TERM"
]
==
"xterm-kitty"
:
os
.
system
(
f
"kitty icat --use-window-size 1,1,600,600
{
bench_args
.
graph_filename
}
"
)
def
main
(
server_args
,
bench_args
):
def
main
(
server_args
,
bench_args
):
print
(
bench_args
)
if
bench_args
.
correctness_test
:
if
server_args
.
model_path
:
work_func
=
correctness_test
if
bench_args
.
correctness_test
:
work_func
=
correctness_test
else
:
work_func
=
latency_test
elif
os
.
path
.
isfile
(
bench_args
.
result_filename
):
assert
bench_args
.
graph_filename
,
"please provide a filename for the graph"
work_func
=
plot_latency_test
else
:
else
:
work_func
=
latency_test
raise
ValueError
(
"Provide --model-path for running the tests or "
"provide --result-filename for plotting the results"
)
if
server_args
.
tp_size
==
1
:
if
server_args
.
tp_size
==
1
:
work_func
(
server_args
,
bench_args
,
0
)
work_func
(
server_args
,
bench_args
,
0
)
...
@@ -361,6 +464,11 @@ if __name__ == "__main__":
...
@@ -361,6 +464,11 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
ServerArgs
.
add_cli_args
(
parser
)
BenchArgs
.
add_cli_args
(
parser
)
BenchArgs
.
add_cli_args
(
parser
)
# For this script, model-path is not required
assert
(
parser
.
_actions
[
1
].
option_strings
[
0
]
==
"--model-path"
),
"options changed, this code need to be updated"
parser
.
_actions
[
1
].
required
=
False
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
ServerArgs
.
from_cli_args
(
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment