bench_sglang.py 5.76 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
Usage:
5
    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000
6

7
    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
8

9
The eval output will be logged
10
11
12
"""

import argparse
13
14
import asyncio
import sys
15
import time
16
17
import traceback
from dataclasses import dataclass, field
18
from typing import Any, List, Optional, Tuple
19

20
import aiohttp
21
import openai
22
23
24
25
26
27
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
28
    process_result,
29
30
31
)
from tqdm import tqdm

32
from sglang.test.test_utils import add_common_sglang_args_and_parse
33

34
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

@dataclass
class RequestFuncOutput:
    generated_text: List[str] = field(default_factory=list)
    prompt_len: List[int] = field(default_factory=list)
    output_len: List[int] = field(default_factory=list)
    latency: List[float] = field(default_factory=list)
    ttft: List[float] = field(default_factory=list)
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies

    success: bool = False
    error: str = ""


async def async_request_profile(api_url: str) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


68
69
70
71
72
73
74
75
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
    """Split the prompt into prefix and suffix."""
    prefix = prompt.split("<")[0]
    suffix = prompt.split(">", 1)[1]
    return prefix, suffix


async def process_sample(
76
    client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
77
78
79
80
81
82
83
) -> Tuple[dict, str]:
    """Send a single sample to the LLM and return (sample, response)."""
    prompt = sample["final_input_prompt"]
    prefix, suffix = _get_prefix_suffix(prompt)
    image = sample["image"]
    assert image is not None
    image_path = sample["image_path"]
84
    extra_body = None if lora_path is None else {"lora_path": lora_path}
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    response = await client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prefix},
                    {"type": "image_url", "image_url": {"url": image_path}},
                    {"type": "text", "text": suffix},
                ],
            }
        ],
        temperature=0,
        max_completion_tokens=sampling_params["max_new_tokens"],
        max_tokens=sampling_params["max_new_tokens"],
100
        extra_body=extra_body,
101
102
103
104
105
    )
    return sample, response.choices[0].message.content


async def process_sample_with_semaphore(
106
107
108
109
110
    semaphore: asyncio.Semaphore,
    client: Any,
    sample: dict,
    sampling_params: dict,
    lora_path: Optional[str] = None,
111
112
113
) -> Tuple[dict, str]:
    """Wrap process_sample with a semaphore for concurrency control."""
    async with semaphore:
114
        return await process_sample(client, sample, sampling_params, lora_path)
115
116
117
118


async def eval_mmmu(args) -> None:
    """Main evaluation loop with concurrency control."""
119
120
    eval_args = EvalArgs.from_cli_args(args)
    sampling_params = get_sampling_params(eval_args)
121
    samples = prepare_samples(eval_args)
122
    lora_path = eval_args.lora_path
123
    answer_dict = {}
124
125
126
127
    out_samples = {}
    client = openai.AsyncOpenAI(
        api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
    )
128
    start = time.perf_counter()
129
    base_url = f"http://127.0.0.1:{args.port}"
130
131
132
133
134
135
136
137
138
139
140

    if args.profile:
        print("Starting profiler...")
        profile_output = await async_request_profile(
            api_url=f"{base_url}/start_profile"
        )
        if profile_output.success:
            print("Profiler started")

        samples = samples[: args.profile_number]

Mick's avatar
Mick committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    if args.concurrency == 1:
        # For concurrency == 1, run in sequential mode to ensure consistent order
        # this is mainly for profiling
        for sample in tqdm(samples):
            _, response = await process_sample(
                client, sample, sampling_params, lora_path
            )
            process_result(response, sample, answer_dict, out_samples)
    else:
        semaphore = asyncio.Semaphore(args.concurrency)
        tasks = [
            process_sample_with_semaphore(
                semaphore, client, sample, sampling_params, lora_path
            )
            for sample in samples
        ]

        for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
            sample, response = await coro
            process_result(response, sample, answer_dict, out_samples)
161

162
163
164
165
166
167
    if args.profile:
        print("Stopping profiler...")
        profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
        if profile_output.success:
            print("Profiler stopped")

168
    print(f"Benchmark time: {time.perf_counter() - start}")
169
    args.output_path = f"./val_sglang.json"
170
    save_json(args.output_path, out_samples)
171
172
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

173

174
def parse_args():
175
176
    parser = argparse.ArgumentParser()
    EvalArgs.add_cli_args(parser)
177
    args = add_common_sglang_args_and_parse(parser)
178
179
180
181
182
    return args


def main():
    args = parse_args()
183
    asyncio.run(eval_mmmu(args))
184
185
186
187


if __name__ == "__main__":
    main()