Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c126a6cc
Unverified
Commit
c126a6cc
authored
Jul 20, 2024
by
zhyncs
Committed by
GitHub
Jul 19, 2024
Browse files
feat: add benchmark serving (#657)
parent
ac971ff6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
660 additions
and
0 deletions
+660
-0
python/sglang/bench.py
python/sglang/bench.py
+627
-0
python/sglang/srt/openai_protocol.py
python/sglang/srt/openai_protocol.py
+17
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+16
-0
No files found.
python/sglang/bench.py
0 → 100644
View file @
c126a6cc
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
import
argparse
import
asyncio
import
json
import
os
import
random
import
resource
import
sys
import
time
import
traceback
import
warnings
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
dataclasses
import
dataclass
,
field
from
typing
import
AsyncGenerator
,
List
,
Optional
,
Tuple
,
Union
import
aiohttp
import
numpy
as
np
import
requests
from
tqdm.asyncio
import
tqdm
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerBase
,
PreTrainedTokenizerFast
,
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
@
dataclass
class
RequestFuncInput
:
prompt
:
str
api_url
:
str
prompt_len
:
int
output_len
:
int
model
:
str
@
dataclass
class
RequestFuncOutput
:
generated_text
:
str
=
""
success
:
bool
=
False
latency
:
float
=
0.0
ttft
:
float
=
0.0
# Time to first token
itl
:
List
[
float
]
=
field
(
default_factory
=
list
)
# List of inter-token latencies
prompt_len
:
int
=
0
error
:
str
=
""
def
remove_prefix
(
text
:
str
,
prefix
:
str
)
->
str
:
return
text
[
len
(
prefix
)
:]
if
text
.
startswith
(
prefix
)
else
text
# set ignore_eos True by default
async
def
async_request_openai_completions
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"completions"
),
"OpenAI Completions API URL must end with 'completions'."
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
payload
=
{
"model"
:
request_func_input
.
model
,
"prompt"
:
request_func_input
.
prompt
,
"temperature"
:
0.0
,
"best_of"
:
1
,
"max_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"ignore_eos"
:
True
,
}
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
remove_prefix
(
chunk_bytes
.
decode
(
"utf-8"
),
"data: "
)
if
chunk
==
"[DONE]"
:
latency
=
time
.
perf_counter
()
-
st
else
:
data
=
json
.
loads
(
chunk
)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if
data
[
"choices"
][
0
][
"text"
]:
timestamp
=
time
.
perf_counter
()
# First token
if
ttft
==
0.0
:
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
# Decoding phase
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
generated_text
+=
data
[
"choices"
][
0
][
"text"
]
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
latency
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
def
get_model
(
pretrained_model_name_or_path
:
str
)
->
str
:
if
os
.
getenv
(
"SGLANG_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
:
import
huggingface_hub.constants
from
modelscope
import
snapshot_download
model_path
=
snapshot_download
(
model_id
=
pretrained_model_name_or_path
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
ignore_file_pattern
=
[
".*.pt"
,
".*.safetensors"
,
".*.bin"
],
)
return
model_path
return
pretrained_model_name_or_path
def
get_tokenizer
(
pretrained_model_name_or_path
:
str
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
if
pretrained_model_name_or_path
is
not
None
and
not
os
.
path
.
exists
(
pretrained_model_name_or_path
):
pretrained_model_name_or_path
=
get_model
(
pretrained_model_name_or_path
)
return
AutoTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
trust_remote_code
=
True
)
ASYNC_REQUEST_FUNCS
=
{
"sglang"
:
async_request_openai_completions
,
"vllm"
:
async_request_openai_completions
,
"lmdeploy"
:
async_request_openai_completions
,
}
@
dataclass
class
BenchmarkMetrics
:
completed
:
int
total_input
:
int
total_output
:
int
request_throughput
:
float
input_throughput
:
float
output_throughput
:
float
mean_ttft_ms
:
float
median_ttft_ms
:
float
std_ttft_ms
:
float
p99_ttft_ms
:
float
mean_tpot_ms
:
float
median_tpot_ms
:
float
std_tpot_ms
:
float
p99_tpot_ms
:
float
mean_itl_ms
:
float
median_itl_ms
:
float
std_itl_ms
:
float
p99_itl_ms
:
float
def
sample_sharegpt_requests
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
]
=
None
,
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
default_dataset_path
=
"ShareGPT_V3_unfiltered_cleaned_split.json"
url
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
default_dataset_path
):
print
(
f
"Downloading dataset from
{
url
}
"
)
try
:
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
block_size
=
8192
with
open
(
default_dataset_path
,
"wb"
)
as
f
,
tqdm
(
desc
=
"Downloading"
,
total
=
total_size
,
unit
=
"iB"
,
unit_scale
=
True
,
unit_divisor
=
1024
,
)
as
progress_bar
:
for
data
in
response
.
iter_content
(
block_size
):
size
=
f
.
write
(
data
)
progress_bar
.
update
(
size
)
print
(
f
"Dataset downloaded and saved to
{
default_dataset_path
}
"
)
dataset_path
=
default_dataset_path
except
requests
.
RequestException
as
e
:
raise
Exception
(
f
"Failed to download dataset:
{
e
}
"
)
else
:
dataset_path
=
(
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_dataset_path
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Filter out the conversations with less than 2 turns.
dataset
=
[
data
for
data
in
dataset
if
len
(
data
[
"conversations"
])
>=
2
]
# Only keep the first two turns of each conversation.
dataset
=
[
(
data
[
"conversations"
][
0
][
"value"
],
data
[
"conversations"
][
1
][
"value"
])
for
data
in
dataset
]
# Shuffle the dataset.
random
.
shuffle
(
dataset
)
# Filter out sequences that are too long or too short
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
i
in
range
(
len
(
dataset
)):
if
len
(
filtered_dataset
)
==
num_requests
:
break
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt_token_ids
=
tokenizer
(
prompt
).
input_ids
completion
=
dataset
[
i
][
1
]
completion_token_ids
=
tokenizer
(
completion
).
input_ids
prompt_len
=
len
(
prompt_token_ids
)
output_len
=
(
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
)
if
prompt_len
<
4
or
output_len
<
4
:
# Prune too short sequences.
continue
if
prompt_len
>
1024
or
prompt_len
+
output_len
>
2048
:
# Prune too long sequences.
continue
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
return
filtered_dataset
async
def
get_request
(
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
request_rate
:
float
,
)
->
AsyncGenerator
[
Tuple
[
str
,
int
,
int
],
None
]:
input_requests
=
iter
(
input_requests
)
for
request
in
input_requests
:
yield
request
if
request_rate
==
float
(
"inf"
):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval
=
np
.
random
.
exponential
(
1.0
/
request_rate
)
# The next request will be sent after the interval.
await
asyncio
.
sleep
(
interval
)
def
calculate_metrics
(
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
outputs
:
List
[
RequestFuncOutput
],
dur_s
:
float
,
tokenizer
:
PreTrainedTokenizerBase
,
)
->
Tuple
[
BenchmarkMetrics
,
List
[
int
]]:
actual_output_lens
:
List
[
int
]
=
[]
total_input
=
0
completed
=
0
itls
:
List
[
float
]
=
[]
tpots
:
List
[
float
]
=
[]
ttfts
:
List
[
float
]
=
[]
for
i
in
range
(
len
(
outputs
)):
if
outputs
[
i
].
success
:
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len
=
len
(
tokenizer
(
outputs
[
i
].
generated_text
,
add_special_tokens
=
False
).
input_ids
)
actual_output_lens
.
append
(
output_len
)
total_input
+=
input_requests
[
i
][
1
]
if
output_len
>
1
:
tpots
.
append
((
outputs
[
i
].
latency
-
outputs
[
i
].
ttft
)
/
(
output_len
-
1
))
itls
+=
outputs
[
i
].
itl
ttfts
.
append
(
outputs
[
i
].
ttft
)
completed
+=
1
else
:
actual_output_lens
.
append
(
0
)
if
completed
==
0
:
warnings
.
warn
(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments."
,
stacklevel
=
2
,
)
metrics
=
BenchmarkMetrics
(
completed
=
completed
,
total_input
=
total_input
,
total_output
=
sum
(
actual_output_lens
),
request_throughput
=
completed
/
dur_s
,
input_throughput
=
total_input
/
dur_s
,
output_throughput
=
sum
(
actual_output_lens
)
/
dur_s
,
mean_ttft_ms
=
np
.
mean
(
ttfts
or
0
)
*
1000
,
# ttfts is empty if streaming is not supported by backend
median_ttft_ms
=
np
.
median
(
ttfts
or
0
)
*
1000
,
std_ttft_ms
=
np
.
std
(
ttfts
or
0
)
*
1000
,
p99_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
99
)
*
1000
,
mean_tpot_ms
=
np
.
mean
(
tpots
or
0
)
*
1000
,
median_tpot_ms
=
np
.
median
(
tpots
or
0
)
*
1000
,
std_tpot_ms
=
np
.
std
(
tpots
or
0
)
*
1000
,
p99_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
99
)
*
1000
,
mean_itl_ms
=
np
.
mean
(
itls
or
0
)
*
1000
,
median_itl_ms
=
np
.
median
(
itls
or
0
)
*
1000
,
std_itl_ms
=
np
.
std
(
itls
or
0
)
*
1000
,
p99_itl_ms
=
np
.
percentile
(
itls
or
0
,
99
)
*
1000
,
)
return
metrics
,
actual_output_lens
async
def
benchmark
(
backend
:
str
,
api_url
:
str
,
model_id
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
request_rate
:
float
,
disable_tqdm
:
bool
,
):
if
backend
in
ASYNC_REQUEST_FUNCS
:
request_func
=
ASYNC_REQUEST_FUNCS
[
backend
]
else
:
raise
ValueError
(
f
"Unknown backend:
{
backend
}
"
)
print
(
"Starting initial single prompt test run..."
)
test_prompt
,
test_prompt_len
,
test_output_len
=
input_requests
[
0
]
test_input
=
RequestFuncInput
(
model
=
model_id
,
prompt
=
test_prompt
,
api_url
=
api_url
,
prompt_len
=
test_prompt_len
,
output_len
=
test_output_len
,
)
test_output
=
await
request_func
(
request_func_input
=
test_input
)
if
not
test_output
.
success
:
raise
ValueError
(
"Initial test run failed - Please make sure benchmark arguments "
f
"are correctly specified. Error:
{
test_output
.
error
}
"
)
else
:
print
(
"Initial test run completed. Starting main benchmark run..."
)
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
benchmark_start_time
=
time
.
perf_counter
()
tasks
:
List
[
asyncio
.
Task
]
=
[]
async
for
request
in
get_request
(
input_requests
,
request_rate
):
prompt
,
prompt_len
,
output_len
=
request
request_func_input
=
RequestFuncInput
(
model
=
model_id
,
prompt
=
prompt
,
api_url
=
api_url
,
prompt_len
=
prompt_len
,
output_len
=
output_len
,
)
tasks
.
append
(
asyncio
.
create_task
(
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
)
)
outputs
:
List
[
RequestFuncOutput
]
=
await
asyncio
.
gather
(
*
tasks
)
if
pbar
is
not
None
:
pbar
.
close
()
benchmark_duration
=
time
.
perf_counter
()
-
benchmark_start_time
metrics
,
actual_output_lens
=
calculate_metrics
(
input_requests
=
input_requests
,
outputs
=
outputs
,
dur_s
=
benchmark_duration
,
tokenizer
=
tokenizer
,
)
print
(
"
\n
{s:{c}^{n}}"
.
format
(
s
=
" Serving Benchmark Result "
,
n
=
50
,
c
=
"="
))
print
(
"{:<40} {:<10}"
.
format
(
"Traffic request rate:"
,
request_rate
))
print
(
"{:<40} {:<10}"
.
format
(
"Successful requests:"
,
metrics
.
completed
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Benchmark duration (s):"
,
benchmark_duration
))
print
(
"{:<40} {:<10}"
.
format
(
"Total input tokens:"
,
metrics
.
total_input
))
print
(
"{:<40} {:<10}"
.
format
(
"Total generated tokens:"
,
metrics
.
total_output
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request throughput (req/s):"
,
metrics
.
request_throughput
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Input token throughput (tok/s):"
,
metrics
.
input_throughput
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Output token throughput (tok/s):"
,
metrics
.
output_throughput
)
)
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Time to First Token"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TTFT (ms):"
,
metrics
.
mean_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TTFT (ms):"
,
metrics
.
median_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TTFT (ms):"
,
metrics
.
p99_ttft_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Time per Output Token (excl. 1st token)"
,
n
=
50
,
c
=
"-"
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TPOT (ms):"
,
metrics
.
mean_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TPOT (ms):"
,
metrics
.
median_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TPOT (ms):"
,
metrics
.
p99_tpot_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Inter-token Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean ITL (ms):"
,
metrics
.
mean_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median ITL (ms):"
,
metrics
.
median_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 ITL (ms):"
,
metrics
.
p99_itl_ms
))
print
(
"="
*
50
)
result
=
{
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
"total_input_tokens"
:
metrics
.
total_input
,
"total_output_tokens"
:
metrics
.
total_output
,
"request_throughput"
:
metrics
.
request_throughput
,
"input_throughput"
:
metrics
.
input_throughput
,
"output_throughput"
:
metrics
.
output_throughput
,
"mean_ttft_ms"
:
metrics
.
mean_ttft_ms
,
"median_ttft_ms"
:
metrics
.
median_ttft_ms
,
"std_ttft_ms"
:
metrics
.
std_ttft_ms
,
"p99_ttft_ms"
:
metrics
.
p99_ttft_ms
,
"mean_tpot_ms"
:
metrics
.
mean_tpot_ms
,
"median_tpot_ms"
:
metrics
.
median_tpot_ms
,
"std_tpot_ms"
:
metrics
.
std_tpot_ms
,
"p99_tpot_ms"
:
metrics
.
p99_tpot_ms
,
"mean_itl_ms"
:
metrics
.
mean_itl_ms
,
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"std_itl_ms"
:
metrics
.
std_itl_ms
,
"p99_itl_ms"
:
metrics
.
p99_itl_ms
,
"input_lens"
:
[
output
.
prompt_len
for
output
in
outputs
],
"output_lens"
:
actual_output_lens
,
"ttfts"
:
[
output
.
ttft
for
output
in
outputs
],
"itls"
:
[
output
.
itl
for
output
in
outputs
],
"generated_texts"
:
[
output
.
generated_text
for
output
in
outputs
],
"errors"
:
[
output
.
error
for
output
in
outputs
],
}
return
result
def
fire
(
args
:
argparse
.
Namespace
):
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
if
args
.
port
is
None
:
args
.
port
=
{
"sglang"
:
30000
,
"lmdeploy"
:
23333
,
"vllm"
:
8000
,
}.
get
(
args
.
backend
,
30000
)
api_url
=
(
f
"
{
args
.
base_url
}
/v1/completions"
if
args
.
base_url
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/v1/completions"
)
model_url
=
(
f
"
{
args
.
base_url
}
/v1/models"
if
args
.
base_url
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/v1/models"
)
if
args
.
model
is
None
:
try
:
response
=
requests
.
get
(
model_url
)
model_list
=
response
.
json
().
get
(
"data"
,
[])
args
.
model
=
model_list
[
0
][
"id"
]
if
model_list
else
None
except
Exception
as
e
:
print
(
f
"Failed to fetch model from
{
model_url
}
. Error:
{
e
}
"
)
print
(
"Please specify the correct host and port using `--host` and `--port`."
)
sys
.
exit
(
1
)
if
args
.
model
is
None
:
print
(
"No model specified or found. Please provide a model using `--model`."
)
sys
.
exit
(
1
)
print
(
f
"
{
args
}
\n
"
)
backend
=
args
.
backend
model_id
=
args
.
model
tokenizer_id
=
args
.
tokenizer
if
args
.
tokenizer
is
not
None
else
args
.
model
tokenizer
=
get_tokenizer
(
tokenizer_id
)
assert
args
.
dataset
is
not
None
input_requests
=
sample_sharegpt_requests
(
dataset_path
=
args
.
dataset
,
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
fixed_output_len
=
args
.
sharegpt_output_len
,
)
asyncio
.
run
(
benchmark
(
backend
=
backend
,
api_url
=
api_url
,
model_id
=
model_id
,
tokenizer
=
tokenizer
,
input_requests
=
input_requests
,
request_rate
=
args
.
request_rate
,
disable_tqdm
=
args
.
disable_tqdm
,
)
)
# to avoid relying on SGLang's components
def
set_ulimit
(
target_soft_limit
=
65535
):
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
if
current_soft
<
target_soft_limit
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
print
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the online serving throughput."
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
required
=
True
,
choices
=
list
(
ASYNC_REQUEST_FUNCS
.
keys
()),
help
=
"Must specify a backend, depending on the LLM Inference Engine."
,
)
parser
.
add_argument
(
"--base-url"
,
type
=
str
,
default
=
None
,
help
=
"Server or API base url if not using http host and port."
,
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Default host is 0.0.0.0."
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
help
=
"If not set, the default port is configured according to its default value for different LLM Inference Engines."
,
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"sharegpt"
,
help
=
"Path to the ShareGPT dataset"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Name or path of the model. If not set, the default model will request /v1/models for conf."
,
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
help
=
"Name or path of the tokenizer. If not set, using the model conf."
,
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
help
=
"Number of prompts to process. Default is 1000."
,
)
parser
.
add_argument
(
"--sharegpt-output-len"
,
type
=
int
,
default
=
None
,
help
=
"Output length for each request. Overrides the output length from the ShareGPT dataset."
,
)
parser
.
add_argument
(
"--request-rate"
,
type
=
float
,
default
=
128.0
,
help
=
"Number of requests per second. If this is inf, then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
,
help
=
"Default is 0."
)
parser
.
add_argument
(
"--disable-tqdm"
,
action
=
"store_true"
,
help
=
"Specify to disable tqdm progress bar."
,
)
set_ulimit
()
args
=
parser
.
parse_args
()
fire
(
args
)
python/sglang/srt/openai_protocol.py
View file @
c126a6cc
...
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
...
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
class
ModelCard
(
BaseModel
):
"""Model cards."""
id
:
str
object
:
str
=
"model"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
owned_by
:
str
=
"sglang"
root
:
Optional
[
str
]
=
None
class
ModelList
(
BaseModel
):
"""Model list consists of model cards."""
object
:
str
=
"list"
data
:
List
[
ModelCard
]
=
[]
class
ErrorResponse
(
BaseModel
):
class
ErrorResponse
(
BaseModel
):
object
:
str
=
"error"
object
:
str
=
"error"
message
:
str
message
:
str
...
...
python/sglang/srt/server.py
View file @
c126a6cc
...
@@ -44,6 +44,7 @@ from sglang.srt.openai_api_adapter import (
...
@@ -44,6 +44,7 @@ from sglang.srt.openai_api_adapter import (
v1_chat_completions
,
v1_chat_completions
,
v1_completions
,
v1_completions
,
)
)
from
sglang.srt.openai_protocol
import
ModelCard
,
ModelList
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
API_KEY_HEADER_NAME
,
API_KEY_HEADER_NAME
,
...
@@ -73,6 +74,21 @@ async def health() -> Response:
...
@@ -73,6 +74,21 @@ async def health() -> Response:
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
def
get_model_list
():
"""Available models."""
model_names
=
[
tokenizer_manager
.
model_path
]
return
model_names
@
app
.
get
(
"/v1/models"
)
def
available_models
():
"""Show available models."""
model_cards
=
[]
for
model_name
in
get_model_list
():
model_cards
.
append
(
ModelCard
(
id
=
model_name
,
root
=
model_name
))
return
ModelList
(
data
=
model_cards
)
@
app
.
get
(
"/get_model_info"
)
@
app
.
get
(
"/get_model_info"
)
async
def
get_model_info
():
async
def
get_model_info
():
result
=
{
result
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment