Unverified Commit 98be3bd3 authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: rewrite bench-mmmu-sglang (#4458)

parent a98290ae
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
### Evaluate sglang ### Evaluate sglang
``` ```
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000
```
```
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
``` ```
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above. It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.
......
"""
Bench the huggingface vLM with benchmark MMMU
Usage:
python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct
The eval output will be logged
"""
import argparse import argparse
import random
import torch import torch
from data_utils import save_json from data_utils import save_json
...@@ -53,48 +43,31 @@ def eval_mmmu(args): ...@@ -53,48 +43,31 @@ def eval_mmmu(args):
image = sample["image"] image = sample["image"]
prefix = prompt.split("<")[0] prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1] suffix = prompt.split(">")[1]
if image is not None: assert image is not None
messages = [ contents = []
{ if prefix:
"role": "user", contents += [{"type": "text", "text": prefix}]
"content": [ contents += [
{"type": "text", "text": prefix}, {
{ "type": "image",
"type": "image", "image": sample["image_path"],
"image": image, }
}, ]
{"type": "text", "text": suffix}, if suffix:
], contents += [{"type": "text", "text": suffix}]
} messages = [{"role": "user", "content": contents}]
] model_inputs = processor.apply_chat_template(
text = processor.apply_chat_template( messages,
messages, tokenize=False, add_generation_prompt=True tokenize=True,
) return_dict=True,
inputs = processor( add_generation_prompt=True,
text=[text], return_tensors="pt",
images=[image], ).to(model.device)
padding=True, input_len = model_inputs["input_ids"].shape[-1]
return_tensors="pt", generation = model.generate(**model_inputs, generation_config=generation_config)
).to(model.device) generation = generation[0][input_len:]
response = processor.decode(generation, skip_special_tokens=True)
generated_ids = model.generate( print(f"response: {response}")
**inputs, generation_config=generation_config
)
response = processor.decode(
generated_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[len(text) :]
print(f"response: {response}")
else: # multiple images actually
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
response = random.choice(all_choices)
else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)
args.output_path = f"{args.model_path}_val_hf.json" args.output_path = f"{args.model_path}_val_hf.json"
......
...@@ -8,11 +8,8 @@ ...@@ -8,11 +8,8 @@
""" """
import argparse import argparse
import base64
import dataclasses
import random
from io import BytesIO
import openai
from data_utils import save_json from data_utils import save_json
from eval_utils import ( from eval_utils import (
EvalArgs, EvalArgs,
...@@ -23,21 +20,12 @@ from eval_utils import ( ...@@ -23,21 +20,12 @@ from eval_utils import (
) )
from tqdm import tqdm from tqdm import tqdm
from sglang import Engine from sglang.test.test_utils import add_common_sglang_args_and_parse
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
def eval_mmmu(args): def eval_mmmu(args):
server_args = ServerArgs.from_cli_args(args)
eval_args = EvalArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args)
if server_args.chat_template is None:
raise ValueError("Chat template must be provided for this benchmark")
backend = Engine(**dataclasses.asdict(server_args))
out_samples = dict() out_samples = dict()
sampling_params = get_sampling_params(eval_args) sampling_params = get_sampling_params(eval_args)
...@@ -46,17 +34,20 @@ def eval_mmmu(args): ...@@ -46,17 +34,20 @@ def eval_mmmu(args):
answer_dict = {} answer_dict = {}
for sample in tqdm(samples): # had to use an openai server, since SglImage doesn't support image data
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"] prompt = sample["final_input_prompt"]
image = sample["image"]
buff = BytesIO()
image.save(buff, format="PNG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
prefix = prompt.split("<")[0] prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1] suffix = prompt.split(">")[1]
request_dict = { image = sample["image"]
"model": "", assert image is not None
"messages": [ image_path = sample["image_path"]
# TODO: batch
response = client.chat.completions.create(
model="default",
messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
...@@ -66,9 +57,7 @@ def eval_mmmu(args): ...@@ -66,9 +57,7 @@ def eval_mmmu(args):
}, },
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": image_path},
"url": f"data:image/jpeg;base64,{base64_str}"
},
}, },
{ {
"type": "text", "type": "text",
...@@ -77,40 +66,21 @@ def eval_mmmu(args): ...@@ -77,40 +66,21 @@ def eval_mmmu(args):
], ],
} }
], ],
} temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
conv = generate_chat_conv( max_tokens=sampling_params["max_new_tokens"],
ChatCompletionRequest(**request_dict),
template_name=server_args.chat_template,
) )
prompt = conv.get_prompt() response = response.choices[0].message.content
if image is not None:
gen_out = backend.generate(
prompt=prompt,
image_data=conv.image_data,
sampling_params=sampling_params,
)["text"]
response = gen_out
else: # multiple images actually
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
response = random.choice(all_choices)
else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)
args.output_path = f"{args.model_path}_val_sglang.json"
args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples) save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
backend.shutdown()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) args = add_common_sglang_args_and_parse(parser)
EvalArgs.add_cli_args(parser) EvalArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -19,11 +19,11 @@ from data_utils import ( ...@@ -19,11 +19,11 @@ from data_utils import (
process_single_sample, process_single_sample,
) )
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset
from tqdm import tqdm
@dataclasses.dataclass @dataclasses.dataclass
class EvalArgs: class EvalArgs:
backend: str = "engine"
seed: int = 42 seed: int = 42
split: str = "validation" split: str = "validation"
# Default setting to make the benchmark available on A100 for most 7B models # Default setting to make the benchmark available on A100 for most 7B models
...@@ -35,7 +35,6 @@ class EvalArgs: ...@@ -35,7 +35,6 @@ class EvalArgs:
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend", type=str, default=EvalArgs.backend)
parser.add_argument( parser.add_argument(
"--result-filename", type=str, default=EvalArgs.result_filename "--result-filename", type=str, default=EvalArgs.result_filename
) )
...@@ -108,7 +107,7 @@ def prepare_samples(eval_args: EvalArgs): ...@@ -108,7 +107,7 @@ def prepare_samples(eval_args: EvalArgs):
# run for each subject # run for each subject
sub_dataset_list = [] sub_dataset_list = []
for subject in CAT_SHORT2LONG.values(): for subject in tqdm(CAT_SHORT2LONG.values()):
sub_dataset = load_dataset( sub_dataset = load_dataset(
eval_args.dataset_path, subject, split=eval_args.split eval_args.dataset_path, subject, split=eval_args.split
) )
...@@ -121,19 +120,31 @@ def prepare_samples(eval_args: EvalArgs): ...@@ -121,19 +120,31 @@ def prepare_samples(eval_args: EvalArgs):
## prepare images ## prepare images
samples = [] samples = []
skip_count = 0 skip_count = 0
for i, sample in enumerate(dataset):
# use image file as input to ensure the consistency between sglang and hf
images_path = os.path.expanduser("~/.cache/mmmu/images")
os.makedirs(images_path, exist_ok=True)
print(f"Saving images to: {images_path}")
for i, sample in enumerate(tqdm(dataset)):
sample = process_single_sample(sample) sample = process_single_sample(sample)
sample = construct_prompt(sample, eval_args.config) sample = construct_prompt(sample, eval_args.config)
image = sample["image"] image = sample["image"]
width, height = image.size width, height = image.size
if width * height >= eval_args.image_pixels_limit: if width * height >= eval_args.image_pixels_limit:
skip_count += 1 skip_count += 1
continue continue
image_path = f"{images_path}/image_{i}.png"
if not os.path.exists(image_path):
image.save(image_path)
sample["image_path"] = image_path
samples.append(sample) samples.append(sample)
print( print(
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
) )
print("samples have been prepared")
return samples return samples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment