bench_sglang.py 2.62 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
Usage:
5
6
7
    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000

    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
8

9
The eval output will be logged
10
11
12
"""

import argparse
13
import time
14

15
import openai
16
17
18
19
20
21
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
22
    process_result,
23
24
25
)
from tqdm import tqdm

26
from sglang.test.test_utils import add_common_sglang_args_and_parse
27
28
29
30
31
32
33
34
35


def eval_mmmu(args):
    eval_args = EvalArgs.from_cli_args(args)

    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)

36
37
    samples = prepare_samples(eval_args)

38
    answer_dict = {}
39

40
41
42
    # had to use an openai server, since SglImage doesn't support image data
    client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")

43
    start = time.time()
44
    for i, sample in enumerate(tqdm(samples)):
45
        prompt = sample["final_input_prompt"]
46
47
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
48
49
50
51
52
53
54
        image = sample["image"]
        assert image is not None
        image_path = sample["image_path"]
        # TODO: batch
        response = client.chat.completions.create(
            model="default",
            messages=[
55
56
57
58
59
60
61
62
63
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prefix,
                        },
                        {
                            "type": "image_url",
64
                            "image_url": {"url": image_path},
65
66
67
68
69
70
71
72
                        },
                        {
                            "type": "text",
                            "text": suffix,
                        },
                    ],
                }
            ],
73
74
75
            temperature=0,
            max_completion_tokens=sampling_params["max_new_tokens"],
            max_tokens=sampling_params["max_new_tokens"],
76
        )
77
        response = response.choices[0].message.content
78
        process_result(response, sample, answer_dict, out_samples)
79

80
81
    print(f"Benchmark time: {time.time() - start}")

82
    args.output_path = f"./val_sglang.json"
83
    save_json(args.output_path, out_samples)
84
85
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

86
87
88
89

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    EvalArgs.add_cli_args(parser)
90
    args = add_common_sglang_args_and_parse(parser)
91
92
    args = parser.parse_args()
    eval_mmmu(args)