Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -16,7 +16,7 @@ You can build a docker container using `examples/multimodal/Dockerfile` to run t
### Language model
Follow the instructions in [Mistral](../../docs/llama_mistral.md#mistral-7b) to download weights for Mistral-7B-Instruct-v0.3 (Base or Instruct) from HuggingFace and convert to mcore format with tensor parallel size 4.
Follow the instructions in [Mistral](../../docs/llama_mistral.md#mistral-7b) to download weights for Mistral-7B-Instruct-v0.3 from HuggingFace and convert to mcore format with tensor parallel size 4.
Please use the tokenizer from HuggingFace.
### Vision model
......@@ -113,7 +113,7 @@ Run the following script:
```
examples/multimodal/text_generation_mistral_clip.sh --input-image-path /path/to/input/images --output-path /some/output/directory \
--model-path /path/to/model.pt --tokenizer-path /path/to/tokenizer/ --gt-path /path/to/groundtruth/file --task generation-task-name
--model-path /path/to/model.pt --gt-path /path/to/groundtruth/file --task generation-task-name
```
where `--task generation-task-name` is the name of the evaluation benchmark such as `captioning` or `MMMU`.
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -7,34 +7,20 @@ from megatron.training.activations import fast_gelu, quick_gelu, squared_relu
def get_language_model_config(config):
if config.language_model_type == "2b":
if config.language_model_type == "llama3_8b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = True
config.bias_dropout_fusion = False
config.rotary_percent = 0.5
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
elif config.language_model_type == "8b":
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = False
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = True
config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm
)
config.bias_dropout_fusion = False
config.rotary_percent = 0.5
config.attention_dropout = 0.0
config.apply_rope_fusion = False
config.activation_func = squared_relu
config.ffn_hidden_size = 16384
config.masked_softmax_fusion = True
config.attention_softmax_in_fp32 = True
config.num_query_groups = 32
config.kv_channels = 128
config.rotary_interleaved = False
elif config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
elif config.language_model_type == "mistral_7b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
......@@ -47,7 +33,7 @@ def get_language_model_config(config):
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
elif config.language_model_type == "mistral_7b":
elif config.language_model_type == "yi-34b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
......@@ -59,10 +45,11 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
elif config.language_model_type == "yi-34b":
config.ffn_hidden_size = 20480
elif config.language_model_type == "qwen2.5_7B":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.add_qkv_bias = True
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
......@@ -72,7 +59,7 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 20480
config.ffn_hidden_size = 18944
elif config.language_model_type == "qwen2.0_72B":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
......@@ -168,13 +155,7 @@ def get_vision_projection_config(config, hidden_size):
config.bias_activation_fusion = False
config.add_bias_linear = False
config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model.
if config.language_model_type == "2b":
config.ffn_hidden_size = 5440
config.activation_func = torch.nn.functional.gelu
if config.language_model_type == "8b":
config.ffn_hidden_size = 16384
config.activation_func = squared_relu
elif config.language_model_type == "llama3_8b":
if config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "mistral_7b":
......@@ -185,6 +166,9 @@ def get_vision_projection_config(config, hidden_size):
config.ffn_hidden_size = 20480
config.normalization = "LayerNorm"
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.5_7B":
config.ffn_hidden_size = 3584
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.0_72B":
config.ffn_hidden_size = 29568
config.normalization = "LayerNorm"
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -2,16 +2,19 @@
import bisect
import dataclasses
import json
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from image_processing import get_visual_transform
from PIL import Image
from torchvision.transforms import ToPILImage
import numpy as np
import torch
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.energon import (
Batch,
......@@ -175,6 +178,10 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
self.img_h, self.img_w = self.args.img_h, self.args.img_w
# This map is used to reduce the number of tiles used per image if the number of tokens is
# larger than the decoder_seq_length.
self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1}
def _get_total_seq_length(self, input_ids, num_tiles):
"""Calculate expected sequence length given text tokens length and number of tiles."""
total_num_images = len(num_tiles)
......@@ -237,7 +244,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
cur_prompt = "<image>\n" + cur_prompt + "\n"
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n"
caption = sample.caption.strip()
......@@ -282,7 +289,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
# LLAVA training: override text-prompt with just the image.
conv = [
# Note: no system message.
{"role": "user", "content": "<image>\n"},
{"role": "user", "content": IMAGE_TOKEN + "\n"},
{"role": "assistant", "content": sample.answers},
]
......@@ -307,66 +314,130 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
"""Encode SFT sample."""
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
has_image = sample.__subflavors__['has_image'] if 'has_image' in sample.__subflavors__ else False
has_image = has_image or (hasattr(sample, "images") and len(sample.images) > 0)
if has_video:
# Grab the selected frames of the video as a tensor with shape
# fhwc: (num_frames, height, width, num_channels).
video_fhwc = sample.images[0].permute(0, 2, 3, 1)
selected_frames = torch.linspace(
0, video_fhwc.shape[0] - 1, self.args.num_frames).long()
video_frame_fhwc = video_fhwc[selected_frames]
imgs = []
for video_frame_hwc in video_frame_fhwc:
imgs += get_visual_transform(
video_frame_hwc, self.img_h, self.img_w,
self.args.use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
num_tiles = [len(imgs)]
elif has_image:
imgs = get_visual_transform(
sample.images[0], self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
self.args.vision_model_type,
)
num_tiles = [len(imgs)]
else:
imgs = num_tiles = []
sample.__key__ = "{}-{}".format("no-image", sample.__key__)
has_image = False
if hasattr(sample, "images"):
# If this is a text-only sample and we are freezing the LM,
# then use a dummy input image.
if len(sample.images) == 0 and self.args.freeze_LM:
empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255))
sample.images.append(empty_img)
if len(sample.images) > 0 and not has_video:
has_image = True
conversation = []
# Note: Some tokenizers may ignore the system prompt.
conversation.append({"role": "system", "content": "Answer the questions."})
has_image_token = False
conversation = [{"role": "system", "content": "Answer the questions."}]
# Format the conversation as a list of "user" / "assistant" turns.
for text in sample.texts:
if IMAGE_TOKEN in text["value"]:
has_image_token = True
if text["from"] == "human":
role = "user"
elif text["from"] == "gpt":
role = "assistant"
else:
raise RuntimeError(f"unexpected role {text['from']} in {sample.texts}")
turn = {"role": role, "content": text["value"]}
conversation.append(turn)
# If the sample contains an image but none of the user messages has an image token,
# then add it to the first user message.
if len(imgs) > 0 and not has_image_token:
error_msg = f"unexpected role {text['from']} in {sample.texts}"
assert text["from"] in ["human", "gpt"], error_msg
conversation.append({
"role": "user" if text["from"] == "human" else "assistant",
"content": text["value"]})
# Replace the image tags <image-idx> with IMAGE_TOKEN and count the number of image tags
number_image_tags = 0
image_tag_ids_list = []
for turn in conversation:
if turn["role"] == "user":
image_tag_ids = [int(x) - 1 for x in re.findall(r"<image-(\d+)>", turn["content"])]
image_tag_ids_list.extend(image_tag_ids)
turn["content"] = re.sub(r"<image-\d+>", IMAGE_TOKEN, turn["content"])
number_image_tags += turn["content"].count(IMAGE_TOKEN)
# For videos, we replace the image tag with the video tag
if has_video:
turn["content"] = turn["content"].replace(IMAGE_TOKEN, VIDEO_TOKEN)
# We re-order the images in sample.images according to how they appear in the conversation.
if len(image_tag_ids_list) > 0:
sample.images = [sample.images[idx] for idx in image_tag_ids_list]
# If there is only one image, but several image tags, we assume all the tags refer to the
# same image and duplicate the image:
if len(sample.images) == 1 and number_image_tags > 1:
sample.images = sample.images * number_image_tags
number_of_images = len(sample.images)
# Fail if there are more image or video tags than image or videos:
error_msg = (
f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}")
assert number_image_tags <= number_of_images, error_msg
# If there are less image of video tags than image or videos, prepend the tags to the first
# user message:
if number_image_tags < number_of_images:
for turn in conversation:
if turn["role"] == "user":
turn["content"] = f"{IMAGE_TOKEN}\n" + turn["content"]
tag_to_add = VIDEO_TOKEN if has_video else IMAGE_TOKEN
turn["content"] = tag_to_add*(number_of_images-number_image_tags) + "\n" + turn["content"]
break
input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)
if has_image:
imgs = []
num_tiles = []
max_num_tiles = self.args.max_num_tiles
# We keep a buffer of 4 tokens for the question,
# the rest can be used for image tokens.
max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4
# We start by extracting as many tiles per image as possible, and decrease the max
# number of tiles if there are too many image tokens.
while True:
imgs = []
num_tiles = []
for img in sample.images:
img_tiles = get_visual_transform(
img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
imgs += img_tiles
num_tiles += [len(img_tiles)]
if max_num_tiles == 1:
break
if sum(num_tiles) * self.token_per_img_tile > max_image_token_allowed:
if max_num_tiles in self.num_tiles_degradation_map:
max_num_tiles = self.num_tiles_degradation_map[max_num_tiles]
else:
raise RuntimeError((
f"Tried to decrease the number of tiles {max_num_tiles} but it's not ",
f"defined in the degradation map {self.num_tiles_degradation_map}"))
else:
break
elif has_video:
# We don't use tiling for videos to limit the number of tokens.
use_tiling=False
# Grab the selected frames of the video as a tensor with shape
# fhwc: (num_frames, num_channels, height, width).
video_fchw = sample.images[0].permute(0, 1, 2, 3)
selected_frames = torch.linspace(
0, video_fchw.shape[0] - 1, self.args.num_frames).long()
video_fchw = video_fchw[selected_frames]
imgs = []
for video_chw in video_fchw:
to_pil = ToPILImage()
video_chw = to_pil(video_chw)
imgs += get_visual_transform(
video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
num_tiles = [len(imgs)]
else:
imgs = num_tiles = []
if self.is_packing_enabled:
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
# Some final checks with respect to the number of image tokens and images on the tokenized
# conversation. There can still be errors, for instance if a non-video sample happens to
# have our pre-defined video token, or if the packing truncation removed a necessary image
# tag.
number_image_token = np.sum(input_ids == self.img_token_id)
error_msg = (
f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.")
assert number_image_token == len(num_tiles), error_msg
error_msg = (
f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.")
assert np.sum(num_tiles) == len(imgs), error_msg
return ImageTaskSample(
__key__=sample.__key__,
__restore_key__=sample.__restore_key__,
......@@ -407,8 +478,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
if isinstance(sample, MultiChoiceVQASample):
cur_prompt = format_multichoice_question(sample.context, sample.choices)
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = format_multichoice_answer(sample.correct_choice_idx)
elif isinstance(sample, VQASample):
if 'docvqa' in sample.__key__:
......@@ -423,8 +494,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
cur_prompt = cur_prompt.format(sample.context)
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
if isinstance(sample.answers, list):
answer_list = sample.answers
......@@ -505,11 +576,11 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_list = self.manual_prompts["DocPretraining"]["raw"]
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
# Make sure there is no extra <image> tag.
sample.text = sample.text.replace("<image>", "")
# Make sure there is no extra IMAGE_TOKEN tag.
sample.text = sample.text.replace(IMAGE_TOKEN, "")
caption = sample.text.strip()
......@@ -526,8 +597,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
ref = sample.text
region = sample.words_boxes
# Make sure there is no extra <image> tag
ref = ref.replace("<image>", "")
# Make sure there is no extra IMAGE_TOKEN tag
ref = ref.replace(IMAGE_TOKEN, "")
if len(region) == 4:
region = f"<box>({region[0]},{region[1]}),({region[2]},{region[3]})</box>"
......@@ -550,8 +621,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
cur_prompt = cur_prompt.format(prompt_content)
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
return sample, cur_prompt, answer
......@@ -559,8 +630,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
"""Format bbox coordinates as text."""
assert len(bbox) == 4 or len(bbox) == 8
# Make sure there is no extra <image> tag
text = text.replace("<image>", "")
# Make sure there is no extra IMAGE_TOKEN tag
text = text.replace(IMAGE_TOKEN, "")
if len(bbox) == 4:
label_str = f"<ref>{text}</ref><box>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})</box>"
......@@ -582,8 +653,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = answer
return sample, cur_prompt, cur_answer
......
......@@ -9,19 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D")
results = []
results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)
sample_id = res["sample_id"]
# Ignore possible duplicates.
if sample_id in results:
continue
results[sample_id] = {
"question_id": sample_id,
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
......
......@@ -9,15 +9,22 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="ChartQA")
results = []
results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
res["question_id"] = res["sample_id"]
sample_id = res["sample_id"]
results.append(res)
# Ignore possible duplicates.
if sample_id in results:
continue
res["question_id"] = sample_id
results[sample_id] = res
results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
......
......@@ -11,20 +11,28 @@ def convert_to_coco_format(input_path):
"""Convert input files to COCO compatible format."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="captioning")
captions = []
results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
sample_id = res["sample_id"]
question_id = res['sample_id']
caption = res['caption'].rstrip('.').lower()
# Ignore possible duplicates.
if sample_id in results:
continue
captions.append({"image_id": question_id, "caption": caption})
caption = res["caption"].rstrip(".").lower()
results[sample_id] = {
"image_id": sample_id,
"caption": caption,
}
results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(captions, output_file, indent=4)
json.dump(results, output_file, indent=4)
return output_file_path
......
......@@ -11,13 +11,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista")
results = []
results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(res)
sample_id = res["sample_id"]
# Remove possible duplicates.
if sample_id in results:
continue
results[sample_id] = res
results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
......
......@@ -2,9 +2,15 @@ import argparse
import glob
import json
import os
import sys
import re
import subprocess
# Get the absolute path of the parent directory
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)
from run_text_generation import get_output_path
from config import EvaluationConfig
......@@ -48,6 +54,10 @@ def convert_to_mmmu_format(input_path):
)
# MMMU eval script expects just a sample_id to prediction mapping.
# Skip possible duplicates.
if sample_id in output:
continue
output[sample_id] = prediction
with open(output_file_path, "w") as output_file:
......
......@@ -8,13 +8,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench")
results = []
results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(res)
sample_id = res["sample_id"]
# Remove possible duplicates.
if sample_id in results:
continue
results[sample_id] = res
results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
......
......@@ -9,22 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="TextVQA")
results = []
results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)
# Make order deterministic.
# results = sorted(results, key=lambda d: d["question_id"])
sample_id = res["sample_id"]
# Remove possible duplicates.
if sample_id in results:
continue
results[sample_id] = {
"question_id": sample_id,
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment