Unverified Commit c5645e92 authored by XinyuanTong's avatar XinyuanTong Committed by GitHub
Browse files

feat: add concurrency evaluation logic in mmmu benchmark (#5782)

parent d33955d2
...@@ -8,13 +8,15 @@ Host the VLM: ...@@ -8,13 +8,15 @@ Host the VLM:
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000 python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
``` ```
It's recommended to reduce the memory usage by appending something like `--mem-fraction-static 0.6` to the command above.
Benchmark: Benchmark:
``` ```
python benchmark/mmmu/bench_sglang.py --port 30000 python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
``` ```
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above. You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
### Evaluate hf ### Evaluate hf
......
...@@ -4,7 +4,7 @@ Bench the sglang-hosted vLM with benchmark MMMU ...@@ -4,7 +4,7 @@ Bench the sglang-hosted vLM with benchmark MMMU
Usage: Usage:
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000 Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
The eval output will be logged The eval output will be logged
""" """
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
import time import time
import traceback import traceback
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import Any, List, Tuple
import aiohttp import aiohttp
import openai import openai
...@@ -65,22 +65,62 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput: ...@@ -65,22 +65,62 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput:
return output return output
async def eval_mmmu(args): def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
"""Split the prompt into prefix and suffix."""
prefix = prompt.split("<")[0]
suffix = prompt.split(">", 1)[1]
return prefix, suffix
async def process_sample(
client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"]
prefix, suffix = _get_prefix_suffix(prompt)
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
response = await client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prefix},
{"type": "image_url", "image_url": {"url": image_path}},
{"type": "text", "text": suffix},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
return sample, response.choices[0].message.content
async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control."""
async with semaphore:
return await process_sample(client, sample, sampling_params)
async def eval_mmmu(args) -> None:
"""Main evaluation loop with concurrency control."""
eval_args = EvalArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args)
out_samples = dict()
sampling_params = get_sampling_params(eval_args) sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args) samples = prepare_samples(eval_args)
answer_dict = {} answer_dict = {}
out_samples = {}
# had to use an openai server, since SglImage doesn't support image data client = openai.AsyncOpenAI(
base_url = f"http://127.0.0.1:{args.port}" api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1") )
semaphore = asyncio.Semaphore(args.concurrency)
start = time.time() start = time.time()
base_url = f"http://127.0.0.1:{args.port}"
if args.profile: if args.profile:
print("Starting profiler...") print("Starting profiler...")
...@@ -90,44 +130,15 @@ async def eval_mmmu(args): ...@@ -90,44 +130,15 @@ async def eval_mmmu(args):
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
if args.profile:
samples = samples[: args.profile_number] samples = samples[: args.profile_number]
for i, sample in enumerate(tqdm(samples)): tasks = [
prompt = sample["final_input_prompt"] process_sample_with_semaphore(semaphore, client, sample, sampling_params)
prefix = prompt.split("<")[0] for sample in samples
suffix = prompt.split(">")[1] ]
image = sample["image"]
assert image is not None for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
image_path = sample["image_path"] sample, response = await coro
# TODO: batch
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prefix,
},
{
"type": "image_url",
"image_url": {"url": image_path},
},
{
"type": "text",
"text": suffix,
},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)
if args.profile: if args.profile:
...@@ -137,15 +148,22 @@ async def eval_mmmu(args): ...@@ -137,15 +148,22 @@ async def eval_mmmu(args):
print("Profiler stopped") print("Profiler stopped")
print(f"Benchmark time: {time.time() - start}") print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json" args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples) save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
if __name__ == "__main__": def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
EvalArgs.add_cli_args(parser) EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser) args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args() return args
def main():
args = parse_args()
asyncio.run(eval_mmmu(args)) asyncio.run(eval_mmmu(args))
if __name__ == "__main__":
main()
...@@ -35,6 +35,7 @@ class EvalArgs: ...@@ -35,6 +35,7 @@ class EvalArgs:
extra_request_body: Optional[str] = None extra_request_body: Optional[str] = None
profile: bool = False profile: bool = False
profile_number: int = 5 profile_number: int = 5
concurrency: int = 1
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -73,6 +74,7 @@ class EvalArgs: ...@@ -73,6 +74,7 @@ class EvalArgs:
parser.add_argument( parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number "--profile-number", type=int, default=EvalArgs.profile_number
) )
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment