Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8f2c522a
Unverified
Commit
8f2c522a
authored
Jan 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 16, 2025
Browse files
Improve benchmark scripts and error message printing (#2922)
parent
75964177
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
110 additions
and
55 deletions
+110
-55
python/sglang/bench_offline_throughput.py
python/sglang/bench_offline_throughput.py
+22
-15
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+37
-28
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+6
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+34
-7
python/sglang/srt/server.py
python/sglang/srt/server.py
+6
-2
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
test/srt/test_moe_ep.py
test/srt/test_moe_ep.py
+2
-2
No files found.
python/sglang/bench_offline_throughput.py
View file @
8f2c522a
...
...
@@ -39,14 +39,15 @@ class BenchArgs:
dataset_path
:
str
=
""
num_prompts
:
int
=
1000
sharegpt_output_len
:
Optional
[
int
]
=
None
sharegpt_context_len
:
Optional
[
int
]
=
None
random_input_len
:
int
=
1024
random_output_len
:
int
=
1024
random_range_ratio
:
float
=
0.0
g
en
_num_groups
:
int
=
64
g
en
_prompts_per_group
:
int
=
16
g
en
_system_prompt_len
:
int
=
2048
g
en
_question_len
:
int
=
128
g
en
_output_len
:
int
=
256
g
sp
_num_groups
:
int
=
64
g
sp
_prompts_per_group
:
int
=
16
g
sp
_system_prompt_len
:
int
=
2048
g
sp
_question_len
:
int
=
128
g
sp
_output_len
:
int
=
256
disable_ignore_eos
:
bool
=
False
extra_request_body
:
Optional
[
str
]
=
None
seed
:
int
=
1
...
...
@@ -82,6 +83,12 @@ class BenchArgs:
default
=
BenchArgs
.
sharegpt_output_len
,
help
=
"Output length for each request. Overrides the output length from the ShareGPT dataset."
,
)
parser
.
add_argument
(
"--sharegpt-context-len"
,
type
=
int
,
default
=
BenchArgs
.
sharegpt_context_len
,
help
=
"The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped."
,
)
parser
.
add_argument
(
"--random-input-len"
,
type
=
int
,
...
...
@@ -102,35 +109,35 @@ class BenchArgs:
"used only for random dataset."
,
)
parser
.
add_argument
(
"--g
en
-num-groups"
,
"--g
sp
-num-groups"
,
type
=
int
,
default
=
BenchArgs
.
g
en
_num_groups
,
default
=
BenchArgs
.
g
sp
_num_groups
,
help
=
"Number of groups with shared prefix, used"
"only for generate-shared-prefix"
,
)
parser
.
add_argument
(
"--g
en
-prompts-per-group"
,
"--g
sp
-prompts-per-group"
,
type
=
int
,
default
=
BenchArgs
.
g
en
_prompts_per_group
,
default
=
BenchArgs
.
g
sp
_prompts_per_group
,
help
=
"Number of prompts per group of shared prefix, used"
"only for generate-shared-prefix"
,
)
parser
.
add_argument
(
"--g
en
-system-prompt-len"
,
"--g
sp
-system-prompt-len"
,
type
=
int
,
default
=
BenchArgs
.
g
en
_system_prompt_len
,
default
=
BenchArgs
.
g
sp
_system_prompt_len
,
help
=
"System prompt length, used"
"only for generate-shared-prefix"
,
)
parser
.
add_argument
(
"--g
en
-question-len"
,
"--g
sp
-question-len"
,
type
=
int
,
default
=
BenchArgs
.
g
en
_question_len
,
default
=
BenchArgs
.
g
sp
_question_len
,
help
=
"Question length, used"
"only for generate-shared-prefix"
,
)
parser
.
add_argument
(
"--g
en
-output-len"
,
"--g
sp
-output-len"
,
type
=
int
,
default
=
BenchArgs
.
g
en
_output_len
,
default
=
BenchArgs
.
g
sp
_output_len
,
help
=
"Target length in tokens for outputs in generated-shared-prefix dataset"
,
)
parser
.
add_argument
(
...
...
python/sglang/bench_serving.py
View file @
8f2c522a
...
...
@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
fixed_output_len
=
args
.
sharegpt_output_len
,
context_len
=
args
.
sharegpt_context_len
,
)
elif
args
.
dataset_name
==
"random"
:
input_requests
=
sample_random_requests
(
...
...
@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
)
elif
args
.
dataset_name
==
"generated-shared-prefix"
:
input_requests
=
sample_generated_shared_prefix_requests
(
num_groups
=
args
.
g
en
_num_groups
,
prompts_per_group
=
args
.
g
en
_prompts_per_group
,
system_prompt_len
=
args
.
g
en
_system_prompt_len
,
question_len
=
args
.
g
en
_question_len
,
output_len
=
args
.
g
en
_output_len
,
num_groups
=
args
.
g
sp
_num_groups
,
prompts_per_group
=
args
.
g
sp
_prompts_per_group
,
system_prompt_len
=
args
.
g
sp
_system_prompt_len
,
question_len
=
args
.
g
sp
_question_len
,
output_len
=
args
.
g
sp
_output_len
,
tokenizer
=
tokenizer
,
)
else
:
...
...
@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
]
=
None
,
context_len
:
Optional
[
int
]
=
None
,
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
...
...
@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
output_len
=
(
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
)
if
prompt_len
<
4
or
output_len
<
4
:
if
prompt_len
<
1
or
output_len
<
1
:
# Prune too short sequences.
continue
if
prompt_len
>
1024
or
(
prompt_len
+
output_len
>
2048
and
fixed_output_len
is
None
):
if
context_len
and
prompt_len
+
output_len
>
context_len
:
# Prune too long sequences.
continue
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
print
(
f
"#Input tokens:
{
np
.
sum
([
x
[
1
]
for
x
in
filtered_dataset
])
}
"
)
...
...
@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
# Create a unique cache filename based on the generation parameters
cache_key
=
(
f
"gen_prefix_
{
args
.
g
en
_num_groups
}
_
{
args
.
g
en
_prompts_per_group
}
_"
f
"
{
args
.
g
en
_system_prompt_len
}
_
{
args
.
g
en
_question_len
}
_
{
args
.
g
en
_output_len
}
_"
f
"gen_
shared_
prefix_
{
args
.
g
sp
_num_groups
}
_
{
args
.
g
sp
_prompts_per_group
}
_"
f
"
{
args
.
g
sp
_system_prompt_len
}
_
{
args
.
g
sp
_question_len
}
_
{
args
.
g
sp
_output_len
}
_"
f
"
{
tokenizer
.
__class__
.
__name__
}
.pkl"
)
return
cache_dir
/
cache_key
...
...
@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
default
=
None
,
help
=
"Output length for each request. Overrides the output length from the ShareGPT dataset."
,
)
parser
.
add_argument
(
"--sharegpt-context-len"
,
type
=
int
,
default
=
None
,
help
=
"The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped."
,
)
parser
.
add_argument
(
"--random-input-len"
,
type
=
int
,
...
...
@@ -1453,49 +1462,49 @@ if __name__ == "__main__":
help
=
"Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params."
,
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler."
,
)
parser
.
add_argument
(
"--lora-name"
,
type
=
str
,
default
=
None
,
help
=
"The name of LoRA adapter"
,
)
group
=
parser
.
add_argument_group
(
"generated-shared-prefix dataset arguments"
)
group
.
add_argument
(
"--g
en
-num-groups"
,
"--g
sp
-num-groups"
,
type
=
int
,
default
=
64
,
help
=
"Number of system prompt groups for generated-shared-prefix dataset"
,
)
group
.
add_argument
(
"--g
en
-prompts-per-group"
,
"--g
sp
-prompts-per-group"
,
type
=
int
,
default
=
16
,
help
=
"Number of prompts per system prompt group for generated-shared-prefix dataset"
,
)
group
.
add_argument
(
"--g
en
-system-prompt-len"
,
"--g
sp
-system-prompt-len"
,
type
=
int
,
default
=
2048
,
help
=
"Target length in tokens for system prompts in generated-shared-prefix dataset"
,
)
group
.
add_argument
(
"--g
en
-question-len"
,
"--g
sp
-question-len"
,
type
=
int
,
default
=
128
,
help
=
"Target length in tokens for questions in generated-shared-prefix dataset"
,
)
group
.
add_argument
(
"--g
en
-output-len"
,
"--g
sp
-output-len"
,
type
=
int
,
default
=
256
,
help
=
"Target length in tokens for outputs in generated-shared-prefix dataset"
,
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler."
,
)
parser
.
add_argument
(
"--lora-name"
,
type
=
str
,
default
=
None
,
help
=
"The name of LoRA adapter"
,
)
args
=
parser
.
parse_args
()
run_benchmark
(
args
)
python/sglang/srt/managers/io_struct.py
View file @
8f2c522a
...
...
@@ -59,6 +59,9 @@ class GenerateReqInput:
return_text_in_logprobs
:
bool
=
False
# Whether to stream output.
stream
:
bool
=
False
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics
:
bool
=
True
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
# LoRA related
...
...
@@ -196,6 +199,7 @@ class GenerateReqInput:
top_logprobs_num
=
self
.
top_logprobs_num
[
i
],
return_text_in_logprobs
=
self
.
return_text_in_logprobs
,
stream
=
self
.
stream
,
log_metrics
=
self
.
log_metrics
,
modalities
=
self
.
modalities
[
i
]
if
self
.
modalities
else
None
,
lora_path
=
self
.
lora_path
[
i
]
if
self
.
lora_path
is
not
None
else
None
,
)
...
...
@@ -243,6 +247,8 @@ class EmbeddingReqInput:
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Dummy input embeds for compatibility
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics
:
bool
=
True
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
8f2c522a
...
...
@@ -631,7 +631,8 @@ class Scheduler:
if
len
(
req
.
origin_input_ids
)
>
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
"the max context length. Truncated. "
f
"
{
len
(
req
.
origin_input_ids
)
=
}
,
{
self
.
max_req_input_len
=
}
."
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
8f2c522a
...
...
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
get_zmq_socket
,
kill_process_tree
,
)
from
sglang.utils
import
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -640,7 +641,9 @@ class TokenizerManager:
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
self
.
handle_loop
()))
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
print_exception_wrapper
(
self
.
handle_loop
))
)
# We cannot add signal handler when the tokenizer manager is not in
# the main thread due to the CPython limitation.
...
...
@@ -653,7 +656,9 @@ class TokenizerManager:
"not in the main thread. This disables graceful shutdown of the "
"tokenizer manager when SIGTERM is received."
)
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
self
.
sigterm_watchdog
()))
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
print_exception_wrapper
(
self
.
sigterm_watchdog
))
)
async
def
sigterm_watchdog
(
self
):
while
not
self
.
gracefully_exit
:
...
...
@@ -738,9 +743,13 @@ class TokenizerManager:
state
.
finished
=
recv_obj
.
finished_reasons
[
i
]
is
not
None
state
.
event
.
set
()
if
self
.
enable_metrics
:
if
self
.
enable_metrics
and
state
.
obj
.
log_metrics
:
self
.
collect_metrics
(
state
,
recv_obj
,
i
)
if
self
.
dump_requests_folder
and
state
.
finished
:
if
(
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
):
self
.
dump_requests
(
state
,
out_dict
)
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
...
...
@@ -887,20 +896,38 @@ class TokenizerManager:
)
if
len
(
self
.
dump_request_list
)
>=
self
.
dump_requests_threshold
:
filename
=
os
.
path
.
join
(
self
.
dump_requests_folder
,
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
+
".pkl"
,
)
logger
.
info
(
f
"Dump
{
len
(
self
.
dump_request_list
)
}
requests to
{
filename
}
"
)
to_dump
=
self
.
dump_request_list
self
.
dump_request_list
=
[]
def
background_task
():
os
.
makedirs
(
self
.
dump_requests_folder
,
exist_ok
=
True
)
current_time
=
datetime
.
now
()
filename
=
current_time
.
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
+
".pkl"
with
open
(
os
.
path
.
join
(
self
.
dump_requests_folder
,
filename
),
"wb"
)
as
f
:
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
to_dump
,
f
)
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
async
def
print_exception_wrapper
(
func
):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try
:
await
func
()
except
Exception
:
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"TokenizerManager hit an exception:
{
traceback
}
"
)
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
sys
.
exit
(
1
)
class
SignalHandler
:
def
__init__
(
self
,
tokenizer_manager
):
self
.
tokenizer_manager
=
tokenizer_manager
...
...
python/sglang/srt/server.py
View file @
8f2c522a
...
...
@@ -135,9 +135,13 @@ async def health_generate(request: Request) -> Response:
sampling_params
=
{
"max_new_tokens"
:
1
,
"temperature"
:
0.7
}
if
tokenizer_manager
.
is_generation
:
gri
=
GenerateReqInput
(
input_ids
=
[
0
],
sampling_params
=
sampling_params
)
gri
=
GenerateReqInput
(
input_ids
=
[
0
],
sampling_params
=
sampling_params
,
log_metrics
=
False
)
else
:
gri
=
EmbeddingReqInput
(
input_ids
=
[
0
],
sampling_params
=
sampling_params
)
gri
=
EmbeddingReqInput
(
input_ids
=
[
0
],
sampling_params
=
sampling_params
,
log_metrics
=
False
)
try
:
async
for
_
in
tokenizer_manager
.
generate_request
(
gri
,
request
):
...
...
python/sglang/test/test_utils.py
View file @
8f2c522a
...
...
@@ -560,6 +560,7 @@ def run_bench_serving(
tokenizer
=
tokenizer
,
num_prompts
=
num_prompts
,
sharegpt_output_len
=
None
,
sharegpt_context_len
=
None
,
random_input_len
=
random_input_len
,
random_output_len
=
random_output_len
,
random_range_ratio
=
0.0
,
...
...
test/srt/test_moe_ep.py
View file @
8f2c522a
...
...
@@ -44,7 +44,7 @@ class TestEpMoE(unittest.TestCase):
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.5
self
.
assert
Greater
(
metrics
[
"score"
]
,
0.5
)
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
...
...
@@ -56,7 +56,7 @@ class TestEpMoE(unittest.TestCase):
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.8
self
.
assert
Greater
(
metrics
[
"score"
]
,
0.8
)
class
TestEpMoEFP8
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment