"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9fc9c6dd7186732b1397765aa089f6d45c27c3ea"
Unverified Commit 92009bd2 authored by Zaili Wang's avatar Zaili Wang Committed by GitHub
Browse files

fix: fix MMMU loading issue (#11759)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 4ef981e2
...@@ -764,6 +764,7 @@ def get_dataset(args, tokenizer, model_id=None): ...@@ -764,6 +764,7 @@ def get_dataset(args, tokenizer, model_id=None):
image_content=args.image_content, image_content=args.image_content,
image_format=args.image_format, image_format=args.image_format,
image_resolution=args.image_resolution, image_resolution=args.image_resolution,
backend=args.backend,
) )
elif args.dataset_name == "generated-shared-prefix": elif args.dataset_name == "generated-shared-prefix":
assert not tokenize_prompt assert not tokenize_prompt
...@@ -781,6 +782,7 @@ def get_dataset(args, tokenizer, model_id=None): ...@@ -781,6 +782,7 @@ def get_dataset(args, tokenizer, model_id=None):
input_requests = sample_mmmu_requests( input_requests = sample_mmmu_requests(
num_requests=args.num_prompts, num_requests=args.num_prompts,
processor=processor, processor=processor,
backend=args.backend,
fixed_output_len=args.random_output_len, fixed_output_len=args.random_output_len,
random_sample=True, random_sample=True,
) )
...@@ -1009,6 +1011,7 @@ async def get_mooncake_request_over_time( ...@@ -1009,6 +1011,7 @@ async def get_mooncake_request_over_time(
def sample_mmmu_requests( def sample_mmmu_requests(
num_requests: int, num_requests: int,
processor: AutoProcessor | AutoTokenizer, processor: AutoProcessor | AutoTokenizer,
backend: str,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
random_sample: bool = True, random_sample: bool = True,
) -> List[DatasetRow]: ) -> List[DatasetRow]:
...@@ -1081,7 +1084,7 @@ def sample_mmmu_requests( ...@@ -1081,7 +1084,7 @@ def sample_mmmu_requests(
text_prompt = f"Question: {question}\n\nAnswer: " text_prompt = f"Question: {question}\n\nAnswer: "
output_len = fixed_output_len if fixed_output_len is not None else 256 output_len = fixed_output_len if fixed_output_len is not None else 256
data_row = create_mm_data_row( data_row = create_mm_data_row(
text_prompt, [image], [image_data], output_len, processor text_prompt, [image], [image_data], output_len, processor, backend
) )
filtered_dataset.append(data_row) filtered_dataset.append(data_row)
...@@ -1316,13 +1319,19 @@ def parse_image_resolution(image_resolution: str) -> Tuple[int, int]: ...@@ -1316,13 +1319,19 @@ def parse_image_resolution(image_resolution: str) -> Tuple[int, int]:
) )
def create_mm_data_row(text_prompt, images: list, images_base64, output_len, processor): def create_mm_data_row(
text_prompt, images: list, images_base64, output_len, processor, backend
):
try: try:
content_items = [ if type(processor).__name__ == "Phi4MMProcessor":
{"type": "image", "image": {"url": image_base64}} # <|endoftext10|> is the image token used in the phi-4-multimodal model.
for image_base64 in images_base64 content_items = text_prompt.replace("image 1", "|endoftext10|")
] else:
content_items.append({"type": "text", "text": text_prompt}) content_items = [
{"type": "image", "image": {"url": image_base64}}
for image_base64 in images_base64
]
content_items.append({"type": "text", "text": text_prompt})
prompt_str = processor.apply_chat_template( prompt_str = processor.apply_chat_template(
[{"role": "user", "content": content_items}], [{"role": "user", "content": content_items}],
add_generation_prompt=True, add_generation_prompt=True,
...@@ -1362,8 +1371,16 @@ def create_mm_data_row(text_prompt, images: list, images_base64, output_len, pro ...@@ -1362,8 +1371,16 @@ def create_mm_data_row(text_prompt, images: list, images_base64, output_len, pro
# Vision tokens = total tokens - text tokens # Vision tokens = total tokens - text tokens
vision_prompt_len = prompt_len - text_prompt_len vision_prompt_len = prompt_len - text_prompt_len
use_raw_prompt = backend in [
"sglang-oai",
"sglang-oai-chat",
"vllm",
"vllm-chat",
"lmdeploy",
"lmdeploy-chat",
]
return DatasetRow( return DatasetRow(
prompt=text_prompt, prompt=text_prompt if use_raw_prompt else prompt_str,
prompt_len=prompt_len, prompt_len=prompt_len,
output_len=output_len, output_len=output_len,
text_prompt_len=text_prompt_len, text_prompt_len=text_prompt_len,
...@@ -1382,6 +1399,7 @@ def sample_image_requests( ...@@ -1382,6 +1399,7 @@ def sample_image_requests(
image_content: str, image_content: str,
image_format: str, image_format: str,
image_resolution: str, image_resolution: str,
backend: str,
) -> List[DatasetRow]: ) -> List[DatasetRow]:
"""Generate requests with images. """Generate requests with images.
...@@ -1447,6 +1465,7 @@ def sample_image_requests( ...@@ -1447,6 +1465,7 @@ def sample_image_requests(
list(images_base64), list(images_base64),
int(output_lens[i]), int(output_lens[i]),
processor, processor,
backend,
) )
dataset.append(data_row) dataset.append(data_row)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment