bench_sglang.py 2.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
    Bench the sglang-hosted vLM with benchmark MMMU

    Usage:
        python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl

    The eval output will be logged
"""

import argparse
import dataclasses
import random
import re
from io import BytesIO

from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    parse_multi_choice_response,
    prepare_samples,
)
from tqdm import tqdm

from sglang import Engine
from sglang.srt.conversation import chat_templates
from sglang.srt.server_args import ServerArgs


def eval_mmmu(args):
    server_args = ServerArgs.from_cli_args(args)
    eval_args = EvalArgs.from_cli_args(args)

    if server_args.chat_template is None:
        raise ValueError("Chat template must be provided for this benchmark")

    samples = prepare_samples(eval_args)

    backend = Engine(**dataclasses.asdict(server_args))

    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)

    conv = chat_templates[server_args.chat_template].copy()
    image_token = conv.image_token
    answer_dict = {}
    for sample in tqdm(samples):
        prompt = sample["final_input_prompt"]
        image = sample["image"]
        bytes_io = BytesIO()
        image.save(bytes_io, format="PNG")
        png_bytes = bytes_io.getvalue()

        prompt = re.sub(r"<[^>]*>", image_token, prompt)

        if image is not None:
            gen_out = backend.generate(
                prompt=prompt, image_data=[png_bytes], sampling_params=sampling_params
            )["text"]

            response = gen_out
        else:  # multiple images actually
            if sample["question_type"] == "multiple-choice":
                all_choices = sample["all_choices"]
                response = random.choice(all_choices)

            else:
                response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"

        if sample["question_type"] == "multiple-choice":
            pred_ans = parse_multi_choice_response(
                response, sample["all_choices"], sample["index2ans"]
            )
        else:  # open question
            pred_ans = response
        out_samples[sample["id"]] = pred_ans

        # set ground truth answer
        answer_dict[sample["id"]] = {
            "question_type": sample["question_type"],
            "ground_truth": (
                sample["correct_choice"]
                if "correct_choice" in samples
                else sample["answer"]
            ),
        }

    args.output_path = f"{args.model_path}_val_sglang.json"
    save_json(args.output_path, out_samples)
    eval_result(output_path=args.output_path, answer_dict=answer_dict)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    EvalArgs.add_cli_args(parser)
    args = parser.parse_args()

    eval_mmmu(args)