Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ed0a3dd5
Unverified
Commit
ed0a3dd5
authored
Aug 08, 2025
by
Zaili Wang
Committed by
GitHub
Aug 07, 2025
Browse files
Enhancements for bench_one_batch (#8703)
Co-authored-by:
root
<
root@gnr630186.jf.intel.com
>
parent
2e901e89
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
17 deletions
+113
-17
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+113
-17
No files found.
python/sglang/bench_one_batch.py
View file @
ed0a3dd5
...
@@ -43,6 +43,7 @@ I'm going to the park
...
@@ -43,6 +43,7 @@ I'm going to the park
"""
"""
import
argparse
import
argparse
import
copy
import
dataclasses
import
dataclasses
import
itertools
import
itertools
import
json
import
json
...
@@ -84,12 +85,14 @@ class BenchArgs:
...
@@ -84,12 +85,14 @@ class BenchArgs:
batch_size
:
Tuple
[
int
]
=
(
1
,)
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
prompt_filename
:
str
=
""
result_filename
:
str
=
"result.jsonl"
result_filename
:
str
=
"result.jsonl"
correctness_test
:
bool
=
False
correctness_test
:
bool
=
False
# This is only used for correctness test
# This is only used for correctness test
cut_len
:
int
=
4
cut_len
:
int
=
4
log_decode_step
:
int
=
0
log_decode_step
:
int
=
0
profile
:
bool
=
False
profile
:
bool
=
False
profile_record_shapes
:
bool
=
False
profile_filename_prefix
:
str
=
"profile"
profile_filename_prefix
:
str
=
"profile"
@
staticmethod
@
staticmethod
...
@@ -104,6 +107,9 @@ class BenchArgs:
...
@@ -104,6 +107,9 @@ class BenchArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
)
)
parser
.
add_argument
(
"--prompt-filename"
,
type
=
str
,
default
=
BenchArgs
.
prompt_filename
)
parser
.
add_argument
(
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
)
)
...
@@ -118,6 +124,11 @@ class BenchArgs:
...
@@ -118,6 +124,11 @@ class BenchArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"Use Torch Profiler."
"--profile"
,
action
=
"store_true"
,
help
=
"Use Torch Profiler."
)
)
parser
.
add_argument
(
"--profile-record-shapes"
,
action
=
"store_true"
,
help
=
"Record tensor shapes in profiling results."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--profile-filename-prefix"
,
"--profile-filename-prefix"
,
type
=
str
,
type
=
str
,
...
@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
...
@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
return
model_runner
,
tokenizer
return
model_runner
,
tokenizer
def
prepare_inputs_for_correctness_test
(
bench_args
,
tokenizer
):
def
prepare_inputs_for_correctness_test
(
bench_args
,
tokenizer
,
custom_prompts
):
prompts
=
[
prompts
=
(
"The capital of France is"
,
custom_prompts
"The capital of the United Kindom is"
,
if
custom_prompts
"Today is a sunny day and I like"
,
else
[
]
"The capital of France is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
]
)
input_ids
=
[
tokenizer
.
encode
(
p
)
for
p
in
prompts
]
input_ids
=
[
tokenizer
.
encode
(
p
)
for
p
in
prompts
]
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
...
@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
...
@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
return
reqs
return
reqs
def
prepare_synthetic_inputs_for_latency_test
(
batch_size
,
input_len
):
def
prepare_synthetic_inputs_for_latency_test
(
input_ids
=
np
.
random
.
randint
(
0
,
10000
,
(
batch_size
,
input_len
),
dtype
=
np
.
int32
)
batch_size
,
input_len
,
custom_inputs
=
None
):
input_ids
=
(
custom_inputs
if
custom_inputs
else
np
.
random
.
randint
(
0
,
10000
,
(
batch_size
,
input_len
),
dtype
=
np
.
int32
)
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
max_new_tokens
=
BenchArgs
.
output_len
,
max_new_tokens
=
BenchArgs
.
output_len
,
...
@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
...
@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
)
)
def
_read_prompts_from_file
(
prompt_file
,
rank_print
):
"""Read custom prompts from the file specified by `--prompt-filename`."""
if
not
prompt_file
:
return
[]
if
not
os
.
path
.
exists
(
prompt_file
):
rank_print
(
f
"Custom prompt file
{
prompt_file
}
not found. Using default inputs..."
)
return
[]
with
open
(
prompt_file
,
"r"
)
as
pf
:
return
pf
.
readlines
()
def
_save_profile_trace_results
(
profiler
,
filename
):
parent_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
filename
))
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
profiler
.
export_chrome_trace
(
filename
)
print
(
profiler
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"self_cpu_time_total"
)
)
def
correctness_test
(
def
correctness_test
(
server_args
,
server_args
,
port_args
,
port_args
,
...
@@ -298,7 +343,10 @@ def correctness_test(
...
@@ -298,7 +343,10 @@ def correctness_test(
model_runner
,
tokenizer
=
load_model
(
server_args
,
port_args
,
tp_rank
)
model_runner
,
tokenizer
=
load_model
(
server_args
,
port_args
,
tp_rank
)
# Prepare inputs
# Prepare inputs
input_ids
,
reqs
=
prepare_inputs_for_correctness_test
(
bench_args
,
tokenizer
)
custom_prompts
=
_read_prompts_from_file
(
bench_args
.
prompt_filename
,
rank_print
)
input_ids
,
reqs
=
prepare_inputs_for_correctness_test
(
bench_args
,
tokenizer
,
custom_prompts
)
rank_print
(
f
"
\n
{
input_ids
=
}
\n
"
)
rank_print
(
f
"
\n
{
input_ids
=
}
\n
"
)
if
bench_args
.
cut_len
>
0
:
if
bench_args
.
cut_len
>
0
:
...
@@ -344,6 +392,7 @@ def latency_test_run_once(
...
@@ -344,6 +392,7 @@ def latency_test_run_once(
device
,
device
,
log_decode_step
,
log_decode_step
,
profile
,
profile
,
profile_record_shapes
,
profile_filename_prefix
,
profile_filename_prefix
,
):
):
max_batch_size
=
model_runner
.
max_total_num_tokens
//
(
input_len
+
output_len
)
max_batch_size
=
model_runner
.
max_total_num_tokens
//
(
input_len
+
output_len
)
...
@@ -374,6 +423,7 @@ def latency_test_run_once(
...
@@ -374,6 +423,7 @@ def latency_test_run_once(
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
],
with_stack
=
True
,
with_stack
=
True
,
record_shapes
=
profile_record_shapes
,
)
)
profiler
.
start
()
profiler
.
start
()
...
@@ -391,10 +441,30 @@ def latency_test_run_once(
...
@@ -391,10 +441,30 @@ def latency_test_run_once(
measurement_results
[
"prefill_latency"
]
=
prefill_latency
measurement_results
[
"prefill_latency"
]
=
prefill_latency
measurement_results
[
"prefill_throughput"
]
=
throughput
measurement_results
[
"prefill_throughput"
]
=
throughput
if
profile
:
profiler
.
stop
()
profile_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_prefill.trace.json.gz"
_save_profile_trace_results
(
profiler
,
profile_filename
)
rank_print
(
f
"torch profiler chrome trace for prefill saved to
{
profile_filename
}
"
)
# Decode
# Decode
decode_latencies
=
[]
decode_latencies
=
[]
for
i
in
range
(
output_len
-
1
):
for
i
in
range
(
output_len
-
1
):
synchronize
(
device
)
synchronize
(
device
)
if
profile
and
i
==
output_len
/
2
:
profiler
=
None
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
with_stack
=
True
,
record_shapes
=
profile_record_shapes
,
)
profiler
.
start
()
tic
=
time
.
perf_counter
()
tic
=
time
.
perf_counter
()
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
synchronize
(
device
)
synchronize
(
device
)
...
@@ -407,13 +477,13 @@ def latency_test_run_once(
...
@@ -407,13 +477,13 @@ def latency_test_run_once(
f
"Decode
{
i
}
. Batch size:
{
batch_size
}
, latency:
{
latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
f
"Decode
{
i
}
. Batch size:
{
batch_size
}
, latency:
{
latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
)
if
profile
:
if
profile
and
i
==
output_len
/
2
:
profiler
.
stop
()
profiler
.
stop
()
profile_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
.trace.json.gz"
profile_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_decode
.trace.json.gz"
parent_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
profile_filename
)
)
_save_profile_trace_results
(
profiler
,
profile_filename
)
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
rank_print
(
profiler
.
export_
chrome
_
trace
(
profile_filename
)
f
"torch profiler
chrome
trace
for decoding 1 token saved to
{
profile_filename
}
"
rank_print
(
f
"torch profiler chrome trace saved to
{
profile_filename
}
"
)
)
# Record decode timing from 2nd output
# Record decode timing from 2nd output
if
output_len
>
1
:
if
output_len
>
1
:
...
@@ -469,17 +539,42 @@ def latency_test(
...
@@ -469,17 +539,42 @@ def latency_test(
server_args
.
device
,
server_args
.
device
,
log_decode_step
=
0
,
log_decode_step
=
0
,
profile
=
False
,
profile
=
False
,
profile_record_shapes
=
False
,
profile_filename_prefix
=
""
,
# not used
profile_filename_prefix
=
""
,
# not used
)
)
rank_print
(
"Benchmark ..."
)
rank_print
(
"Benchmark ..."
)
custom_inputs
=
_read_prompts_from_file
(
bench_args
.
prompt_filename
,
rank_print
)
custom_inputs
=
[
tokenizer
.
encode
(
p
.
strip
())
for
p
in
custom_inputs
]
custom_input_len
=
len
(
custom_inputs
)
# Run the sweep
# Run the sweep
result_list
=
[]
result_list
=
[]
for
bs
,
il
,
ol
in
itertools
.
product
(
for
bs
,
il
,
ol
in
itertools
.
product
(
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
):
):
reqs
=
prepare_synthetic_inputs_for_latency_test
(
bs
,
il
)
bs_aligned_inputs
=
[]
if
custom_inputs
:
if
custom_input_len
==
bs
:
bs_aligned_inputs
=
custom_inputs
elif
custom_input_len
>
bs
:
rank_print
(
f
"Custom input size (
{
custom_input_len
}
) is larger than batch_size (
{
bs
}
). "
f
"Using the first
{
bs
}
prompts."
)
bs_aligned_inputs
=
copy
.
deepcopy
(
custom_inputs
[:
bs
])
else
:
rank_print
(
f
"Custom input size (
{
custom_input_len
}
) is smaller than batch_size (
{
bs
}
). "
f
"Pad to the desired batch_size with the last prompt."
)
bs_aligned_inputs
=
copy
.
deepcopy
(
custom_inputs
)
bs_aligned_inputs
.
extend
(
[
bs_aligned_inputs
[
-
1
]]
*
(
bs
-
custom_input_len
)
)
reqs
=
prepare_synthetic_inputs_for_latency_test
(
bs
,
il
,
bs_aligned_inputs
)
ret
=
latency_test_run_once
(
ret
=
latency_test_run_once
(
bench_args
.
run_name
,
bench_args
.
run_name
,
model_runner
,
model_runner
,
...
@@ -491,6 +586,7 @@ def latency_test(
...
@@ -491,6 +586,7 @@ def latency_test(
server_args
.
device
,
server_args
.
device
,
bench_args
.
log_decode_step
,
bench_args
.
log_decode_step
,
bench_args
.
profile
if
tp_rank
==
0
else
None
,
bench_args
.
profile
if
tp_rank
==
0
else
None
,
bench_args
.
profile_record_shapes
if
tp_rank
==
0
else
None
,
bench_args
.
profile_filename_prefix
,
bench_args
.
profile_filename_prefix
,
)
)
if
ret
is
not
None
:
if
ret
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment