"...training_service/common/clusterJobRestServer.ts" did not exist on "b749266d3a75a410e83b409ebe99a027a70f2045"
bench_sglang.py 5.07 KB
Newer Older
1
"""
2
Bench the sglang-hosted vLM with benchmark MMMU
3

4
Usage:
5
    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000
6

7
    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
8

9
The eval output will be logged
10
11
12
"""

import argparse
13
14
import asyncio
import sys
15
import time
16
17
import traceback
from dataclasses import dataclass, field
18
from typing import Any, List, Tuple
19

20
import aiohttp
21
import openai
22
23
24
25
26
27
from data_utils import save_json
from eval_utils import (
    EvalArgs,
    eval_result,
    get_sampling_params,
    prepare_samples,
28
    process_result,
29
30
31
)
from tqdm import tqdm

32
from sglang.test.test_utils import add_common_sglang_args_and_parse
33

34
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

@dataclass
class RequestFuncOutput:
    generated_text: List[str] = field(default_factory=list)
    prompt_len: List[int] = field(default_factory=list)
    output_len: List[int] = field(default_factory=list)
    latency: List[float] = field(default_factory=list)
    ttft: List[float] = field(default_factory=list)
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies

    success: bool = False
    error: str = ""


async def async_request_profile(api_url: str) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
    """Split the prompt into prefix and suffix."""
    prefix = prompt.split("<")[0]
    suffix = prompt.split(">", 1)[1]
    return prefix, suffix


async def process_sample(
    client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
    """Send a single sample to the LLM and return (sample, response)."""
    prompt = sample["final_input_prompt"]
    prefix, suffix = _get_prefix_suffix(prompt)
    image = sample["image"]
    assert image is not None
    image_path = sample["image_path"]
    response = await client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prefix},
                    {"type": "image_url", "image_url": {"url": image_path}},
                    {"type": "text", "text": suffix},
                ],
            }
        ],
        temperature=0,
        max_completion_tokens=sampling_params["max_new_tokens"],
        max_tokens=sampling_params["max_new_tokens"],
    )
    return sample, response.choices[0].message.content


async def process_sample_with_semaphore(
    semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
    """Wrap process_sample with a semaphore for concurrency control."""
    async with semaphore:
        return await process_sample(client, sample, sampling_params)


async def eval_mmmu(args) -> None:
    """Main evaluation loop with concurrency control."""
113
114
    eval_args = EvalArgs.from_cli_args(args)
    sampling_params = get_sampling_params(eval_args)
115
    samples = prepare_samples(eval_args)
116
    answer_dict = {}
117
118
119
120
121
    out_samples = {}
    client = openai.AsyncOpenAI(
        api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
    )
    semaphore = asyncio.Semaphore(args.concurrency)
122
    start = time.perf_counter()
123
    base_url = f"http://127.0.0.1:{args.port}"
124
125
126
127
128
129
130
131
132
133
134

    if args.profile:
        print("Starting profiler...")
        profile_output = await async_request_profile(
            api_url=f"{base_url}/start_profile"
        )
        if profile_output.success:
            print("Profiler started")

        samples = samples[: args.profile_number]

135
136
137
138
139
140
141
    tasks = [
        process_sample_with_semaphore(semaphore, client, sample, sampling_params)
        for sample in samples
    ]

    for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
        sample, response = await coro
142
        process_result(response, sample, answer_dict, out_samples)
143

144
145
146
147
148
149
    if args.profile:
        print("Stopping profiler...")
        profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
        if profile_output.success:
            print("Profiler stopped")

150
    print(f"Benchmark time: {time.perf_counter() - start}")
151
    args.output_path = f"./val_sglang.json"
152
    save_json(args.output_path, out_samples)
153
154
    eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

155

156
def parse_args():
157
158
    parser = argparse.ArgumentParser()
    EvalArgs.add_cli_args(parser)
159
    args = add_common_sglang_args_and_parse(parser)
160
161
162
163
164
    return args


def main():
    args = parse_args()
165
    asyncio.run(eval_mmmu(args))
166
167
168
169


if __name__ == "__main__":
    main()