"mmdet3d/vscode:/vscode.git/clone" did not exist on "b05017276c1c785cfd57cf4f4ebe89863d2783b0"
Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -16,7 +16,7 @@ You can build a docker container using `examples/multimodal/Dockerfile` to run t ...@@ -16,7 +16,7 @@ You can build a docker container using `examples/multimodal/Dockerfile` to run t
### Language model ### Language model
Follow the instructions in [Mistral](../../docs/llama_mistral.md#mistral-7b) to download weights for Mistral-7B-Instruct-v0.3 (Base or Instruct) from HuggingFace and convert to mcore format with tensor parallel size 4. Follow the instructions in [Mistral](../../docs/llama_mistral.md#mistral-7b) to download weights for Mistral-7B-Instruct-v0.3 from HuggingFace and convert to mcore format with tensor parallel size 4.
Please use the tokenizer from HuggingFace. Please use the tokenizer from HuggingFace.
### Vision model ### Vision model
...@@ -113,7 +113,7 @@ Run the following script: ...@@ -113,7 +113,7 @@ Run the following script:
``` ```
examples/multimodal/text_generation_mistral_clip.sh --input-image-path /path/to/input/images --output-path /some/output/directory \ examples/multimodal/text_generation_mistral_clip.sh --input-image-path /path/to/input/images --output-path /some/output/directory \
--model-path /path/to/model.pt --tokenizer-path /path/to/tokenizer/ --gt-path /path/to/groundtruth/file --task generation-task-name --model-path /path/to/model.pt --gt-path /path/to/groundtruth/file --task generation-task-name
``` ```
where `--task generation-task-name` is the name of the evaluation benchmark such as `captioning` or `MMMU`. where `--task generation-task-name` is the name of the evaluation benchmark such as `captioning` or `MMMU`.
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -7,34 +7,20 @@ from megatron.training.activations import fast_gelu, quick_gelu, squared_relu ...@@ -7,34 +7,20 @@ from megatron.training.activations import fast_gelu, quick_gelu, squared_relu
def get_language_model_config(config): def get_language_model_config(config):
if config.language_model_type == "2b": if config.language_model_type == "llama3_8b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.bias_activation_fusion = False config.bias_activation_fusion = False
config.gated_linear_unit = True config.gated_linear_unit = True
config.apply_query_key_layer_scaling = True config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = True config.layernorm_zero_centered_gamma = (
config.bias_dropout_fusion = False False # Zero centered gamma not supported for RMSNorm
config.rotary_percent = 0.5 )
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
elif config.language_model_type == "8b":
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = False
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = True
config.bias_dropout_fusion = False config.bias_dropout_fusion = False
config.rotary_percent = 0.5
config.attention_dropout = 0.0
config.apply_rope_fusion = False config.apply_rope_fusion = False
config.activation_func = squared_relu
config.ffn_hidden_size = 16384
config.masked_softmax_fusion = True
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.num_query_groups = 32 config.ffn_hidden_size = 14336
config.kv_channels = 128 elif config.language_model_type == "mistral_7b":
config.rotary_interleaved = False
elif config.language_model_type == "llama3_8b":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.bias_activation_fusion = False config.bias_activation_fusion = False
...@@ -47,7 +33,7 @@ def get_language_model_config(config): ...@@ -47,7 +33,7 @@ def get_language_model_config(config):
config.apply_rope_fusion = False config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336 config.ffn_hidden_size = 14336
elif config.language_model_type == "mistral_7b": elif config.language_model_type == "yi-34b":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.bias_activation_fusion = False config.bias_activation_fusion = False
...@@ -59,10 +45,11 @@ def get_language_model_config(config): ...@@ -59,10 +45,11 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False config.bias_dropout_fusion = False
config.apply_rope_fusion = False config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336 config.ffn_hidden_size = 20480
elif config.language_model_type == "yi-34b": elif config.language_model_type == "qwen2.5_7B":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.add_qkv_bias = True
config.bias_activation_fusion = False config.bias_activation_fusion = False
config.gated_linear_unit = True config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False config.apply_query_key_layer_scaling = False
...@@ -72,7 +59,7 @@ def get_language_model_config(config): ...@@ -72,7 +59,7 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False config.bias_dropout_fusion = False
config.apply_rope_fusion = False config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 20480 config.ffn_hidden_size = 18944
elif config.language_model_type == "qwen2.0_72B": elif config.language_model_type == "qwen2.0_72B":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
...@@ -168,13 +155,7 @@ def get_vision_projection_config(config, hidden_size): ...@@ -168,13 +155,7 @@ def get_vision_projection_config(config, hidden_size):
config.bias_activation_fusion = False config.bias_activation_fusion = False
config.add_bias_linear = False config.add_bias_linear = False
config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model. config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model.
if config.language_model_type == "2b": if config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 5440
config.activation_func = torch.nn.functional.gelu
if config.language_model_type == "8b":
config.ffn_hidden_size = 16384
config.activation_func = squared_relu
elif config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336 config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "mistral_7b": elif config.language_model_type == "mistral_7b":
...@@ -185,6 +166,9 @@ def get_vision_projection_config(config, hidden_size): ...@@ -185,6 +166,9 @@ def get_vision_projection_config(config, hidden_size):
config.ffn_hidden_size = 20480 config.ffn_hidden_size = 20480
config.normalization = "LayerNorm" config.normalization = "LayerNorm"
config.activation_func = torch.nn.functional.gelu config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.5_7B":
config.ffn_hidden_size = 3584
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.0_72B": elif config.language_model_type == "qwen2.0_72B":
config.ffn_hidden_size = 29568 config.ffn_hidden_size = 29568
config.normalization = "LayerNorm" config.normalization = "LayerNorm"
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -2,16 +2,19 @@ ...@@ -2,16 +2,19 @@
import bisect import bisect
import dataclasses import dataclasses
import json import json
import re
import sys import sys
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from image_processing import get_visual_transform from image_processing import get_visual_transform
from PIL import Image
from torchvision.transforms import ToPILImage
import numpy as np import numpy as np
import torch import torch
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.energon import ( from megatron.energon import (
Batch, Batch,
...@@ -175,6 +178,10 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -175,6 +178,10 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
self.img_h, self.img_w = self.args.img_h, self.args.img_w self.img_h, self.img_w = self.args.img_h, self.args.img_w
# This map is used to reduce the number of tiles used per image if the number of tokens is
# larger than the decoder_seq_length.
self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1}
def _get_total_seq_length(self, input_ids, num_tiles): def _get_total_seq_length(self, input_ids, num_tiles):
"""Calculate expected sequence length given text tokens length and number of tiles.""" """Calculate expected sequence length given text tokens length and number of tiles."""
total_num_images = len(num_tiles) total_num_images = len(num_tiles)
...@@ -237,7 +244,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -237,7 +244,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list)) prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx] cur_prompt = prompt_list[prompt_idx]
cur_prompt = "<image>\n" + cur_prompt + "\n" cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n"
caption = sample.caption.strip() caption = sample.caption.strip()
...@@ -282,7 +289,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -282,7 +289,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
# LLAVA training: override text-prompt with just the image. # LLAVA training: override text-prompt with just the image.
conv = [ conv = [
# Note: no system message. # Note: no system message.
{"role": "user", "content": "<image>\n"}, {"role": "user", "content": IMAGE_TOKEN + "\n"},
{"role": "assistant", "content": sample.answers}, {"role": "assistant", "content": sample.answers},
] ]
...@@ -307,66 +314,130 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -307,66 +314,130 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
"""Encode SFT sample.""" """Encode SFT sample."""
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
has_image = sample.__subflavors__['has_image'] if 'has_image' in sample.__subflavors__ else False
has_image = has_image or (hasattr(sample, "images") and len(sample.images) > 0)
if has_video: has_image = False
# Grab the selected frames of the video as a tensor with shape if hasattr(sample, "images"):
# fhwc: (num_frames, height, width, num_channels). # If this is a text-only sample and we are freezing the LM,
video_fhwc = sample.images[0].permute(0, 2, 3, 1) # then use a dummy input image.
selected_frames = torch.linspace( if len(sample.images) == 0 and self.args.freeze_LM:
0, video_fhwc.shape[0] - 1, self.args.num_frames).long() empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255))
video_frame_fhwc = video_fhwc[selected_frames] sample.images.append(empty_img)
imgs = [] if len(sample.images) > 0 and not has_video:
for video_frame_hwc in video_frame_fhwc: has_image = True
imgs += get_visual_transform(
video_frame_hwc, self.img_h, self.img_w,
self.args.use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
num_tiles = [len(imgs)]
elif has_image:
imgs = get_visual_transform(
sample.images[0], self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
self.args.vision_model_type,
)
num_tiles = [len(imgs)]
else:
imgs = num_tiles = []
sample.__key__ = "{}-{}".format("no-image", sample.__key__)
conversation = []
# Note: Some tokenizers may ignore the system prompt. # Note: Some tokenizers may ignore the system prompt.
conversation.append({"role": "system", "content": "Answer the questions."}) conversation = [{"role": "system", "content": "Answer the questions."}]
# Format the conversation as a list of "user" / "assistant" turns.
has_image_token = False
for text in sample.texts: for text in sample.texts:
if IMAGE_TOKEN in text["value"]: error_msg = f"unexpected role {text['from']} in {sample.texts}"
has_image_token = True assert text["from"] in ["human", "gpt"], error_msg
conversation.append({
if text["from"] == "human": "role": "user" if text["from"] == "human" else "assistant",
role = "user" "content": text["value"]})
elif text["from"] == "gpt":
role = "assistant" # Replace the image tags <image-idx> with IMAGE_TOKEN and count the number of image tags
else: number_image_tags = 0
raise RuntimeError(f"unexpected role {text['from']} in {sample.texts}") image_tag_ids_list = []
for turn in conversation:
turn = {"role": role, "content": text["value"]} if turn["role"] == "user":
conversation.append(turn) image_tag_ids = [int(x) - 1 for x in re.findall(r"<image-(\d+)>", turn["content"])]
image_tag_ids_list.extend(image_tag_ids)
# If the sample contains an image but none of the user messages has an image token, turn["content"] = re.sub(r"<image-\d+>", IMAGE_TOKEN, turn["content"])
# then add it to the first user message. number_image_tags += turn["content"].count(IMAGE_TOKEN)
if len(imgs) > 0 and not has_image_token: # For videos, we replace the image tag with the video tag
if has_video:
turn["content"] = turn["content"].replace(IMAGE_TOKEN, VIDEO_TOKEN)
# We re-order the images in sample.images according to how they appear in the conversation.
if len(image_tag_ids_list) > 0:
sample.images = [sample.images[idx] for idx in image_tag_ids_list]
# If there is only one image, but several image tags, we assume all the tags refer to the
# same image and duplicate the image:
if len(sample.images) == 1 and number_image_tags > 1:
sample.images = sample.images * number_image_tags
number_of_images = len(sample.images)
# Fail if there are more image or video tags than image or videos:
error_msg = (
f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}")
assert number_image_tags <= number_of_images, error_msg
# If there are less image of video tags than image or videos, prepend the tags to the first
# user message:
if number_image_tags < number_of_images:
for turn in conversation: for turn in conversation:
if turn["role"] == "user": if turn["role"] == "user":
turn["content"] = f"{IMAGE_TOKEN}\n" + turn["content"] tag_to_add = VIDEO_TOKEN if has_video else IMAGE_TOKEN
turn["content"] = tag_to_add*(number_of_images-number_image_tags) + "\n" + turn["content"]
break break
input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)
if has_image:
imgs = []
num_tiles = []
max_num_tiles = self.args.max_num_tiles
# We keep a buffer of 4 tokens for the question,
# the rest can be used for image tokens.
max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4
# We start by extracting as many tiles per image as possible, and decrease the max
# number of tiles if there are too many image tokens.
while True:
imgs = []
num_tiles = []
for img in sample.images:
img_tiles = get_visual_transform(
img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
imgs += img_tiles
num_tiles += [len(img_tiles)]
if max_num_tiles == 1:
break
if sum(num_tiles) * self.token_per_img_tile > max_image_token_allowed:
if max_num_tiles in self.num_tiles_degradation_map:
max_num_tiles = self.num_tiles_degradation_map[max_num_tiles]
else:
raise RuntimeError((
f"Tried to decrease the number of tiles {max_num_tiles} but it's not ",
f"defined in the degradation map {self.num_tiles_degradation_map}"))
else:
break
elif has_video:
# We don't use tiling for videos to limit the number of tokens.
use_tiling=False
# Grab the selected frames of the video as a tensor with shape
# fhwc: (num_frames, num_channels, height, width).
video_fchw = sample.images[0].permute(0, 1, 2, 3)
selected_frames = torch.linspace(
0, video_fchw.shape[0] - 1, self.args.num_frames).long()
video_fchw = video_fchw[selected_frames]
imgs = []
for video_chw in video_fchw:
to_pil = ToPILImage()
video_chw = to_pil(video_chw)
imgs += get_visual_transform(
video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
num_tiles = [len(imgs)]
else:
imgs = num_tiles = []
if self.is_packing_enabled: if self.is_packing_enabled:
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
# Some final checks with respect to the number of image tokens and images on the tokenized
# conversation. There can still be errors, for instance if a non-video sample happens to
# have our pre-defined video token, or if the packing truncation removed a necessary image
# tag.
number_image_token = np.sum(input_ids == self.img_token_id)
error_msg = (
f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.")
assert number_image_token == len(num_tiles), error_msg
error_msg = (
f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.")
assert np.sum(num_tiles) == len(imgs), error_msg
return ImageTaskSample( return ImageTaskSample(
__key__=sample.__key__, __key__=sample.__key__,
__restore_key__=sample.__restore_key__, __restore_key__=sample.__restore_key__,
...@@ -407,8 +478,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -407,8 +478,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
if isinstance(sample, MultiChoiceVQASample): if isinstance(sample, MultiChoiceVQASample):
cur_prompt = format_multichoice_question(sample.context, sample.choices) cur_prompt = format_multichoice_question(sample.context, sample.choices)
if "<image>" not in cur_prompt: if IMAGE_TOKEN not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = format_multichoice_answer(sample.correct_choice_idx) cur_answer = format_multichoice_answer(sample.correct_choice_idx)
elif isinstance(sample, VQASample): elif isinstance(sample, VQASample):
if 'docvqa' in sample.__key__: if 'docvqa' in sample.__key__:
...@@ -423,8 +494,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -423,8 +494,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
cur_prompt = cur_prompt.format(sample.context) cur_prompt = cur_prompt.format(sample.context)
if "<image>" not in cur_prompt: if IMAGE_TOKEN not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
if isinstance(sample.answers, list): if isinstance(sample.answers, list):
answer_list = sample.answers answer_list = sample.answers
...@@ -505,11 +576,11 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -505,11 +576,11 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_list = self.manual_prompts["DocPretraining"]["raw"] prompt_list = self.manual_prompts["DocPretraining"]["raw"]
prompt_idx = np.random.randint(len(prompt_list)) prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx] cur_prompt = prompt_list[prompt_idx]
if "<image>" not in cur_prompt: if IMAGE_TOKEN not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
# Make sure there is no extra <image> tag. # Make sure there is no extra IMAGE_TOKEN tag.
sample.text = sample.text.replace("<image>", "") sample.text = sample.text.replace(IMAGE_TOKEN, "")
caption = sample.text.strip() caption = sample.text.strip()
...@@ -526,8 +597,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -526,8 +597,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
ref = sample.text ref = sample.text
region = sample.words_boxes region = sample.words_boxes
# Make sure there is no extra <image> tag # Make sure there is no extra IMAGE_TOKEN tag
ref = ref.replace("<image>", "") ref = ref.replace(IMAGE_TOKEN, "")
if len(region) == 4: if len(region) == 4:
region = f"<box>({region[0]},{region[1]}),({region[2]},{region[3]})</box>" region = f"<box>({region[0]},{region[1]}),({region[2]},{region[3]})</box>"
...@@ -550,8 +621,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -550,8 +621,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list)) prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx] cur_prompt = prompt_list[prompt_idx]
cur_prompt = cur_prompt.format(prompt_content) cur_prompt = cur_prompt.format(prompt_content)
if "<image>" not in cur_prompt: if IMAGE_TOKEN not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
return sample, cur_prompt, answer return sample, cur_prompt, answer
...@@ -559,8 +630,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -559,8 +630,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
"""Format bbox coordinates as text.""" """Format bbox coordinates as text."""
assert len(bbox) == 4 or len(bbox) == 8 assert len(bbox) == 4 or len(bbox) == 8
# Make sure there is no extra <image> tag # Make sure there is no extra IMAGE_TOKEN tag
text = text.replace("<image>", "") text = text.replace(IMAGE_TOKEN, "")
if len(bbox) == 4: if len(bbox) == 4:
label_str = f"<ref>{text}</ref><box>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})</box>" label_str = f"<ref>{text}</ref><box>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})</box>"
...@@ -582,8 +653,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, ...@@ -582,8 +653,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list)) prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx] cur_prompt = prompt_list[prompt_idx]
if "<image>" not in cur_prompt: if IMAGE_TOKEN not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = answer cur_answer = answer
return sample, cur_prompt, cur_answer return sample, cur_prompt, cur_answer
......
...@@ -9,19 +9,25 @@ def merge_input_files(input_path): ...@@ -9,19 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator.""" """Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D") input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D")
results = [] results = dict()
for input_file_path in input_file_paths: for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file: with open(input_file_path, "r") as input_file:
for line in input_file: for line in input_file:
res = json.loads(line) res = json.loads(line)
results.append( sample_id = res["sample_id"]
{
"question_id": res["sample_id"], # Ignore possible duplicates.
"answer": res["answer"], if sample_id in results:
"gt_answer": res["gt_answer"], continue
}
) results[sample_id] = {
"question_id": sample_id,
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
results = list(results.values())
with open(output_file_path, "w") as output_file: with open(output_file_path, "w") as output_file:
json.dump(results, output_file) json.dump(results, output_file)
......
...@@ -9,15 +9,22 @@ def merge_input_files(input_path): ...@@ -9,15 +9,22 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator.""" """Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="ChartQA") input_file_paths, output_file_path = get_input_output_paths(input_path, task="ChartQA")
results = [] results = dict()
for input_file_path in input_file_paths: for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file: with open(input_file_path, "r") as input_file:
for line in input_file: for line in input_file:
res = json.loads(line) res = json.loads(line)
res["question_id"] = res["sample_id"] sample_id = res["sample_id"]
results.append(res) # Ignore possible duplicates.
if sample_id in results:
continue
res["question_id"] = sample_id
results[sample_id] = res
results = list(results.values())
with open(output_file_path, "w") as output_file: with open(output_file_path, "w") as output_file:
json.dump(results, output_file) json.dump(results, output_file)
......
...@@ -11,20 +11,28 @@ def convert_to_coco_format(input_path): ...@@ -11,20 +11,28 @@ def convert_to_coco_format(input_path):
"""Convert input files to COCO compatible format.""" """Convert input files to COCO compatible format."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="captioning") input_file_paths, output_file_path = get_input_output_paths(input_path, task="captioning")
captions = [] results = dict()
for input_file_path in input_file_paths: for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file: with open(input_file_path, "r") as input_file:
for line in input_file: for line in input_file:
res = json.loads(line) res = json.loads(line)
sample_id = res["sample_id"]
question_id = res['sample_id'] # Ignore possible duplicates.
caption = res['caption'].rstrip('.').lower() if sample_id in results:
continue
captions.append({"image_id": question_id, "caption": caption}) caption = res["caption"].rstrip(".").lower()
results[sample_id] = {
"image_id": sample_id,
"caption": caption,
}
results = list(results.values())
with open(output_file_path, "w") as output_file: with open(output_file_path, "w") as output_file:
json.dump(captions, output_file, indent=4) json.dump(results, output_file, indent=4)
return output_file_path return output_file_path
......
...@@ -11,13 +11,21 @@ def merge_input_files(input_path): ...@@ -11,13 +11,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator.""" """Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista") input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista")
results = [] results = dict()
for input_file_path in input_file_paths: for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file: with open(input_file_path, "r") as input_file:
for line in input_file: for line in input_file:
res = json.loads(line) res = json.loads(line)
results.append(res) sample_id = res["sample_id"]
# Remove possible duplicates.
if sample_id in results:
continue
results[sample_id] = res
results = list(results.values())
with open(output_file_path, "w") as output_file: with open(output_file_path, "w") as output_file:
json.dump(results, output_file) json.dump(results, output_file)
......
...@@ -2,9 +2,15 @@ import argparse ...@@ -2,9 +2,15 @@ import argparse
import glob import glob
import json import json
import os import os
import sys
import re import re
import subprocess import subprocess
# Get the absolute path of the parent directory
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)
from run_text_generation import get_output_path from run_text_generation import get_output_path
from config import EvaluationConfig from config import EvaluationConfig
...@@ -48,6 +54,10 @@ def convert_to_mmmu_format(input_path): ...@@ -48,6 +54,10 @@ def convert_to_mmmu_format(input_path):
) )
# MMMU eval script expects just a sample_id to prediction mapping. # MMMU eval script expects just a sample_id to prediction mapping.
# Skip possible duplicates.
if sample_id in output:
continue
output[sample_id] = prediction output[sample_id] = prediction
with open(output_file_path, "w") as output_file: with open(output_file_path, "w") as output_file:
......
...@@ -8,13 +8,21 @@ def merge_input_files(input_path): ...@@ -8,13 +8,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator.""" """Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench") input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench")
results = [] results = dict()
for input_file_path in input_file_paths: for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file: with open(input_file_path, "r") as input_file:
for line in input_file: for line in input_file:
res = json.loads(line) res = json.loads(line)
results.append(res) sample_id = res["sample_id"]
# Remove possible duplicates.
if sample_id in results:
continue
results[sample_id] = res
results = list(results.values())
with open(output_file_path, "w") as output_file: with open(output_file_path, "w") as output_file:
json.dump(results, output_file) json.dump(results, output_file)
......
...@@ -9,22 +9,25 @@ def merge_input_files(input_path): ...@@ -9,22 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator.""" """Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="TextVQA") input_file_paths, output_file_path = get_input_output_paths(input_path, task="TextVQA")
results = [] results = dict()
for input_file_path in input_file_paths: for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file: with open(input_file_path, "r") as input_file:
for line in input_file: for line in input_file:
res = json.loads(line) res = json.loads(line)
results.append( sample_id = res["sample_id"]
{
"question_id": res["sample_id"], # Remove possible duplicates.
"answer": res["answer"], if sample_id in results:
"gt_answer": res["gt_answer"], continue
}
) results[sample_id] = {
"question_id": sample_id,
# Make order deterministic. "answer": res["answer"],
# results = sorted(results, key=lambda d: d["question_id"]) "gt_answer": res["gt_answer"],
}
results = list(results.values())
with open(output_file_path, "w") as output_file: with open(output_file_path, "w") as output_file:
json.dump(results, output_file) json.dump(results, output_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment