Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53985645
Unverified
Commit
53985645
authored
Aug 03, 2024
by
min-xu-et
Committed by
GitHub
Aug 03, 2024
Browse files
latency test enhancement - part 1 (#909)
parent
70cc0749
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
15 deletions
+53
-15
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+52
-14
No files found.
python/pyproject.toml
View file @
53985645
...
@@ -21,7 +21,7 @@ dependencies = [
...
@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.3.post1"
,
"outlines>=0.0.44"
,
"python-multipart"
]
"psutil"
,
"pydantic"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.3.post1"
,
"outlines>=0.0.44"
,
"python-multipart"
,
"jsonlines"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
litellm
=
["litellm>=1.0.0"]
...
...
python/sglang/bench_latency.py
View file @
53985645
"""
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test):
# Usage (latency test)
with dummy weights
:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
# Usage (correctness test):
# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
### Reference output:
### Reference output
(of the correctness test above, can be gpu dependent)
:
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
...
@@ -31,7 +31,9 @@ import dataclasses
...
@@ -31,7 +31,9 @@ import dataclasses
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
time
import
time
from
typing
import
Tuple
import
jsonlines
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
...
@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BenchArgs
:
class
BenchArgs
:
batch_size
:
int
=
1
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
int
=
1024
input_len
:
int
=
1024
output_len
:
int
=
4
output_len
:
int
=
4
result_filename
:
str
=
""
correctness_test
:
bool
=
False
correctness_test
:
bool
=
False
# This is only used for correctness test
# This is only used for correctness test
cut_len
:
int
=
4
cut_len
:
int
=
4
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
BenchArgs
.
batch_size
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
batch_size
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
BenchArgs
.
input_len
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
BenchArgs
.
input_len
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
BenchArgs
.
output_len
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
BenchArgs
.
output_len
)
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
)
parser
.
add_argument
(
"--correctness-test"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--correctness-test"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--cut-len"
,
type
=
int
,
default
=
BenchArgs
.
cut_len
)
parser
.
add_argument
(
"--cut-len"
,
type
=
int
,
default
=
BenchArgs
.
cut_len
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
# use the default value's type to case the args into correct types.
return
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
**
{
attr
:
attr_type
(
getattr
(
args
,
attr
))
for
attr
,
attr_type
in
attrs
}
)
def
load_model
(
server_args
,
tp_rank
):
def
load_model
(
server_args
,
tp_rank
):
...
@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
...
@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
return
model_runner
,
tokenizer
return
model_runner
,
tokenizer
def
prepare_inputs
(
bench_args
,
tokenizer
):
def
prepare_inputs
_for_correctness_test
(
bench_args
,
tokenizer
):
prompts
=
[
prompts
=
[
"The capital of France is"
,
"The capital of France is"
,
"The capital of the United Kindom is"
,
"The capital of the United Kindom is"
,
...
@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
...
@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
return
input_ids
,
reqs
return
input_ids
,
reqs
def
prepare_extend_inputs
(
bench_args
,
input_ids
,
reqs
,
model_runner
):
def
prepare_extend_inputs_for_correctness_test
(
bench_args
,
input_ids
,
reqs
,
model_runner
):
for
i
in
range
(
len
(
reqs
)):
for
i
in
range
(
len
(
reqs
)):
req
=
reqs
[
i
]
req
=
reqs
[
i
]
req
.
input_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
req
.
input_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
...
@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
...
@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
return
reqs
return
reqs
def
prepare_synthetic_inputs
(
bench_args
,
tokenizer
):
def
prepare_synthetic_inputs
_for_latency_test
(
batch_size
,
input_len
):
input_ids
=
np
.
ones
((
b
ench_args
.
batch_size
,
bench_args
.
input_len
),
dtype
=
np
.
int32
)
input_ids
=
np
.
ones
((
b
atch_size
,
input_len
),
dtype
=
np
.
int32
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
max_new_tokens
=
BenchArgs
.
output_len
,
max_new_tokens
=
BenchArgs
.
output_len
,
...
@@ -179,7 +192,7 @@ def correctness_test(
...
@@ -179,7 +192,7 @@ def correctness_test(
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
# Prepare inputs
# Prepare inputs
input_ids
,
reqs
=
prepare_inputs
(
bench_args
,
tokenizer
)
input_ids
,
reqs
=
prepare_inputs
_for_correctness_test
(
bench_args
,
tokenizer
)
if
bench_args
.
cut_len
>
0
:
if
bench_args
.
cut_len
>
0
:
# Prefill
# Prefill
...
@@ -187,7 +200,9 @@ def correctness_test(
...
@@ -187,7 +200,9 @@ def correctness_test(
rank_print
(
"prefill logits (first half)"
,
next_token_logits
)
rank_print
(
"prefill logits (first half)"
,
next_token_logits
)
# Prepare extend inputs
# Prepare extend inputs
reqs
=
prepare_extend_inputs
(
bench_args
,
input_ids
,
reqs
,
model_runner
)
reqs
=
prepare_extend_inputs_for_correctness_test
(
bench_args
,
input_ids
,
reqs
,
model_runner
)
# Extend
# Extend
next_token_ids
,
next_token_logits
,
batch
=
extend
(
reqs
,
model_runner
)
next_token_ids
,
next_token_logits
,
batch
=
extend
(
reqs
,
model_runner
)
...
@@ -218,8 +233,13 @@ def latency_test(
...
@@ -218,8 +233,13 @@ def latency_test(
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
)
)
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
bench_args
.
batch_size
=
bench_args
.
batch_size
[
0
]
# Prepare inputs
# Prepare inputs
reqs
=
prepare_synthetic_inputs
(
bench_args
,
tokenizer
)
reqs
=
prepare_synthetic_inputs_for_latency_test
(
bench_args
.
batch_size
,
bench_args
.
input_len
)
def
clear
():
def
clear
():
model_runner
.
req_to_token_pool
.
clear
()
model_runner
.
req_to_token_pool
.
clear
()
...
@@ -227,6 +247,11 @@ def latency_test(
...
@@ -227,6 +247,11 @@ def latency_test(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_once
(
output_len
):
def
run_once
(
output_len
):
measurement_results
=
{
"batch_size"
:
bench_args
.
batch_size
,
"output_len"
:
output_len
,
}
# Prefill
# Prefill
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
tot_latency
=
0
tot_latency
=
0
...
@@ -239,6 +264,8 @@ def latency_test(
...
@@ -239,6 +264,8 @@ def latency_test(
rank_print
(
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
)
measurement_results
[
"prefill_latency"
]
=
prefill_latency
measurement_results
[
"prefill_throughput"
]
=
throughput
# Decode
# Decode
for
i
in
range
(
output_len
):
for
i
in
range
(
output_len
):
...
@@ -258,6 +285,8 @@ def latency_test(
...
@@ -258,6 +285,8 @@ def latency_test(
rank_print
(
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
)
)
measurement_results
[
"avg_decode_latency"
]
=
avg_decode_latency
measurement_results
[
"avg_decode_throughput"
]
=
avg_decode_throughput
throughput
=
(
throughput
=
(
(
bench_args
.
input_len
+
bench_args
.
output_len
)
(
bench_args
.
input_len
+
bench_args
.
output_len
)
...
@@ -267,13 +296,22 @@ def latency_test(
...
@@ -267,13 +296,22 @@ def latency_test(
rank_print
(
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
)
measurement_results
[
"total_latency"
]
=
tot_latency
measurement_results
[
"total_throughput"
]
=
throughput
return
measurement_results
# Warm up
# Warm up
run_once
(
4
)
run_once
(
4
)
clear
()
clear
()
# Run again
# Run again
run_once
(
bench_args
.
output_len
)
result_list
=
[]
result_list
.
append
(
run_once
(
bench_args
.
output_len
))
# Write results in jsonlines format.
if
bench_args
.
result_filename
:
with
jsonlines
.
open
(
bench_args
.
result_filename
,
"a"
)
as
f
:
f
.
write_all
(
result_list
)
def
main
(
server_args
,
bench_args
):
def
main
(
server_args
,
bench_args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment