Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d50e36a7
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "80bc0c0ced1566549dec606f5069e909b86e86b0"
Unverified
Commit
d50e36a7
authored
Apr 30, 2025
by
Yi Zhang
Committed by
GitHub
Apr 29, 2025
Browse files
support vlm benchmark profile (#5905)
parent
8fefdd32
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
3 deletions
+70
-3
benchmark/mmmu/bench_sglang.py
benchmark/mmmu/bench_sglang.py
+62
-3
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+8
-0
No files found.
benchmark/mmmu/bench_sglang.py
View file @
d50e36a7
...
@@ -10,8 +10,14 @@ The eval output will be logged
...
@@ -10,8 +10,14 @@ The eval output will be logged
"""
"""
import
argparse
import
argparse
import
asyncio
import
sys
import
time
import
time
import
traceback
from
dataclasses
import
dataclass
,
field
from
typing
import
List
import
aiohttp
import
openai
import
openai
from
data_utils
import
save_json
from
data_utils
import
save_json
from
eval_utils
import
(
from
eval_utils
import
(
...
@@ -25,8 +31,41 @@ from tqdm import tqdm
...
@@ -25,8 +31,41 @@ from tqdm import tqdm
from
sglang.test.test_utils
import
add_common_sglang_args_and_parse
from
sglang.test.test_utils
import
add_common_sglang_args_and_parse
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
20
*
60
*
60
)
def
eval_mmmu
(
args
):
@
dataclass
class
RequestFuncOutput
:
generated_text
:
List
[
str
]
=
field
(
default_factory
=
list
)
prompt_len
:
List
[
int
]
=
field
(
default_factory
=
list
)
output_len
:
List
[
int
]
=
field
(
default_factory
=
list
)
latency
:
List
[
float
]
=
field
(
default_factory
=
list
)
ttft
:
List
[
float
]
=
field
(
default_factory
=
list
)
itl
:
List
[
float
]
=
field
(
default_factory
=
list
)
# List of inter-token latencies
success
:
bool
=
False
error
:
str
=
""
async
def
async_request_profile
(
api_url
:
str
)
->
RequestFuncOutput
:
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
output
=
RequestFuncOutput
()
try
:
async
with
session
.
post
(
url
=
api_url
)
as
response
:
if
response
.
status
==
200
:
output
.
success
=
True
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
return
output
async
def
eval_mmmu
(
args
):
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
out_samples
=
dict
()
out_samples
=
dict
()
...
@@ -38,9 +77,22 @@ def eval_mmmu(args):
...
@@ -38,9 +77,22 @@ def eval_mmmu(args):
answer_dict
=
{}
answer_dict
=
{}
# had to use an openai server, since SglImage doesn't support image data
# had to use an openai server, since SglImage doesn't support image data
client
=
openai
.
Client
(
api_key
=
"sk"
,
base_url
=
f
"http://127.0.0.1:
{
args
.
port
}
/v1"
)
base_url
=
f
"http://127.0.0.1:
{
args
.
port
}
"
client
=
openai
.
Client
(
api_key
=
"sk"
,
base_url
=
f
"
{
base_url
}
/v1"
)
start
=
time
.
time
()
start
=
time
.
time
()
if
args
.
profile
:
print
(
"Starting profiler..."
)
profile_output
=
await
async_request_profile
(
api_url
=
f
"
{
base_url
}
/start_profile"
)
if
profile_output
.
success
:
print
(
"Profiler started"
)
if
args
.
profile
:
samples
=
samples
[:
args
.
profile_number
]
for
i
,
sample
in
enumerate
(
tqdm
(
samples
)):
for
i
,
sample
in
enumerate
(
tqdm
(
samples
)):
prompt
=
sample
[
"final_input_prompt"
]
prompt
=
sample
[
"final_input_prompt"
]
prefix
=
prompt
.
split
(
"<"
)[
0
]
prefix
=
prompt
.
split
(
"<"
)[
0
]
...
@@ -49,6 +101,7 @@ def eval_mmmu(args):
...
@@ -49,6 +101,7 @@ def eval_mmmu(args):
assert
image
is
not
None
assert
image
is
not
None
image_path
=
sample
[
"image_path"
]
image_path
=
sample
[
"image_path"
]
# TODO: batch
# TODO: batch
response
=
client
.
chat
.
completions
.
create
(
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
model
=
"default"
,
messages
=
[
messages
=
[
...
@@ -77,6 +130,12 @@ def eval_mmmu(args):
...
@@ -77,6 +130,12 @@ def eval_mmmu(args):
response
=
response
.
choices
[
0
].
message
.
content
response
=
response
.
choices
[
0
].
message
.
content
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
if
args
.
profile
:
print
(
"Stopping profiler..."
)
profile_output
=
await
async_request_profile
(
api_url
=
f
"
{
base_url
}
/stop_profile"
)
if
profile_output
.
success
:
print
(
"Profiler stopped"
)
print
(
f
"Benchmark time:
{
time
.
time
()
-
start
}
"
)
print
(
f
"Benchmark time:
{
time
.
time
()
-
start
}
"
)
args
.
output_path
=
f
"./val_sglang.json"
args
.
output_path
=
f
"./val_sglang.json"
...
@@ -89,4 +148,4 @@ if __name__ == "__main__":
...
@@ -89,4 +148,4 @@ if __name__ == "__main__":
EvalArgs
.
add_cli_args
(
parser
)
EvalArgs
.
add_cli_args
(
parser
)
args
=
add_common_sglang_args_and_parse
(
parser
)
args
=
add_common_sglang_args_and_parse
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
eval_mmmu
(
args
)
asyncio
.
run
(
eval_mmmu
(
args
)
)
benchmark/mmmu/eval_utils.py
View file @
d50e36a7
...
@@ -33,6 +33,8 @@ class EvalArgs:
...
@@ -33,6 +33,8 @@ class EvalArgs:
prompt_format_file
:
str
=
"prompt_format.yaml"
prompt_format_file
:
str
=
"prompt_format.yaml"
dataset_path
:
str
=
"MMMU/MMMU"
dataset_path
:
str
=
"MMMU/MMMU"
extra_request_body
:
Optional
[
str
]
=
None
extra_request_body
:
Optional
[
str
]
=
None
profile
:
bool
=
False
profile_number
:
int
=
5
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -65,6 +67,12 @@ class EvalArgs:
...
@@ -65,6 +67,12 @@ class EvalArgs:
help
=
"Append given JSON object to the request payload. You can use this to specify"
help
=
"Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params."
,
"additional generate params like sampling params."
,
)
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"enable mmmu profile"
)
parser
.
add_argument
(
"--profile-number"
,
type
=
int
,
default
=
EvalArgs
.
profile_number
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment