bench_sglang.py 3.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
    Bench the sglang-hosted vLM with benchmark MMMU

    Usage:
        python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl

    The eval output will be logged
"""

import argparse
11
import base64
12
13
14
15
16
17
18
19
20
21
import dataclasses
import random
from io import BytesIO

from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
22
    process_result,
23
24
25
26
)
from tqdm import tqdm

from sglang import Engine
27
28
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from sglang.srt.server_args import ServerArgs


def eval_mmmu(args):
    server_args = ServerArgs.from_cli_args(args)
    eval_args = EvalArgs.from_cli_args(args)

    if server_args.chat_template is None:
        raise ValueError("Chat template must be provided for this benchmark")

    backend = Engine(**dataclasses.asdict(server_args))

    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)

45
46
    samples = prepare_samples(eval_args)

47
    answer_dict = {}
48

49
50
51
    for sample in tqdm(samples):
        prompt = sample["final_input_prompt"]
        image = sample["image"]
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        buff = BytesIO()
        image.save(buff, format="PNG")
        base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
        request_dict = {
            "model": "",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prefix,
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{base64_str}"
                            },
                        },
                        {
                            "type": "text",
                            "text": suffix,
                        },
                    ],
                }
            ],
        }
81

82
83
84
85
86
        conv = generate_chat_conv(
            ChatCompletionRequest(**request_dict),
            template_name=server_args.chat_template,
        )
        prompt = conv.get_prompt()
87
88
        if image is not None:
            gen_out = backend.generate(
89
90
91
                prompt=prompt,
                image_data=conv.image_data,
                sampling_params=sampling_params,
92
93
94
            )["text"]

            response = gen_out
95

96
97
98
99
100
101
102
        else:  # multiple images actually
            if sample["question_type"] == "multiple-choice":
                all_choices = sample["all_choices"]
                response = random.choice(all_choices)
            else:
                response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"

103
        process_result(response, sample, answer_dict, out_samples)
104
105
    args.output_path = f"{args.model_path}_val_sglang.json"
    save_json(args.output_path, out_samples)
106
107
108
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

    backend.shutdown()
109
110
111
112
113
114
115
116
117


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    EvalArgs.add_cli_args(parser)
    args = parser.parse_args()

    eval_mmmu(args)