Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4efe2c57
"tests/python/common/test_heterograph.py" did not exist on "1c91f460d3e534ed549bf600820d7cc31a0981ff"
Unverified
Commit
4efe2c57
authored
Sep 10, 2025
by
Lzhang-hub
Committed by
GitHub
Sep 10, 2025
Browse files
support vlm model spec bench (#10173)
parent
5be8c2f7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
15 deletions
+44
-15
scripts/playground/bench_speculative.py
scripts/playground/bench_speculative.py
+44
-15
No files found.
scripts/playground/bench_speculative.py
View file @
4efe2c57
...
...
@@ -16,8 +16,14 @@ from types import SimpleNamespace
import
numpy
as
np
import
requests
from
transformers
import
AutoTokenizer
from
sglang.bench_serving
import
DatasetRow
,
benchmark
,
set_global_args
from
sglang.bench_serving
import
(
DatasetRow
,
benchmark
,
sample_mmmu_requests
,
set_global_args
,
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
...
@@ -48,20 +54,33 @@ class FakeTokenizer:
return
[]
def
send_one_batch
(
base_url
,
num_prompts
,
batch_size
):
padded_prompts
=
(
prompts
*
((
num_prompts
+
len
(
prompts
)
-
1
)
//
len
(
prompts
)))[
:
num_prompts
]
def
send_one_batch
(
base_url
,
num_prompts
,
batch_size
,
tokenizer
,
is_multimodal
):
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
input_requests
:
List
[
DatasetRow
]
=
[
DatasetRow
(
p
,
0
,
512
)
for
p
in
padded_prompts
]
if
is_multimodal
:
input_requests
=
sample_mmmu_requests
(
num_prompts
,
tokenizer
,
512
,
apply_chat_template
=
False
,
)
backend
=
"sglang-oai-chat"
api_url
=
f
"
{
base_url
}
/v1/chat/completions"
else
:
padded_prompts
=
(
prompts
*
((
num_prompts
+
len
(
prompts
)
-
1
)
//
len
(
prompts
)))[
:
num_prompts
]
input_requests
:
List
[
DatasetRow
]
=
[
DatasetRow
(
p
,
0
,
512
)
for
p
in
padded_prompts
]
backend
=
"sglang"
api_url
=
f
"
{
base_url
}
/generate"
# We need to set some dummy values in order to call `benchmark` below.
args
=
SimpleNamespace
(
disable_ignore_eos
=
False
,
disable_stream
=
False
,
return_logprob
=
False
,
backend
=
"sglang"
,
backend
=
backend
,
dataset_name
=
"custom"
,
num_prompts
=
None
,
sharegpt_output_len
=
None
,
...
...
@@ -73,13 +92,12 @@ def send_one_batch(base_url, num_prompts, batch_size):
output_details
=
False
,
)
set_global_args
(
args
)
tokenizer
=
FakeTokenizer
()
# Run benchmark
results
=
asyncio
.
run
(
benchmark
(
backend
=
"sglang"
,
api_url
=
f
"
{
base_url
}
/generate"
,
backend
=
backend
,
api_url
=
api_url
,
base_url
=
base_url
,
model_id
=
"default"
,
tokenizer
=
tokenizer
,
...
...
@@ -143,8 +161,6 @@ def main(args, server_args):
other_args
=
[]
else
:
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-num-steps"
,
steps
,
"--speculative-eagle-topk"
,
...
...
@@ -157,6 +173,8 @@ def main(args, server_args):
[
"--speculative-draft-model-path"
,
server_args
.
speculative_draft_model_path
,
"--speculative-algorithm"
,
server_args
.
speculative_algorithm
,
]
)
...
...
@@ -207,13 +225,23 @@ def main(args, server_args):
},
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
)
try
:
# Warmup
send_one_batch
(
base_url
,
batch_size
,
batch_size
)
send_one_batch
(
base_url
,
batch_size
,
batch_size
,
tokenizer
,
args
.
is_multimodal
)
# Benchmark
acc_length
,
step_time
,
speed
,
completion_tokens
=
send_one_batch
(
base_url
,
max
(
args
.
num_prompts
,
batch_size
),
batch_size
base_url
,
max
(
args
.
num_prompts
,
batch_size
),
batch_size
,
tokenizer
,
args
.
is_multimodal
,
)
finally
:
kill_process_tree
(
process
.
pid
)
...
...
@@ -273,6 +301,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--start"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--end"
,
type
=
int
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
"output.jsonl"
)
parser
.
add_argument
(
"--is-multimodal"
,
action
=
"store_true"
,
default
=
False
)
args
=
parser
.
parse_args
()
server_args
:
ServerArgs
=
ServerArgs
.
from_cli_args
(
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment