bench_sglang.py 2.54 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
5
Usage:
    python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
6

7
The eval output will be logged
8
9
10
"""

import argparse
11
import time
12

13
import openai
14
15
16
17
18
19
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
20
    process_result,
21
22
23
)
from tqdm import tqdm

24
from sglang.test.test_utils import add_common_sglang_args_and_parse
25
26
27
28
29
30
31
32
33


def eval_mmmu(args):
    eval_args = EvalArgs.from_cli_args(args)

    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)

34
35
    samples = prepare_samples(eval_args)

36
    answer_dict = {}
37

38
39
40
    # had to use an openai server, since SglImage doesn't support image data
    client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")

41
    start = time.time()
42
    for i, sample in enumerate(tqdm(samples)):
43
        prompt = sample["final_input_prompt"]
44
45
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
46
47
48
49
50
51
52
        image = sample["image"]
        assert image is not None
        image_path = sample["image_path"]
        # TODO: batch
        response = client.chat.completions.create(
            model="default",
            messages=[
53
54
55
56
57
58
59
60
61
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prefix,
                        },
                        {
                            "type": "image_url",
62
                            "image_url": {"url": image_path},
63
64
65
66
67
68
69
70
                        },
                        {
                            "type": "text",
                            "text": suffix,
                        },
                    ],
                }
            ],
71
72
73
            temperature=0,
            max_completion_tokens=sampling_params["max_new_tokens"],
            max_tokens=sampling_params["max_new_tokens"],
74
        )
75
        response = response.choices[0].message.content
76
        process_result(response, sample, answer_dict, out_samples)
77

78
79
    print(f"Benchmark time: {time.time() - start}")

80
    args.output_path = f"./val_sglang.json"
81
    save_json(args.output_path, out_samples)
82
83
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

84
85
86

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
87
    args = add_common_sglang_args_and_parse(parser)
88
89
90
91
    EvalArgs.add_cli_args(parser)
    args = parser.parse_args()

    eval_mmmu(args)