Unverified Commit e07d0647 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support LoRA in MMMU benchmark script. (#7218)

parent 3c2274fb
...@@ -18,6 +18,15 @@ python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16 ...@@ -18,6 +18,15 @@ python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
You can adjust the `--concurrency` to control the number of concurrent OpenAI calls. You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
You can use `--lora-path` to specify the LoRA adapter to apply during benchmarking. E.g.,
```
# Launch server with LoRA enabled
python -m sglang.launch_server --model-path microsoft/Phi-4-multimodal-instruct --port 30000 --trust-remote-code --disable-radix-cache --lora-paths vision=<LoRA path>
# Apply LoRA adapter during inferencing
python -m benchmark/mmmu/bench_sglang.py --concurrency 8 --lora-path vision
```
### Evaluate hf ### Evaluate hf
``` ```
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
import time import time
import traceback import traceback
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, List, Tuple from typing import Any, List, Optional, Tuple
import aiohttp import aiohttp
import openai import openai
...@@ -73,7 +73,7 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]: ...@@ -73,7 +73,7 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
async def process_sample( async def process_sample(
client: Any, sample: dict, sampling_params: dict client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
) -> Tuple[dict, str]: ) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response).""" """Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"] prompt = sample["final_input_prompt"]
...@@ -81,6 +81,7 @@ async def process_sample( ...@@ -81,6 +81,7 @@ async def process_sample(
image = sample["image"] image = sample["image"]
assert image is not None assert image is not None
image_path = sample["image_path"] image_path = sample["image_path"]
extra_body = None if lora_path is None else {"lora_path": lora_path}
response = await client.chat.completions.create( response = await client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
...@@ -96,16 +97,21 @@ async def process_sample( ...@@ -96,16 +97,21 @@ async def process_sample(
temperature=0, temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"], max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"], max_tokens=sampling_params["max_new_tokens"],
extra_body=extra_body,
) )
return sample, response.choices[0].message.content return sample, response.choices[0].message.content
async def process_sample_with_semaphore( async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict semaphore: asyncio.Semaphore,
client: Any,
sample: dict,
sampling_params: dict,
lora_path: Optional[str] = None,
) -> Tuple[dict, str]: ) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control.""" """Wrap process_sample with a semaphore for concurrency control."""
async with semaphore: async with semaphore:
return await process_sample(client, sample, sampling_params) return await process_sample(client, sample, sampling_params, lora_path)
async def eval_mmmu(args) -> None: async def eval_mmmu(args) -> None:
...@@ -113,6 +119,7 @@ async def eval_mmmu(args) -> None: ...@@ -113,6 +119,7 @@ async def eval_mmmu(args) -> None:
eval_args = EvalArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args)
sampling_params = get_sampling_params(eval_args) sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args) samples = prepare_samples(eval_args)
lora_path = eval_args.lora_path
answer_dict = {} answer_dict = {}
out_samples = {} out_samples = {}
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
...@@ -133,7 +140,9 @@ async def eval_mmmu(args) -> None: ...@@ -133,7 +140,9 @@ async def eval_mmmu(args) -> None:
samples = samples[: args.profile_number] samples = samples[: args.profile_number]
tasks = [ tasks = [
process_sample_with_semaphore(semaphore, client, sample, sampling_params) process_sample_with_semaphore(
semaphore, client, sample, sampling_params, lora_path
)
for sample in samples for sample in samples
] ]
......
...@@ -36,17 +36,22 @@ class EvalArgs: ...@@ -36,17 +36,22 @@ class EvalArgs:
profile: bool = False profile: bool = False
profile_number: int = 5 profile_number: int = 5
concurrency: int = 1 concurrency: int = 1
lora_path: Optional[str] = None
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--result-filename", type=str, default=EvalArgs.result_filename "--result-filename",
type=str,
default=EvalArgs.result_filename,
help="The filename to save the evaluation results.",
) )
parser.add_argument( parser.add_argument(
"--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit "--image-pixels-limit",
type=int,
default=EvalArgs.image_pixels_limit,
help="The maximum number of pixels allowed for an image. If an image exceeds this limit, it will be skipped during evaluation.",
) )
parser.add_argument( parser.add_argument(
"--dataset-path", "--dataset-path",
type=str, type=str,
...@@ -59,7 +64,12 @@ class EvalArgs: ...@@ -59,7 +64,12 @@ class EvalArgs:
type=str, type=str,
help="The path to the prompt format of mmmu. If not, a default format llava_config.yaml will be used", help="The path to the prompt format of mmmu. If not, a default format llava_config.yaml will be used",
) )
parser.add_argument("--split", type=str, default=EvalArgs.split) parser.add_argument(
"--split",
type=str,
default=EvalArgs.split,
help='Split of the dataset to use for evaluation. Default is "validation".',
)
parser.add_argument( parser.add_argument(
"--extra-request-body", "--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}', metavar='{"key1": "value1", "key2": "value2"}',
...@@ -72,9 +82,23 @@ class EvalArgs: ...@@ -72,9 +82,23 @@ class EvalArgs:
"--profile", action="store_true", help="enable mmmu profile" "--profile", action="store_true", help="enable mmmu profile"
) )
parser.add_argument( parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number "--profile-number",
type=int,
default=EvalArgs.profile_number,
help="Number of samples to profile. If not set, will profile all samples.",
)
parser.add_argument(
"--concurrency",
type=int,
default=EvalArgs.concurrency,
help="Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.",
)
parser.add_argument(
"--lora-path",
type=str,
default=EvalArgs.lora_path,
help="Specify the LoRA path to use for evaluation. If specified, the value will be specified in the body of every request as `lora-path`.",
) )
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment