Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
6ea1e6ac
"vscode:/vscode.git/clone" did not exist on "01dffdb557e00e3761111170ee6b11c52165fed6"
Unverified
Commit
6ea1e6ac
authored
May 02, 2025
by
XinyuanTong
Committed by
GitHub
May 02, 2025
Browse files
Support MMMU benchmark for InternVL (#5968)
parent
3409aaab
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
12 deletions
+139
-12
benchmark/mmmu/bench_hf.py
benchmark/mmmu/bench_hf.py
+45
-12
benchmark/mmmu/internvl_utils.py
benchmark/mmmu/internvl_utils.py
+94
-0
No files found.
benchmark/mmmu/bench_hf.py
View file @
6ea1e6ac
...
@@ -17,6 +17,13 @@ from transformers import AutoModel, AutoProcessor, GenerationConfig
...
@@ -17,6 +17,13 @@ from transformers import AutoModel, AutoProcessor, GenerationConfig
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
eval_mmmu
(
args
):
def
eval_mmmu
(
args
):
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
sampling_params
=
get_sampling_params
(
eval_args
)
generation_config
=
GenerationConfig
(
max_new_tokens
=
sampling_params
[
"max_new_tokens"
],
do_sample
=
False
,
)
try
:
try
:
from
transformers
import
AutoModelForImageTextToText
from
transformers
import
AutoModelForImageTextToText
...
@@ -27,12 +34,28 @@ def eval_mmmu(args):
...
@@ -27,12 +34,28 @@ def eval_mmmu(args):
)
)
except
Exception
as
first_exception
:
except
Exception
as
first_exception
:
try
:
try
:
model
=
AutoModel
.
from_pretrained
(
# check if the model is belongs to internvl
args
.
model_path
,
if
"InternVL"
in
args
.
model_path
:
torch_dtype
=
"auto"
,
from
internvl_utils
import
load_image
trust_remote_code
=
True
,
from
transformers
import
AutoTokenizer
init_tts
=
False
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_path
)
model
=
AutoModel
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
"auto"
,
trust_remote_code
=
True
,
)
generation_config_internvl
=
dict
(
max_new_tokens
=
sampling_params
[
"max_new_tokens"
],
do_sample
=
False
)
else
:
model
=
AutoModel
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
"auto"
,
trust_remote_code
=
True
,
init_tts
=
False
,
)
except
Exception
as
second_exception
:
except
Exception
as
second_exception
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Failed to load model: First attempt failed with
{
first_exception
}
, "
f
"Failed to load model: First attempt failed with
{
first_exception
}
, "
...
@@ -48,12 +71,6 @@ def eval_mmmu(args):
...
@@ -48,12 +71,6 @@ def eval_mmmu(args):
samples
=
prepare_samples
(
eval_args
)
samples
=
prepare_samples
(
eval_args
)
out_samples
=
dict
()
out_samples
=
dict
()
sampling_params
=
get_sampling_params
(
eval_args
)
generation_config
=
GenerationConfig
(
max_new_tokens
=
sampling_params
[
"max_new_tokens"
],
do_sample
=
False
,
)
answer_dict
=
{}
answer_dict
=
{}
for
sample
in
tqdm
(
samples
):
for
sample
in
tqdm
(
samples
):
prompt
=
sample
[
"final_input_prompt"
]
prompt
=
sample
[
"final_input_prompt"
]
...
@@ -61,6 +78,22 @@ def eval_mmmu(args):
...
@@ -61,6 +78,22 @@ def eval_mmmu(args):
prefix
=
prompt
.
split
(
"<"
)[
0
]
prefix
=
prompt
.
split
(
"<"
)[
0
]
suffix
=
prompt
.
split
(
">"
)[
1
]
suffix
=
prompt
.
split
(
">"
)[
1
]
assert
image
is
not
None
assert
image
is
not
None
if
"InternVL"
in
args
.
model_path
:
pixel_values
=
load_image
(
sample
[
"image_path"
]).
to
(
torch
.
bfloat16
).
cuda
()
contents
=
""
if
prefix
:
contents
+=
prefix
contents
+=
"<image>
\n
"
if
suffix
:
contents
+=
suffix
response
=
model
.
chat
(
tokenizer
,
pixel_values
,
contents
,
generation_config_internvl
)
print
(
f
"response:
{
response
}
"
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
continue
contents
=
[]
contents
=
[]
if
prefix
:
if
prefix
:
contents
+=
[{
"type"
:
"text"
,
"text"
:
prefix
}]
contents
+=
[{
"type"
:
"text"
,
"text"
:
prefix
}]
...
...
benchmark/mmmu/internvl_utils.py
0 → 100644
View file @
6ea1e6ac
# copy from https://huggingface.co/OpenGVLab/InternVL3-1B
import
torch
import
torchvision.transforms
as
T
from
PIL
import
Image
from
torchvision.transforms.functional
import
InterpolationMode
IMAGENET_MEAN
=
(
0.485
,
0.456
,
0.406
)
IMAGENET_STD
=
(
0.229
,
0.224
,
0.225
)
def
build_transform
(
input_size
):
MEAN
,
STD
=
IMAGENET_MEAN
,
IMAGENET_STD
transform
=
T
.
Compose
(
[
T
.
Lambda
(
lambda
img
:
img
.
convert
(
"RGB"
)
if
img
.
mode
!=
"RGB"
else
img
),
T
.
Resize
((
input_size
,
input_size
),
interpolation
=
InterpolationMode
.
BICUBIC
),
T
.
ToTensor
(),
T
.
Normalize
(
mean
=
MEAN
,
std
=
STD
),
]
)
return
transform
def
find_closest_aspect_ratio
(
aspect_ratio
,
target_ratios
,
width
,
height
,
image_size
):
best_ratio_diff
=
float
(
"inf"
)
best_ratio
=
(
1
,
1
)
area
=
width
*
height
for
ratio
in
target_ratios
:
target_aspect_ratio
=
ratio
[
0
]
/
ratio
[
1
]
ratio_diff
=
abs
(
aspect_ratio
-
target_aspect_ratio
)
if
ratio_diff
<
best_ratio_diff
:
best_ratio_diff
=
ratio_diff
best_ratio
=
ratio
elif
ratio_diff
==
best_ratio_diff
:
if
area
>
0.5
*
image_size
*
image_size
*
ratio
[
0
]
*
ratio
[
1
]:
best_ratio
=
ratio
return
best_ratio
def
dynamic_preprocess
(
image
,
min_num
=
1
,
max_num
=
12
,
image_size
=
448
,
use_thumbnail
=
False
):
orig_width
,
orig_height
=
image
.
size
aspect_ratio
=
orig_width
/
orig_height
# calculate the existing image aspect ratio
target_ratios
=
set
(
(
i
,
j
)
for
n
in
range
(
min_num
,
max_num
+
1
)
for
i
in
range
(
1
,
n
+
1
)
for
j
in
range
(
1
,
n
+
1
)
if
i
*
j
<=
max_num
and
i
*
j
>=
min_num
)
target_ratios
=
sorted
(
target_ratios
,
key
=
lambda
x
:
x
[
0
]
*
x
[
1
])
# find the closest aspect ratio to the target
target_aspect_ratio
=
find_closest_aspect_ratio
(
aspect_ratio
,
target_ratios
,
orig_width
,
orig_height
,
image_size
)
# calculate the target width and height
target_width
=
image_size
*
target_aspect_ratio
[
0
]
target_height
=
image_size
*
target_aspect_ratio
[
1
]
blocks
=
target_aspect_ratio
[
0
]
*
target_aspect_ratio
[
1
]
# resize the image
resized_img
=
image
.
resize
((
target_width
,
target_height
))
processed_images
=
[]
for
i
in
range
(
blocks
):
box
=
(
(
i
%
(
target_width
//
image_size
))
*
image_size
,
(
i
//
(
target_width
//
image_size
))
*
image_size
,
((
i
%
(
target_width
//
image_size
))
+
1
)
*
image_size
,
((
i
//
(
target_width
//
image_size
))
+
1
)
*
image_size
,
)
# split the image
split_img
=
resized_img
.
crop
(
box
)
processed_images
.
append
(
split_img
)
assert
len
(
processed_images
)
==
blocks
if
use_thumbnail
and
len
(
processed_images
)
!=
1
:
thumbnail_img
=
image
.
resize
((
image_size
,
image_size
))
processed_images
.
append
(
thumbnail_img
)
return
processed_images
def
load_image
(
image_file
,
input_size
=
448
,
max_num
=
12
):
image
=
Image
.
open
(
image_file
).
convert
(
"RGB"
)
transform
=
build_transform
(
input_size
=
input_size
)
images
=
dynamic_preprocess
(
image
,
image_size
=
input_size
,
use_thumbnail
=
True
,
max_num
=
max_num
)
pixel_values
=
[
transform
(
image
)
for
image
in
images
]
pixel_values
=
torch
.
stack
(
pixel_values
)
return
pixel_values
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment