bench_sglang.py 6.54 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
Usage:
5
    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000
6

7
    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
8

9
The eval output will be logged
10
11
12
"""

import argparse
13
import asyncio
14
import re
15
import sys
16
import time
17
18
import traceback
from dataclasses import dataclass, field
19
from typing import Any, List, Optional, Tuple
20

21
import aiohttp
22
import openai
23
24
25
26
27
28
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
29
    process_result,
30
31
32
)
from tqdm import tqdm

33
from sglang.test.test_utils import add_common_sglang_args_and_parse
34

35
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

@dataclass
class RequestFuncOutput:
    generated_text: List[str] = field(default_factory=list)
    prompt_len: List[int] = field(default_factory=list)
    output_len: List[int] = field(default_factory=list)
    latency: List[float] = field(default_factory=list)
    ttft: List[float] = field(default_factory=list)
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies

    success: bool = False
    error: str = ""


async def async_request_profile(api_url: str) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


69
70
71
72
73
74
75
76
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
    """Split the prompt into prefix and suffix."""
    prefix = prompt.split("<")[0]
    suffix = prompt.split(">", 1)[1]
    return prefix, suffix


async def process_sample(
77
    client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
78
79
80
81
82
83
84
) -> Tuple[dict, str]:
    """Send a single sample to the LLM and return (sample, response)."""
    prompt = sample["final_input_prompt"]
    prefix, suffix = _get_prefix_suffix(prompt)
    image = sample["image"]
    assert image is not None
    image_path = sample["image_path"]
85
    extra_body = None if lora_path is None else {"lora_path": lora_path}
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    response = await client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prefix},
                    {"type": "image_url", "image_url": {"url": image_path}},
                    {"type": "text", "text": suffix},
                ],
            }
        ],
        temperature=0,
        max_completion_tokens=sampling_params["max_new_tokens"],
        max_tokens=sampling_params["max_new_tokens"],
101
        extra_body=extra_body,
102
103
104
105
106
    )
    return sample, response.choices[0].message.content


async def process_sample_with_semaphore(
107
108
109
110
111
    semaphore: asyncio.Semaphore,
    client: Any,
    sample: dict,
    sampling_params: dict,
    lora_path: Optional[str] = None,
112
113
114
) -> Tuple[dict, str]:
    """Wrap process_sample with a semaphore for concurrency control."""
    async with semaphore:
115
        return await process_sample(client, sample, sampling_params, lora_path)
116
117
118
119


async def eval_mmmu(args) -> None:
    """Main evaluation loop with concurrency control."""
120
121
    eval_args = EvalArgs.from_cli_args(args)
    sampling_params = get_sampling_params(eval_args)
122
    samples = prepare_samples(eval_args)
123
    lora_path = eval_args.lora_path
124
    answer_dict = {}
125
126
    out_samples = {}
    client = openai.AsyncOpenAI(
127
128
129
        api_key="sk",
        base_url=f"http://127.0.0.1:{args.port}/v1",
        timeout=20 * 60 * 60,
130
    )
131
    start = time.perf_counter()
132
    base_url = f"http://127.0.0.1:{args.port}"
133
134
135
136
137
138
139
140
141
142
143

    if args.profile:
        print("Starting profiler...")
        profile_output = await async_request_profile(
            api_url=f"{base_url}/start_profile"
        )
        if profile_output.success:
            print("Profiler started")

        samples = samples[: args.profile_number]

Mick's avatar
Mick committed
144
145
146
147
148
149
150
    if args.concurrency == 1:
        # For concurrency == 1, run in sequential mode to ensure consistent order
        # this is mainly for profiling
        for sample in tqdm(samples):
            _, response = await process_sample(
                client, sample, sampling_params, lora_path
            )
151
            sample["original_response"] = response
152
153
154
155
156
157
            answer = (
                re.search(args.response_answer_regex, response)
                if response is not None
                else None
            )
            process_result(
158
                answer.group(1).strip() if answer else response,
159
160
161
162
                sample,
                answer_dict,
                out_samples,
            )
Mick's avatar
Mick committed
163
164
165
166
167
168
169
170
171
172
173
    else:
        semaphore = asyncio.Semaphore(args.concurrency)
        tasks = [
            process_sample_with_semaphore(
                semaphore, client, sample, sampling_params, lora_path
            )
            for sample in samples
        ]

        for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
            sample, response = await coro
174
            sample["original_response"] = response
175
176
177
178
179
180
            answer = (
                re.search(args.response_answer_regex, response)
                if response is not None
                else None
            )
            process_result(
181
                answer.group(1).strip() if answer else response,
182
183
184
185
                sample,
                answer_dict,
                out_samples,
            )
186

187
188
189
190
191
192
    if args.profile:
        print("Stopping profiler...")
        profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
        if profile_output.success:
            print("Profiler stopped")

193
    print(f"Benchmark time: {time.perf_counter() - start}")
194
    args.output_path = "./answer_sglang.json"
195
    save_json(args.output_path, out_samples)
196
197
198
199
200
    eval_result(
        model_answer_path=args.output_path,
        answer_dict=answer_dict,
        eval_output_path="./val_sglang.json",
    )
201

202

203
def parse_args():
204
205
    parser = argparse.ArgumentParser()
    EvalArgs.add_cli_args(parser)
206
    args = add_common_sglang_args_and_parse(parser)
207
208
209
210
211
    return args


def main():
    args = parse_args()
212
    asyncio.run(eval_mmmu(args))
213
214
215
216


if __name__ == "__main__":
    main()