Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d50e36a7
Unverified
Commit
d50e36a7
authored
Apr 30, 2025
by
Yi Zhang
Committed by
GitHub
Apr 29, 2025
Browse files
support vlm benchmark profile (#5905)
parent
8fefdd32
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
3 deletions
+70
-3
benchmark/mmmu/bench_sglang.py
benchmark/mmmu/bench_sglang.py
+62
-3
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+8
-0
No files found.
benchmark/mmmu/bench_sglang.py
View file @
d50e36a7
...
@@ -10,8 +10,14 @@ The eval output will be logged
...
@@ -10,8 +10,14 @@ The eval output will be logged
"""
"""
import
argparse
import
argparse
import
asyncio
import
sys
import
time
import
time
import
traceback
from
dataclasses
import
dataclass
,
field
from
typing
import
List
import
aiohttp
import
openai
import
openai
from
data_utils
import
save_json
from
data_utils
import
save_json
from
eval_utils
import
(
from
eval_utils
import
(
...
@@ -25,8 +31,41 @@ from tqdm import tqdm
...
@@ -25,8 +31,41 @@ from tqdm import tqdm
from
sglang.test.test_utils
import
add_common_sglang_args_and_parse
from
sglang.test.test_utils
import
add_common_sglang_args_and_parse
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
20
*
60
*
60
)
def
eval_mmmu
(
args
):
@
dataclass
class
RequestFuncOutput
:
generated_text
:
List
[
str
]
=
field
(
default_factory
=
list
)
prompt_len
:
List
[
int
]
=
field
(
default_factory
=
list
)
output_len
:
List
[
int
]
=
field
(
default_factory
=
list
)
latency
:
List
[
float
]
=
field
(
default_factory
=
list
)
ttft
:
List
[
float
]
=
field
(
default_factory
=
list
)
itl
:
List
[
float
]
=
field
(
default_factory
=
list
)
# List of inter-token latencies
success
:
bool
=
False
error
:
str
=
""
async
def
async_request_profile
(
api_url
:
str
)
->
RequestFuncOutput
:
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
output
=
RequestFuncOutput
()
try
:
async
with
session
.
post
(
url
=
api_url
)
as
response
:
if
response
.
status
==
200
:
output
.
success
=
True
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
return
output
async
def
eval_mmmu
(
args
):
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
out_samples
=
dict
()
out_samples
=
dict
()
...
@@ -38,9 +77,22 @@ def eval_mmmu(args):
...
@@ -38,9 +77,22 @@ def eval_mmmu(args):
answer_dict
=
{}
answer_dict
=
{}
# had to use an openai server, since SglImage doesn't support image data
# had to use an openai server, since SglImage doesn't support image data
client
=
openai
.
Client
(
api_key
=
"sk"
,
base_url
=
f
"http://127.0.0.1:
{
args
.
port
}
/v1"
)
base_url
=
f
"http://127.0.0.1:
{
args
.
port
}
"
client
=
openai
.
Client
(
api_key
=
"sk"
,
base_url
=
f
"
{
base_url
}
/v1"
)
start
=
time
.
time
()
start
=
time
.
time
()
if
args
.
profile
:
print
(
"Starting profiler..."
)
profile_output
=
await
async_request_profile
(
api_url
=
f
"
{
base_url
}
/start_profile"
)
if
profile_output
.
success
:
print
(
"Profiler started"
)
if
args
.
profile
:
samples
=
samples
[:
args
.
profile_number
]
for
i
,
sample
in
enumerate
(
tqdm
(
samples
)):
for
i
,
sample
in
enumerate
(
tqdm
(
samples
)):
prompt
=
sample
[
"final_input_prompt"
]
prompt
=
sample
[
"final_input_prompt"
]
prefix
=
prompt
.
split
(
"<"
)[
0
]
prefix
=
prompt
.
split
(
"<"
)[
0
]
...
@@ -49,6 +101,7 @@ def eval_mmmu(args):
...
@@ -49,6 +101,7 @@ def eval_mmmu(args):
assert
image
is
not
None
assert
image
is
not
None
image_path
=
sample
[
"image_path"
]
image_path
=
sample
[
"image_path"
]
# TODO: batch
# TODO: batch
response
=
client
.
chat
.
completions
.
create
(
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
model
=
"default"
,
messages
=
[
messages
=
[
...
@@ -77,6 +130,12 @@ def eval_mmmu(args):
...
@@ -77,6 +130,12 @@ def eval_mmmu(args):
response
=
response
.
choices
[
0
].
message
.
content
response
=
response
.
choices
[
0
].
message
.
content
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
if
args
.
profile
:
print
(
"Stopping profiler..."
)
profile_output
=
await
async_request_profile
(
api_url
=
f
"
{
base_url
}
/stop_profile"
)
if
profile_output
.
success
:
print
(
"Profiler stopped"
)
print
(
f
"Benchmark time:
{
time
.
time
()
-
start
}
"
)
print
(
f
"Benchmark time:
{
time
.
time
()
-
start
}
"
)
args
.
output_path
=
f
"./val_sglang.json"
args
.
output_path
=
f
"./val_sglang.json"
...
@@ -89,4 +148,4 @@ if __name__ == "__main__":
...
@@ -89,4 +148,4 @@ if __name__ == "__main__":
EvalArgs
.
add_cli_args
(
parser
)
EvalArgs
.
add_cli_args
(
parser
)
args
=
add_common_sglang_args_and_parse
(
parser
)
args
=
add_common_sglang_args_and_parse
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
eval_mmmu
(
args
)
asyncio
.
run
(
eval_mmmu
(
args
)
)
benchmark/mmmu/eval_utils.py
View file @
d50e36a7
...
@@ -33,6 +33,8 @@ class EvalArgs:
...
@@ -33,6 +33,8 @@ class EvalArgs:
prompt_format_file
:
str
=
"prompt_format.yaml"
prompt_format_file
:
str
=
"prompt_format.yaml"
dataset_path
:
str
=
"MMMU/MMMU"
dataset_path
:
str
=
"MMMU/MMMU"
extra_request_body
:
Optional
[
str
]
=
None
extra_request_body
:
Optional
[
str
]
=
None
profile
:
bool
=
False
profile_number
:
int
=
5
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -65,6 +67,12 @@ class EvalArgs:
...
@@ -65,6 +67,12 @@ class EvalArgs:
help
=
"Append given JSON object to the request payload. You can use this to specify"
help
=
"Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params."
,
"additional generate params like sampling params."
,
)
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"enable mmmu profile"
)
parser
.
add_argument
(
"--profile-number"
,
type
=
int
,
default
=
EvalArgs
.
profile_number
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment