Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53985645
Unverified
Commit
53985645
authored
Aug 03, 2024
by
min-xu-et
Committed by
GitHub
Aug 03, 2024
Browse files
latency test enhancement - part 1 (#909)
parent
70cc0749
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
15 deletions
+53
-15
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+52
-14
No files found.
python/pyproject.toml
View file @
53985645
...
...
@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.3.post1"
,
"outlines>=0.0.44"
,
"python-multipart"
]
"psutil"
,
"pydantic"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.3.post1"
,
"outlines>=0.0.44"
,
"python-multipart"
,
"jsonlines"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
...
...
python/sglang/bench_latency.py
View file @
53985645
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test):
# Usage (latency test)
with dummy weights
:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
### Reference output:
### Reference output
(of the correctness test above, can be gpu dependent)
:
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
...
...
@@ -31,7 +31,9 @@ import dataclasses
import
logging
import
multiprocessing
import
time
from
typing
import
Tuple
import
jsonlines
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
...
...
@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
@
dataclasses
.
dataclass
class
BenchArgs
:
batch_size
:
int
=
1
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
int
=
1024
output_len
:
int
=
4
result_filename
:
str
=
""
correctness_test
:
bool
=
False
# This is only used for correctness test
cut_len
:
int
=
4
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
BenchArgs
.
batch_size
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
batch_size
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
BenchArgs
.
input_len
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
BenchArgs
.
output_len
)
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
)
parser
.
add_argument
(
"--correctness-test"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--cut-len"
,
type
=
int
,
default
=
BenchArgs
.
cut_len
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
# use the default value's type to case the args into correct types.
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
**
{
attr
:
attr_type
(
getattr
(
args
,
attr
))
for
attr
,
attr_type
in
attrs
}
)
def
load_model
(
server_args
,
tp_rank
):
...
...
@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
return
model_runner
,
tokenizer
def
prepare_inputs
(
bench_args
,
tokenizer
):
def
prepare_inputs
_for_correctness_test
(
bench_args
,
tokenizer
):
prompts
=
[
"The capital of France is"
,
"The capital of the United Kindom is"
,
...
...
@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
return
input_ids
,
reqs
def
prepare_extend_inputs
(
bench_args
,
input_ids
,
reqs
,
model_runner
):
def
prepare_extend_inputs_for_correctness_test
(
bench_args
,
input_ids
,
reqs
,
model_runner
):
for
i
in
range
(
len
(
reqs
)):
req
=
reqs
[
i
]
req
.
input_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
...
...
@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
return
reqs
def
prepare_synthetic_inputs
(
bench_args
,
tokenizer
):
input_ids
=
np
.
ones
((
b
ench_args
.
batch_size
,
bench_args
.
input_len
),
dtype
=
np
.
int32
)
def
prepare_synthetic_inputs
_for_latency_test
(
batch_size
,
input_len
):
input_ids
=
np
.
ones
((
b
atch_size
,
input_len
),
dtype
=
np
.
int32
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_new_tokens
=
BenchArgs
.
output_len
,
...
...
@@ -179,7 +192,7 @@ def correctness_test(
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
# Prepare inputs
input_ids
,
reqs
=
prepare_inputs
(
bench_args
,
tokenizer
)
input_ids
,
reqs
=
prepare_inputs
_for_correctness_test
(
bench_args
,
tokenizer
)
if
bench_args
.
cut_len
>
0
:
# Prefill
...
...
@@ -187,7 +200,9 @@ def correctness_test(
rank_print
(
"prefill logits (first half)"
,
next_token_logits
)
# Prepare extend inputs
reqs
=
prepare_extend_inputs
(
bench_args
,
input_ids
,
reqs
,
model_runner
)
reqs
=
prepare_extend_inputs_for_correctness_test
(
bench_args
,
input_ids
,
reqs
,
model_runner
)
# Extend
next_token_ids
,
next_token_logits
,
batch
=
extend
(
reqs
,
model_runner
)
...
...
@@ -218,8 +233,13 @@ def latency_test(
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
)
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
bench_args
.
batch_size
=
bench_args
.
batch_size
[
0
]
# Prepare inputs
reqs
=
prepare_synthetic_inputs
(
bench_args
,
tokenizer
)
reqs
=
prepare_synthetic_inputs_for_latency_test
(
bench_args
.
batch_size
,
bench_args
.
input_len
)
def
clear
():
model_runner
.
req_to_token_pool
.
clear
()
...
...
@@ -227,6 +247,11 @@ def latency_test(
@
torch
.
inference_mode
()
def
run_once
(
output_len
):
measurement_results
=
{
"batch_size"
:
bench_args
.
batch_size
,
"output_len"
:
output_len
,
}
# Prefill
torch
.
cuda
.
synchronize
()
tot_latency
=
0
...
...
@@ -239,6 +264,8 @@ def latency_test(
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
measurement_results
[
"prefill_latency"
]
=
prefill_latency
measurement_results
[
"prefill_throughput"
]
=
throughput
# Decode
for
i
in
range
(
output_len
):
...
...
@@ -258,6 +285,8 @@ def latency_test(
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
)
measurement_results
[
"avg_decode_latency"
]
=
avg_decode_latency
measurement_results
[
"avg_decode_throughput"
]
=
avg_decode_throughput
throughput
=
(
(
bench_args
.
input_len
+
bench_args
.
output_len
)
...
...
@@ -267,13 +296,22 @@ def latency_test(
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
measurement_results
[
"total_latency"
]
=
tot_latency
measurement_results
[
"total_throughput"
]
=
throughput
return
measurement_results
# Warm up
run_once
(
4
)
clear
()
# Run again
run_once
(
bench_args
.
output_len
)
result_list
=
[]
result_list
.
append
(
run_once
(
bench_args
.
output_len
))
# Write results in jsonlines format.
if
bench_args
.
result_filename
:
with
jsonlines
.
open
(
bench_args
.
result_filename
,
"a"
)
as
f
:
f
.
write_all
(
result_list
)
def
main
(
server_args
,
bench_args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment