"vscode:/vscode.git/clone" did not exist on "094891c01a29ebef8bc1977f434b33283f9114c9"
Unverified Commit 4fa44d63 authored by Mick's avatar Mick Committed by GitHub
Browse files

chore: improve mmmu benchmark (#7000)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent e6312d27
......@@ -125,7 +125,6 @@ async def eval_mmmu(args) -> None:
client = openai.AsyncOpenAI(
api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
)
semaphore = asyncio.Semaphore(args.concurrency)
start = time.perf_counter()
base_url = f"http://127.0.0.1:{args.port}"
......@@ -139,16 +138,26 @@ async def eval_mmmu(args) -> None:
samples = samples[: args.profile_number]
tasks = [
process_sample_with_semaphore(
semaphore, client, sample, sampling_params, lora_path
)
for sample in samples
]
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
sample, response = await coro
process_result(response, sample, answer_dict, out_samples)
if args.concurrency == 1:
# For concurrency == 1, run in sequential mode to ensure consistent order
# this is mainly for profiling
for sample in tqdm(samples):
_, response = await process_sample(
client, sample, sampling_params, lora_path
)
process_result(response, sample, answer_dict, out_samples)
else:
semaphore = asyncio.Semaphore(args.concurrency)
tasks = [
process_sample_with_semaphore(
semaphore, client, sample, sampling_params, lora_path
)
for sample in samples
]
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
sample, response = await coro
process_result(response, sample, answer_dict, out_samples)
if args.profile:
print("Stopping profiler...")
......
......@@ -27,8 +27,7 @@ from tqdm import tqdm
class EvalArgs:
seed: int = 42
split: str = "validation"
# Default setting to make the benchmark available on A100 for most 7B models
image_pixels_limit: int = 4300000
image_pixels_limit: int = -1
result_filename: str = ""
prompt_format_file: str = "prompt_format.yaml"
dataset_path: str = "MMMU/MMMU"
......@@ -190,7 +189,7 @@ def prepare_samples(eval_args: EvalArgs):
sample = construct_prompt(sample, eval_args.config)
image = sample["image"]
width, height = image.size
if width * height >= eval_args.image_pixels_limit:
if 0 < eval_args.image_pixels_limit <= width * height:
return None, True
# Use a unique identifier for the image path to avoid potential collisions if indices reset
image_path = f"{images_path}/image_{sample['id']}.png"
......@@ -217,6 +216,8 @@ def prepare_samples(eval_args: EvalArgs):
elif sample:
samples.append(sample)
samples.sort(key=lambda x: x["final_input_prompt"])
print(
f"Skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment