bench_sglang.py 4.45 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
Usage:
5
6
7
    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000

    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
8

9
The eval output will be logged
10
11
12
"""

import argparse
13
14
import asyncio
import sys
15
import time
16
17
18
import traceback
from dataclasses import dataclass, field
from typing import List
19

20
import aiohttp
21
import openai
22
23
24
25
26
27
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
28
    process_result,
29
30
31
)
from tqdm import tqdm

32
from sglang.test.test_utils import add_common_sglang_args_and_parse
33

34
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

@dataclass
class RequestFuncOutput:
    generated_text: List[str] = field(default_factory=list)
    prompt_len: List[int] = field(default_factory=list)
    output_len: List[int] = field(default_factory=list)
    latency: List[float] = field(default_factory=list)
    ttft: List[float] = field(default_factory=list)
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies

    success: bool = False
    error: str = ""


async def async_request_profile(api_url: str) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


async def eval_mmmu(args):
69
70
71
72
73
74
    eval_args = EvalArgs.from_cli_args(args)

    out_samples = dict()

    sampling_params = get_sampling_params(eval_args)

75
76
    samples = prepare_samples(eval_args)

77
    answer_dict = {}
78

79
    # had to use an openai server, since SglImage doesn't support image data
80
81
    base_url = f"http://127.0.0.1:{args.port}"
    client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
82

83
    start = time.time()
84
85
86
87
88
89
90
91
92
93
94
95

    if args.profile:
        print("Starting profiler...")
        profile_output = await async_request_profile(
            api_url=f"{base_url}/start_profile"
        )
        if profile_output.success:
            print("Profiler started")

    if args.profile:
        samples = samples[: args.profile_number]

96
    for i, sample in enumerate(tqdm(samples)):
97
        prompt = sample["final_input_prompt"]
98
99
        prefix = prompt.split("<")[0]
        suffix = prompt.split(">")[1]
100
101
102
103
        image = sample["image"]
        assert image is not None
        image_path = sample["image_path"]
        # TODO: batch
104

105
106
107
        response = client.chat.completions.create(
            model="default",
            messages=[
108
109
110
111
112
113
114
115
116
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prefix,
                        },
                        {
                            "type": "image_url",
117
                            "image_url": {"url": image_path},
118
119
120
121
122
123
124
125
                        },
                        {
                            "type": "text",
                            "text": suffix,
                        },
                    ],
                }
            ],
126
127
128
            temperature=0,
            max_completion_tokens=sampling_params["max_new_tokens"],
            max_tokens=sampling_params["max_new_tokens"],
129
        )
130
        response = response.choices[0].message.content
131
        process_result(response, sample, answer_dict, out_samples)
132

133
134
135
136
137
138
    if args.profile:
        print("Stopping profiler...")
        profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
        if profile_output.success:
            print("Profiler stopped")

139
140
    print(f"Benchmark time: {time.time() - start}")

141
    args.output_path = f"./val_sglang.json"
142
    save_json(args.output_path, out_samples)
143
144
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

145
146
147
148

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    EvalArgs.add_cli_args(parser)
149
    args = add_common_sglang_args_and_parse(parser)
150
    args = parser.parse_args()
151
    asyncio.run(eval_mmmu(args))