bench_sglang.py 5.34 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
Usage:
5
    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000
6

7
    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
8

9
The eval output will be logged
10
11
12
"""

import argparse
13
14
import asyncio
import sys
15
import time
16
17
import traceback
from dataclasses import dataclass, field
18
from typing import Any, List, Optional, Tuple
19

20
import aiohttp
21
import openai
22
23
24
25
26
27
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
28
    process_result,
29
30
31
)
from tqdm import tqdm

32
from sglang.test.test_utils import add_common_sglang_args_and_parse
33

34
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

@dataclass
class RequestFuncOutput:
    generated_text: List[str] = field(default_factory=list)
    prompt_len: List[int] = field(default_factory=list)
    output_len: List[int] = field(default_factory=list)
    latency: List[float] = field(default_factory=list)
    ttft: List[float] = field(default_factory=list)
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies

    success: bool = False
    error: str = ""


async def async_request_profile(api_url: str) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


68
69
70
71
72
73
74
75
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
    """Split the prompt into prefix and suffix."""
    prefix = prompt.split("<")[0]
    suffix = prompt.split(">", 1)[1]
    return prefix, suffix


async def process_sample(
76
    client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
77
78
79
80
81
82
83
) -> Tuple[dict, str]:
    """Send a single sample to the LLM and return (sample, response)."""
    prompt = sample["final_input_prompt"]
    prefix, suffix = _get_prefix_suffix(prompt)
    image = sample["image"]
    assert image is not None
    image_path = sample["image_path"]
84
    extra_body = None if lora_path is None else {"lora_path": lora_path}
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    response = await client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prefix},
                    {"type": "image_url", "image_url": {"url": image_path}},
                    {"type": "text", "text": suffix},
                ],
            }
        ],
        temperature=0,
        max_completion_tokens=sampling_params["max_new_tokens"],
        max_tokens=sampling_params["max_new_tokens"],
100
        extra_body=extra_body,
101
102
103
104
105
    )
    return sample, response.choices[0].message.content


async def process_sample_with_semaphore(
106
107
108
109
110
    semaphore: asyncio.Semaphore,
    client: Any,
    sample: dict,
    sampling_params: dict,
    lora_path: Optional[str] = None,
111
112
113
) -> Tuple[dict, str]:
    """Wrap process_sample with a semaphore for concurrency control."""
    async with semaphore:
114
        return await process_sample(client, sample, sampling_params, lora_path)
115
116
117
118


async def eval_mmmu(args) -> None:
    """Main evaluation loop with concurrency control."""
119
120
    eval_args = EvalArgs.from_cli_args(args)
    sampling_params = get_sampling_params(eval_args)
121
    samples = prepare_samples(eval_args)
122
    lora_path = eval_args.lora_path
123
    answer_dict = {}
124
125
126
127
128
    out_samples = {}
    client = openai.AsyncOpenAI(
        api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
    )
    semaphore = asyncio.Semaphore(args.concurrency)
129
    start = time.perf_counter()
130
    base_url = f"http://127.0.0.1:{args.port}"
131
132
133
134
135
136
137
138
139
140
141

    if args.profile:
        print("Starting profiler...")
        profile_output = await async_request_profile(
            api_url=f"{base_url}/start_profile"
        )
        if profile_output.success:
            print("Profiler started")

        samples = samples[: args.profile_number]

142
    tasks = [
143
144
145
        process_sample_with_semaphore(
            semaphore, client, sample, sampling_params, lora_path
        )
146
147
148
149
150
        for sample in samples
    ]

    for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
        sample, response = await coro
151
        process_result(response, sample, answer_dict, out_samples)
152

153
154
155
156
157
158
    if args.profile:
        print("Stopping profiler...")
        profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
        if profile_output.success:
            print("Profiler stopped")

159
    print(f"Benchmark time: {time.perf_counter() - start}")
160
    args.output_path = f"./val_sglang.json"
161
    save_json(args.output_path, out_samples)
162
163
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

164

165
def parse_args():
166
167
    parser = argparse.ArgumentParser()
    EvalArgs.add_cli_args(parser)
168
    args = add_common_sglang_args_and_parse(parser)
169
170
171
172
173
    return args


def main():
    args = parse_args()
174
    asyncio.run(eval_mmmu(args))
175
176
177
178


if __name__ == "__main__":
    main()