bench_sglang.py 2.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
    Bench the sglang-hosted vLM with benchmark MMMU

    Usage:
        python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl

    The eval output will be logged
"""

import argparse

12
import openai
13
14
15
16
17
18
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
19
    process_result,
20
21
22
)
from tqdm import tqdm

23
from sglang.test.test_utils import add_common_sglang_args_and_parse
24
25
26
27
28
29
30
31
32


def eval_mmmu(args):
    eval_args = EvalArgs.from_cli_args(args)

    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)

33
34
    samples = prepare_samples(eval_args)

35
    answer_dict = {}
36

37
38
39
40
    # had to use an openai server, since SglImage doesn't support image data
    client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")

    for i, sample in enumerate(tqdm(samples)):
41
        prompt = sample["final_input_prompt"]
42
43
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
44
45
46
47
48
49
50
        image = sample["image"]
        assert image is not None
        image_path = sample["image_path"]
        # TODO: batch
        response = client.chat.completions.create(
            model="default",
            messages=[
51
52
53
54
55
56
57
58
59
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prefix,
                        },
                        {
                            "type": "image_url",
60
                            "image_url": {"url": image_path},
61
62
63
64
65
66
67
68
                        },
                        {
                            "type": "text",
                            "text": suffix,
                        },
                    ],
                }
            ],
69
70
71
            temperature=0,
            max_completion_tokens=sampling_params["max_new_tokens"],
            max_tokens=sampling_params["max_new_tokens"],
72
        )
73
        response = response.choices[0].message.content
74
        process_result(response, sample, answer_dict, out_samples)
75
76

    args.output_path = f"./val_sglang.json"
77
    save_json(args.output_path, out_samples)
78
79
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

80
81
82

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
83
    args = add_common_sglang_args_and_parse(parser)
84
85
86
87
    EvalArgs.add_cli_args(parser)
    args = parser.parse_args()

    eval_mmmu(args)