Unverified Commit d50e36a7 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

support vlm benchmark profile (#5905)

parent 8fefdd32
...@@ -10,8 +10,14 @@ The eval output will be logged ...@@ -10,8 +10,14 @@ The eval output will be logged
""" """
import argparse import argparse
import asyncio
import sys
import time import time
import traceback
from dataclasses import dataclass, field
from typing import List
import aiohttp
import openai import openai
from data_utils import save_json from data_utils import save_json
from eval_utils import ( from eval_utils import (
...@@ -25,8 +31,41 @@ from tqdm import tqdm ...@@ -25,8 +31,41 @@ from tqdm import tqdm
from sglang.test.test_utils import add_common_sglang_args_and_parse from sglang.test.test_utils import add_common_sglang_args_and_parse
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
def eval_mmmu(args):
@dataclass
class RequestFuncOutput:
generated_text: List[str] = field(default_factory=list)
prompt_len: List[int] = field(default_factory=list)
output_len: List[int] = field(default_factory=list)
latency: List[float] = field(default_factory=list)
ttft: List[float] = field(default_factory=list)
itl: List[float] = field(default_factory=list) # List of inter-token latencies
success: bool = False
error: str = ""
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
async def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args)
out_samples = dict() out_samples = dict()
...@@ -38,9 +77,22 @@ def eval_mmmu(args): ...@@ -38,9 +77,22 @@ def eval_mmmu(args):
answer_dict = {} answer_dict = {}
# had to use an openai server, since SglImage doesn't support image data # had to use an openai server, since SglImage doesn't support image data
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1") base_url = f"http://127.0.0.1:{args.port}"
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
start = time.time() start = time.time()
if args.profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=f"{base_url}/start_profile"
)
if profile_output.success:
print("Profiler started")
if args.profile:
samples = samples[: args.profile_number]
for i, sample in enumerate(tqdm(samples)): for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"] prompt = sample["final_input_prompt"]
prefix = prompt.split("<")[0] prefix = prompt.split("<")[0]
...@@ -49,6 +101,7 @@ def eval_mmmu(args): ...@@ -49,6 +101,7 @@ def eval_mmmu(args):
assert image is not None assert image is not None
image_path = sample["image_path"] image_path = sample["image_path"]
# TODO: batch # TODO: batch
response = client.chat.completions.create( response = client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
...@@ -77,6 +130,12 @@ def eval_mmmu(args): ...@@ -77,6 +130,12 @@ def eval_mmmu(args):
response = response.choices[0].message.content response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)
if args.profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
if profile_output.success:
print("Profiler stopped")
print(f"Benchmark time: {time.time() - start}") print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json" args.output_path = f"./val_sglang.json"
...@@ -89,4 +148,4 @@ if __name__ == "__main__": ...@@ -89,4 +148,4 @@ if __name__ == "__main__":
EvalArgs.add_cli_args(parser) EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser) args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args() args = parser.parse_args()
eval_mmmu(args) asyncio.run(eval_mmmu(args))
...@@ -33,6 +33,8 @@ class EvalArgs: ...@@ -33,6 +33,8 @@ class EvalArgs:
prompt_format_file: str = "prompt_format.yaml" prompt_format_file: str = "prompt_format.yaml"
dataset_path: str = "MMMU/MMMU" dataset_path: str = "MMMU/MMMU"
extra_request_body: Optional[str] = None extra_request_body: Optional[str] = None
profile: bool = False
profile_number: int = 5
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -65,6 +67,12 @@ class EvalArgs: ...@@ -65,6 +67,12 @@ class EvalArgs:
help="Append given JSON object to the request payload. You can use this to specify" help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.", "additional generate params like sampling params.",
) )
parser.add_argument(
"--profile", action="store_true", help="enable mmmu profile"
)
parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment