Unverified Commit e07d0647 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support LoRA in MMMU benchmark script. (#7218)

parent 3c2274fb
......@@ -18,6 +18,15 @@ python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
You can use `--lora-path` to specify the LoRA adapter to apply during benchmarking. E.g.,
```
# Launch server with LoRA enabled
python -m sglang.launch_server --model-path microsoft/Phi-4-multimodal-instruct --port 30000 --trust-remote-code --disable-radix-cache --lora-paths vision=<LoRA path>
# Apply LoRA adapter during inferencing
python -m benchmark/mmmu/bench_sglang.py --concurrency 8 --lora-path vision
```
### Evaluate hf
```
......
......@@ -15,7 +15,7 @@ import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple
import aiohttp
import openai
......@@ -73,7 +73,7 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
async def process_sample(
client: Any, sample: dict, sampling_params: dict
client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"]
......@@ -81,6 +81,7 @@ async def process_sample(
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
extra_body = None if lora_path is None else {"lora_path": lora_path}
response = await client.chat.completions.create(
model="default",
messages=[
......@@ -96,16 +97,21 @@ async def process_sample(
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
extra_body=extra_body,
)
return sample, response.choices[0].message.content
async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
semaphore: asyncio.Semaphore,
client: Any,
sample: dict,
sampling_params: dict,
lora_path: Optional[str] = None,
) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control."""
async with semaphore:
return await process_sample(client, sample, sampling_params)
return await process_sample(client, sample, sampling_params, lora_path)
async def eval_mmmu(args) -> None:
......@@ -113,6 +119,7 @@ async def eval_mmmu(args) -> None:
eval_args = EvalArgs.from_cli_args(args)
sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args)
lora_path = eval_args.lora_path
answer_dict = {}
out_samples = {}
client = openai.AsyncOpenAI(
......@@ -133,7 +140,9 @@ async def eval_mmmu(args) -> None:
samples = samples[: args.profile_number]
tasks = [
process_sample_with_semaphore(semaphore, client, sample, sampling_params)
process_sample_with_semaphore(
semaphore, client, sample, sampling_params, lora_path
)
for sample in samples
]
......
......@@ -36,17 +36,22 @@ class EvalArgs:
profile: bool = False
profile_number: int = 5
concurrency: int = 1
lora_path: Optional[str] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--result-filename", type=str, default=EvalArgs.result_filename
"--result-filename",
type=str,
default=EvalArgs.result_filename,
help="The filename to save the evaluation results.",
)
parser.add_argument(
"--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit
"--image-pixels-limit",
type=int,
default=EvalArgs.image_pixels_limit,
help="The maximum number of pixels allowed for an image. If an image exceeds this limit, it will be skipped during evaluation.",
)
parser.add_argument(
"--dataset-path",
type=str,
......@@ -59,7 +64,12 @@ class EvalArgs:
type=str,
help="The path to the prompt format of mmmu. If not, a default format llava_config.yaml will be used",
)
parser.add_argument("--split", type=str, default=EvalArgs.split)
parser.add_argument(
"--split",
type=str,
default=EvalArgs.split,
help='Split of the dataset to use for evaluation. Default is "validation".',
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
......@@ -72,9 +82,23 @@ class EvalArgs:
"--profile", action="store_true", help="enable mmmu profile"
)
parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number
"--profile-number",
type=int,
default=EvalArgs.profile_number,
help="Number of samples to profile. If not set, will profile all samples.",
)
parser.add_argument(
"--concurrency",
type=int,
default=EvalArgs.concurrency,
help="Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.",
)
parser.add_argument(
"--lora-path",
type=str,
default=EvalArgs.lora_path,
help="Specify the LoRA path to use for evaluation. If specified, the value will be specified in the body of every request as `lora-path`.",
)
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment