"vscode:/vscode.git/clone" did not exist on "0f095f79ef18998577e9c510f95725f4cfce039d"
bench_sglang.py 6.38 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
Usage:
5
    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000
6

7
    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
8

9
The eval output will be logged
10
11
12
"""

import argparse
13
import asyncio
14
import re
15
import sys
16
import time
17
18
import traceback
from dataclasses import dataclass, field
19
from typing import Any, List, Optional, Tuple
20

21
import aiohttp
22
import openai
23
24
25
26
27
28
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
29
    process_result,
30
31
32
)
from tqdm import tqdm

33
from sglang.test.test_utils import add_common_sglang_args_and_parse
34

35
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

@dataclass
class RequestFuncOutput:
    generated_text: List[str] = field(default_factory=list)
    prompt_len: List[int] = field(default_factory=list)
    output_len: List[int] = field(default_factory=list)
    latency: List[float] = field(default_factory=list)
    ttft: List[float] = field(default_factory=list)
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies

    success: bool = False
    error: str = ""


async def async_request_profile(api_url: str) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


69
70
71
72
73
74
75
76
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
    """Split the prompt into prefix and suffix."""
    prefix = prompt.split("<")[0]
    suffix = prompt.split(">", 1)[1]
    return prefix, suffix


async def process_sample(
77
    client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
78
79
80
81
82
83
84
) -> Tuple[dict, str]:
    """Send a single sample to the LLM and return (sample, response)."""
    prompt = sample["final_input_prompt"]
    prefix, suffix = _get_prefix_suffix(prompt)
    image = sample["image"]
    assert image is not None
    image_path = sample["image_path"]
85
    extra_body = None if lora_path is None else {"lora_path": lora_path}
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    response = await client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prefix},
                    {"type": "image_url", "image_url": {"url": image_path}},
                    {"type": "text", "text": suffix},
                ],
            }
        ],
        temperature=0,
        max_completion_tokens=sampling_params["max_new_tokens"],
        max_tokens=sampling_params["max_new_tokens"],
101
        extra_body=extra_body,
102
103
104
105
106
    )
    return sample, response.choices[0].message.content


async def process_sample_with_semaphore(
107
108
109
110
111
    semaphore: asyncio.Semaphore,
    client: Any,
    sample: dict,
    sampling_params: dict,
    lora_path: Optional[str] = None,
112
113
114
) -> Tuple[dict, str]:
    """Wrap process_sample with a semaphore for concurrency control."""
    async with semaphore:
115
        return await process_sample(client, sample, sampling_params, lora_path)
116
117
118
119


async def eval_mmmu(args) -> None:
    """Main evaluation loop with concurrency control."""
120
121
    eval_args = EvalArgs.from_cli_args(args)
    sampling_params = get_sampling_params(eval_args)
122
    samples = prepare_samples(eval_args)
123
    lora_path = eval_args.lora_path
124
    answer_dict = {}
125
126
127
128
    out_samples = {}
    client = openai.AsyncOpenAI(
        api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
    )
129
    start = time.perf_counter()
130
    base_url = f"http://127.0.0.1:{args.port}"
131
132
133
134
135
136
137
138
139
140
141

    if args.profile:
        print("Starting profiler...")
        profile_output = await async_request_profile(
            api_url=f"{base_url}/start_profile"
        )
        if profile_output.success:
            print("Profiler started")

        samples = samples[: args.profile_number]

Mick's avatar
Mick committed
142
143
144
145
146
147
148
    if args.concurrency == 1:
        # For concurrency == 1, run in sequential mode to ensure consistent order
        # this is mainly for profiling
        for sample in tqdm(samples):
            _, response = await process_sample(
                client, sample, sampling_params, lora_path
            )
149
150
151
152
153
154
155
156
157
158
159
            answer = (
                re.search(args.response_answer_regex, response)
                if response is not None
                else None
            )
            process_result(
                answer.group(1) if answer else response,
                sample,
                answer_dict,
                out_samples,
            )
Mick's avatar
Mick committed
160
161
162
163
164
165
166
167
168
169
170
    else:
        semaphore = asyncio.Semaphore(args.concurrency)
        tasks = [
            process_sample_with_semaphore(
                semaphore, client, sample, sampling_params, lora_path
            )
            for sample in samples
        ]

        for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
            sample, response = await coro
171
172
173
174
175
176
177
178
179
180
181
            answer = (
                re.search(args.response_answer_regex, response)
                if response is not None
                else None
            )
            process_result(
                answer.group(1) if answer else response,
                sample,
                answer_dict,
                out_samples,
            )
182

183
184
185
186
187
188
    if args.profile:
        print("Stopping profiler...")
        profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
        if profile_output.success:
            print("Profiler stopped")

189
    print(f"Benchmark time: {time.perf_counter() - start}")
190
    args.output_path = "./answer_sglang.json"
191
    save_json(args.output_path, out_samples)
192
193
194
195
196
    eval_result(
        model_answer_path=args.output_path,
        answer_dict=answer_dict,
        eval_output_path="./val_sglang.json",
    )
197

198

199
def parse_args():
200
201
    parser = argparse.ArgumentParser()
    EvalArgs.add_cli_args(parser)
202
    args = add_common_sglang_args_and_parse(parser)
203
204
205
206
207
    return args


def main():
    args = parse_args()
208
    asyncio.run(eval_mmmu(args))
209
210
211
212


if __name__ == "__main__":
    main()