"vscode:/vscode.git/clone" did not exist on "90a624f697e5176b7400ffc647ec64531df05be2"
Unverified Commit c5645e92 authored by XinyuanTong's avatar XinyuanTong Committed by GitHub
Browse files

feat: add concurrency evaluation logic in mmmu benchmark (#5782)

parent d33955d2
......@@ -8,13 +8,15 @@ Host the VLM:
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
```
It's recommended to reduce the memory usage by appending something like `--mem-fraction-static 0.6` to the command above.
Benchmark:
```
python benchmark/mmmu/bench_sglang.py --port 30000
python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
```
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.
You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
### Evaluate hf
......
......@@ -4,7 +4,7 @@ Bench the sglang-hosted vLM with benchmark MMMU
Usage:
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
The eval output will be logged
"""
......@@ -15,7 +15,7 @@ import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import List
from typing import Any, List, Tuple
import aiohttp
import openai
......@@ -65,22 +65,62 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput:
return output
async def eval_mmmu(args):
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
"""Split the prompt into prefix and suffix."""
prefix = prompt.split("<")[0]
suffix = prompt.split(">", 1)[1]
return prefix, suffix
async def process_sample(
client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"]
prefix, suffix = _get_prefix_suffix(prompt)
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
response = await client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prefix},
{"type": "image_url", "image_url": {"url": image_path}},
{"type": "text", "text": suffix},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
return sample, response.choices[0].message.content
async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control."""
async with semaphore:
return await process_sample(client, sample, sampling_params)
async def eval_mmmu(args) -> None:
"""Main evaluation loop with concurrency control."""
eval_args = EvalArgs.from_cli_args(args)
out_samples = dict()
sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args)
answer_dict = {}
# had to use an openai server, since SglImage doesn't support image data
base_url = f"http://127.0.0.1:{args.port}"
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
out_samples = {}
client = openai.AsyncOpenAI(
api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
)
semaphore = asyncio.Semaphore(args.concurrency)
start = time.time()
base_url = f"http://127.0.0.1:{args.port}"
if args.profile:
print("Starting profiler...")
......@@ -90,44 +130,15 @@ async def eval_mmmu(args):
if profile_output.success:
print("Profiler started")
if args.profile:
samples = samples[: args.profile_number]
for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"]
prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1]
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
# TODO: batch
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prefix,
},
{
"type": "image_url",
"image_url": {"url": image_path},
},
{
"type": "text",
"text": suffix,
},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
response = response.choices[0].message.content
tasks = [
process_sample_with_semaphore(semaphore, client, sample, sampling_params)
for sample in samples
]
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
sample, response = await coro
process_result(response, sample, answer_dict, out_samples)
if args.profile:
......@@ -137,15 +148,22 @@ async def eval_mmmu(args):
print("Profiler stopped")
print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
if __name__ == "__main__":
def parse_args():
parser = argparse.ArgumentParser()
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args()
return args
def main():
args = parse_args()
asyncio.run(eval_mmmu(args))
if __name__ == "__main__":
main()
......@@ -35,6 +35,7 @@ class EvalArgs:
extra_request_body: Optional[str] = None
profile: bool = False
profile_number: int = 5
concurrency: int = 1
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -73,6 +74,7 @@ class EvalArgs:
parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number
)
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment