Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4b4a67f8
Unverified
Commit
4b4a67f8
authored
Jul 21, 2024
by
zhyncs
Committed by
GitHub
Jul 20, 2024
Browse files
feat: support TRT LLM benchmark and multiple benchmarks (#670)
parent
0ac94c36
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
156 additions
and
10 deletions
+156
-10
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+156
-10
No files found.
python/sglang/bench_serving.py
View file @
4b4a67f8
...
...
@@ -19,6 +19,7 @@ import traceback
import
warnings
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
dataclasses
import
dataclass
,
field
from
datetime
import
datetime
from
typing
import
AsyncGenerator
,
List
,
Optional
,
Tuple
,
Union
import
aiohttp
...
...
@@ -59,6 +60,72 @@ def remove_prefix(text: str, prefix: str) -> str:
return
text
[
len
(
prefix
)
:]
if
text
.
startswith
(
prefix
)
else
text
# trt llm not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async
def
async_request_trt_llm
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"generate_stream"
)
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
assert
not
request_func_input
.
use_beam_search
assert
request_func_input
.
best_of
==
1
payload
=
{
"accumulate_tokens"
:
True
,
"text_input"
:
request_func_input
.
prompt
,
"temperature"
:
0.0
,
"top_p"
:
1.0
,
"max_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
remove_prefix
(
chunk_bytes
.
decode
(
"utf-8"
),
"data:"
)
data
=
json
.
loads
(
chunk
)
output
.
generated_text
+=
data
[
"text_output"
]
timestamp
=
time
.
perf_counter
()
# First token
if
ttft
==
0.0
:
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
output
.
latency
=
most_recent_timestamp
-
st
output
.
success
=
True
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
# set ignore_eos True by default
async
def
async_request_openai_completions
(
request_func_input
:
RequestFuncInput
,
...
...
@@ -167,6 +234,7 @@ ASYNC_REQUEST_FUNCS = {
"sglang"
:
async_request_openai_completions
,
"vllm"
:
async_request_openai_completions
,
"lmdeploy"
:
async_request_openai_completions
,
"trt"
:
async_request_trt_llm
,
}
...
...
@@ -449,6 +517,7 @@ async def benchmark(
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
request_rate
:
float
,
disable_tqdm
:
bool
,
enable_multi
:
bool
,
):
if
backend
in
ASYNC_REQUEST_FUNCS
:
request_func
=
ASYNC_REQUEST_FUNCS
[
backend
]
...
...
@@ -542,6 +611,37 @@ async def benchmark(
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 ITL (ms):"
,
metrics
.
p99_itl_ms
))
print
(
"="
*
50
)
if
enable_multi
:
if
(
metrics
.
median_ttft_ms
is
not
None
and
metrics
.
mean_itl_ms
is
not
None
and
metrics
.
output_throughput
is
not
None
):
result
=
{
"dataset_name"
:
args
.
dataset_name
,
"request_rate"
:
request_rate
,
"median_ttft"
:
metrics
.
median_ttft_ms
,
"median_itl"
:
metrics
.
mean_itl_ms
,
"output_token_throughput"
:
metrics
.
output_throughput
,
"sharegpt_output_len"
:
args
.
sharegpt_output_len
,
"random_input_len"
:
args
.
random_input_len
,
"random_output_len"
:
args
.
random_output_len
,
}
else
:
print
(
f
"Error running benchmark for request rate:
{
request_rate
}
"
)
print
(
"-"
*
30
)
# Determine output file name
if
args
.
output_file
:
output_file_name
=
args
.
output_file
else
:
now
=
datetime
.
now
().
strftime
(
"%m%d%H"
)
output_file_name
=
f
"
{
args
.
backend
}
_
{
now
}
.jsonl"
# Append results to a JSONL file
with
open
(
output_file_name
,
"a"
)
as
file
:
file
.
write
(
json
.
dumps
(
result
)
+
"
\n
"
)
result
=
{
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
...
...
@@ -572,6 +672,11 @@ async def benchmark(
return
result
def
parse_request_rate_range
(
request_rate_range
):
start
,
stop
,
step
=
map
(
int
,
request_rate_range
.
split
(
","
))
return
list
(
range
(
start
,
stop
,
step
))
def
fire
(
args
:
argparse
.
Namespace
):
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
...
...
@@ -581,6 +686,7 @@ def fire(args: argparse.Namespace):
"sglang"
:
30000
,
"lmdeploy"
:
23333
,
"vllm"
:
8000
,
"trt"
:
8000
,
}.
get
(
args
.
backend
,
30000
)
api_url
=
(
...
...
@@ -594,6 +700,16 @@ def fire(args: argparse.Namespace):
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/v1/models"
)
if
args
.
backend
==
"trt"
:
api_url
=
(
f
"
{
args
.
base_url
}
/v2/models/ensemble/generate_stream"
if
args
.
base_url
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/v2/models/ensemble/generate_stream"
)
if
args
.
model
is
None
:
print
(
"Please provide a model using `--model` when using `trt` backend."
)
sys
.
exit
(
1
)
if
args
.
model
is
None
:
try
:
response
=
requests
.
get
(
model_url
)
...
...
@@ -637,17 +753,35 @@ def fire(args: argparse.Namespace):
else
:
raise
ValueError
(
f
"Unknown dataset:
{
args
.
dataset_name
}
"
)
asyncio
.
run
(
benchmark
(
backend
=
backend
,
api_url
=
api_url
,
model_id
=
model_id
,
tokenizer
=
tokenizer
,
input_requests
=
input_requests
,
request_rate
=
args
.
request_rate
,
disable_tqdm
=
args
.
disable_tqdm
,
if
args
.
multi
:
request_rates
=
parse_request_rate_range
(
args
.
request_rate_range
)
for
rate
in
request_rates
:
asyncio
.
run
(
benchmark
(
backend
=
backend
,
api_url
=
api_url
,
model_id
=
model_id
,
tokenizer
=
tokenizer
,
input_requests
=
input_requests
,
request_rate
=
rate
,
disable_tqdm
=
args
.
disable_tqdm
,
enable_multi
=
args
.
multi
,
)
)
else
:
asyncio
.
run
(
benchmark
(
backend
=
backend
,
api_url
=
api_url
,
model_id
=
model_id
,
tokenizer
=
tokenizer
,
input_requests
=
input_requests
,
request_rate
=
args
.
request_rate
,
disable_tqdm
=
args
.
disable_tqdm
,
enable_multi
=
args
.
multi
,
)
)
)
# to avoid relying on SGLang's components
...
...
@@ -751,6 +885,18 @@ if __name__ == "__main__":
action
=
"store_true"
,
help
=
"Specify to disable tqdm progress bar."
,
)
parser
.
add_argument
(
"--multi"
,
action
=
"store_true"
,
help
=
"Use request rate range rather than single value."
,
)
parser
.
add_argument
(
"--request-rate-range"
,
type
=
str
,
default
=
"2,34,2"
,
help
=
"Range of request rates in the format start,stop,step. Default is 2,34,2"
,
)
parser
.
add_argument
(
"--output-file"
,
type
=
str
,
help
=
"Output JSONL file name."
)
set_ulimit
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment