Commit 4567a27f authored by jerrrrry's avatar jerrrrry
Browse files

“core13.0”

parents
import argparse
import json
from .evaluate_vqav2 import compute_vqa_accuracy
from .evaluate_mmmu import get_input_output_paths
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MotionBench")
results = []
collected = set()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
res["question_id"] = res["sample_id"]
if res['sample_id'] in collected:
continue
collected.add(res['sample_id'])
results.append(res)
with open(output_file_path, "w") as output_file:
json.dump(results, output_file, indent=4, sort_keys=True)
return output_file_path
def motionbench_eval(input_path):
result_file_path = merge_input_files(input_path)
return compute_vqa_accuracy(result_file_path, task="MotionBench")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = motionbench_eval(args.input_path)
print(f"MotionBench accuracy: {avg_acc:.2f}")
import argparse
import json
from .evaluate_mmmu import get_input_output_paths
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MVBench")
results = []
collected = set()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
res["question_id"] = "{}-{}".format(res['task_type'], res['sample_id'])
if res['sample_id'] in collected:
continue
collected.add(res['sample_id'])
results.append(res)
with open(output_file_path, "w") as output_file:
json.dump(results, output_file, indent=4, sort_keys=True)
return output_file_path
# The following code is adapted from
# https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/mvbench.ipynb
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/OpenGVLab/Ask-Anything/tree/main?tab=MIT-1-ov-file#readme
def check_ans(pred, gt):
flag = False
pred_list = pred.lower().split(' ')
pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
return flag
def create_result_dict(result_list):
correct = 0
total = 0
res_list = []
acc_dict = {}
for idx, result_obj in enumerate(result_list):
task_type = result_obj['task_type']
if task_type not in acc_dict:
acc_dict[task_type] = [0, 0] # correct, total
acc_dict[task_type][1] += 1
total += 1
pred = result_obj['answer']
gt = result_obj['gt_answer'][0]
res_list.append({
'pred': pred,
'gt': gt
})
if check_ans(pred=pred, gt=gt):
acc_dict[task_type][0] += 1
correct += 1
print(f"Total Acc: {correct / total * 100 :.2f}%")
print('-' * 30, task_type, '-' * 30)
return acc_dict
def combine_all_res(acc_dict):
final_res = dict()
correct = 0
total = 0
for k, v in acc_dict.items():
final_res[k] = v[0] / v[1] * 100
correct += v[0]
total += v[1]
final_res['total-acc'] = correct / total * 100
print(final_res)
return final_res
def mvbench_eval(input_path):
result_file_path = merge_input_files(input_path)
merged_results = json.load(open(result_file_path))
acc_dict = create_result_dict(merged_results)
return combine_all_res(acc_dict)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc_dict = mvbench_eval(args.input_path)
print(f"MVBench {avg_acc_dict}")
import argparse
import json
from .evaluate_mmmu import get_input_output_paths
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="PhysGameBench")
results = []
collected = set()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
res["question_id"] = res["sample_id"]
if res['sample_id'] in collected:
continue
collected.add(res['sample_id'])
results.append(res)
with open(output_file_path, "w") as output_file:
json.dump(results, output_file, indent=4, sort_keys=True)
return output_file_path
# The following function is adapted from
# https://github.com/PhysGame/PhysGame/blob/main/physvlm/test/PhysGame_bench/utils.py#L101
# which is licensed under the Apache 2.0 license. More details on the license can be
# found at https://github.com/PhysGame/PhysGame/tree/main?tab=Apache-2.0-1-ov-file#readme
def check_ans(pred, gt):
flag = False
pred_list = pred.lower().split(' ')
pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
return flag
def compute_all_acc(result_list):
correct, total = 0, 0
subclass_cnt = {}
for res in result_list:
total += 1
pred = res['answer']
gt = res['gt_answer'][0]
subclass = res['subclass']
if gt.lower().replace(".", "") == pred.lower().replace(".", ""):
correct += 1
if subclass not in subclass_cnt.keys():
subclass_cnt.update({subclass: [1, 1]})
else:
subclass_cnt[subclass][0] += 1
subclass_cnt[subclass][1] += 1
else:
if subclass not in subclass_cnt.keys():
subclass_cnt.update({subclass: [0, 1]})
else:
subclass_cnt[subclass][1] += 1
result_acc_dict = {
"Physgame-Total-Acc": correct / total * 100
}
print (f'Physgame-Total-Acc: {correct / total * 100 :.2f}%', )
for sub_i in subclass_cnt.keys():
print(f'Physgame-{sub_i}-Acc: {subclass_cnt[sub_i][0] / subclass_cnt[sub_i][1] * 100 :.2f}%')
result_acc_dict[f'Physgame-{sub_i}-Acc'] = subclass_cnt[sub_i][0] / subclass_cnt[sub_i][1] * 100
return result_acc_dict
def phys_game_bench_eval(input_path):
result_file_path = merge_input_files(input_path)
merged_results = json.load(open(result_file_path))
return compute_all_acc(merged_results)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = phys_game_bench_eval(args.input_path)
print(f"PhysGameBench accuracy: {avg_acc:.2f}")
import argparse
import json
from typing import List
from .evaluate_mmmu import get_input_output_paths
from open_flamingo.eval.vqa_metric import VQAEval
# ANLS score calculation based on https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/dist.py#L1
# and https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/score.py#L6
# MIT License. Copyright (c) 2022 Shunsuke KITADA
def levenshtein_distance(s1: str, s2: str) -> int:
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = list(range(len(s1) + 1))
for i2, c2 in enumerate(s2):
dists = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
dists.append(distances[i1])
else:
dists.append(1 + min((distances[i1], distances[i1 + 1], dists[-1])))
distances = dists
return distances[-1]
def normalized_levenshtein_distance(s1: str, s2: str) -> float:
dist = levenshtein_distance(s1, s2)
length = max(len(s1.upper()), len(s2.upper()))
return 0.0 if length == 0 else dist / length
def similarity_function(prediction: str, gold_label: str, threshold: float) -> float:
nl_score = normalized_levenshtein_distance(prediction, gold_label)
return 1 - nl_score if nl_score < threshold else 0.0
def anls_score(
prediction: str, gold_labels: List[str], threshold: float = 0.5
) -> float:
# not case sensitive, but space sensitive
y_pred = " ".join(prediction.strip().lower().split())
anls_scores: List[float] = []
for gold_label in gold_labels:
# not case sensitive, but space sensitive
y_true = " ".join(gold_label.strip().lower().split())
anls_score = similarity_function(y_pred, y_true, threshold)
anls_scores.append(anls_score)
score = max(anls_scores)
return score
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2")
results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
sample_id = res["sample_id"]
# Skip possible duplicates.
if sample_id in results:
continue
res["question_id"] = sample_id
results[sample_id] = res
results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file, indent=4, sort_keys=True)
return output_file_path
def is_number(n: str):
"""Check if input is a number."""
try:
float(n)
return True
except ValueError:
return False
def compute_vqa_accuracy(result_file, task):
"""Compute VQA accuracy."""
merged_results = json.load(open(result_file))
vqa = VQAEval(vqa=None, vqaRes=None)
all_acc = []
for res in merged_results:
pred = res["answer"]
pred = vqa.processPunctuation(pred)
pred = vqa.processDigitArticle(pred)
gt = res["gt_answer"]
gt = [vqa.processPunctuation(ans) for ans in gt]
gt = [vqa.processDigitArticle(ans) for ans in gt]
# ChartQA uses relaxed accuracy:
# "We consider an answer to be correct if it is within 5% of the gold answer.
# For non-numeric answers, we still need an exact match to consider an answer to be correct."
if task == "ChartQA":
acc = 0.0
assert len(gt) == 1, "expected exactly one groundtruth answer."
gt = gt[0]
pred = pred.rstrip("%")
gt = gt.rstrip("%")
if is_number(pred) and is_number(gt):
pred = float(pred)
gt = float(gt)
if pred >= (gt * 0.95) and pred <= (gt * 1.05):
acc = 1.0
elif pred == gt:
acc = 1.0
all_acc.append(acc)
elif task in ("VQAv2", "TextVQA"):
num_match = sum([pred == ans for ans in gt])
acc = min(1.0, num_match / 3.0)
all_acc.append(acc)
elif task in ("SPDocVQA", "InfoVQA"):
acc = anls_score(prediction=pred, gold_labels=gt, threshold=0.5)
all_acc.append(acc)
elif task in ("AI2D", "RealworldQA", "MotionBench"):
assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}"
acc = pred == gt[0]
all_acc.append(acc)
else:
raise NotImplementedError(f"unknown task {task}")
acc_avg = sum(all_acc) / len(all_acc) * 100
return acc_avg
def vqav2_eval(input_path):
"""Run VQAv2 evaluation."""
result_file = merge_input_files(input_path)
avg_acc = compute_vqa_accuracy(result_file, task="VQAv2")
return avg_acc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = vqav2_eval(args.input_path)
print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====")
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Evaluation datasets."""
import glob
import json
import os
import re
from collections import defaultdict
import numpy as np
import torch
from image_processing import ImageTransform
from PIL import Image
from megatron.training import print_rank_0
def _get_partition_bounds(
total_num_samples, num_samples_per_partition, num_partitions, partition_id
):
if num_samples_per_partition == 0:
samples_per_partition = [
int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1)
]
return samples_per_partition[partition_id], samples_per_partition[partition_id + 1]
return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1)
class VQADataset(torch.utils.data.Dataset):
"""VQA evaluation dataset."""
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
split="validation"
):
samples = json.load(open(gt_path, encoding='utf-8'))
if "data" in samples:
samples = samples["data"]
# Optionally, process only a subset of the input files.
if num_partitions > 0:
lb, ub = _get_partition_bounds(
len(samples), num_samples_per_partition, num_partitions, partition_id
)
samples = samples[lb:ub]
self._keys = keys
self._samples = samples
self._input_image_path = input_image_path
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
self._split = split
def __len__(self):
return len(self._samples)
def __getitem__(self, idx):
sample = self._samples[idx]
img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]])
if not os.path.exists(img_file):
img_file += ".jpg"
if not os.path.exists(img_file):
img_file = img_file.replace('.jpg', '.png')
img = Image.open(img_file)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
sample_id = idx
if "sample_id" in self._keys:
sample_id = sample[self._keys["sample_id"]]
metadata = "" # Not used.
return (
torch.stack(imgs),
tile_count,
sample_id,
sample[self._keys["question"]],
[""] if self._split == "test" else sample[self._keys["answer"]],
metadata,
)
class CaptioningDataset(torch.utils.data.Dataset):
"""Captioning evaluation dataset."""
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
):
image_files = sorted(glob.glob(input_image_path + "/*"))
# Optionally, process only a subset of the input files.
if num_partitions > 0:
lb, ub = _get_partition_bounds(
len(image_files), num_samples_per_partition, num_partitions, partition_id
)
image_files = image_files[lb:ub]
gts = json.load(open(gt_path))
answers = defaultdict(list)
for gt in gts["annotations"]:
answers[gt["image_id"]].append(gt['caption'])
self._image_files = image_files
self._answers = answers
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._image_files)
def __getitem__(self, idx):
img_file = self._image_files[idx]
try:
image_id = int(img_file.split("_")[-1].split(".")[0]) # coco
except:
image_id = int(img_file.split("/")[-1].split(".")[0]) # flickr
img = Image.open(img_file)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
question = "" # Fixed for all samples.
metadata = "" # Not used.
return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata
class MMMUDataset(torch.utils.data.Dataset):
"""MMMU evaluation dataset."""
def __init__(
self,
input_image_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
prompt_style,
vision_model_type,
split="validation",
):
import datasets
from .mmmu_utils import CAT_SHORT2LONG, load_yaml
# The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation.
all_mmmu_datasets = []
hf_datasets_cache = os.environ["HF_DATASETS_CACHE"]
assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE."
for subject in CAT_SHORT2LONG.values():
# Use a local copy of the dataset if exists (can be faster) or the HF one.
if os.path.exists(input_image_path):
subject_dataset = datasets.load_dataset(
os.path.join(input_image_path, subject),
split=split,
cache_dir=hf_datasets_cache,
verification_mode="no_checks",
)
else:
subject_dataset = datasets.load_dataset(
"MMMU/MMMU",
subject,
split=split,
cache_dir=hf_datasets_cache,
)
all_mmmu_datasets.append(subject_dataset)
dataset = datasets.concatenate_datasets(all_mmmu_datasets)
dataset = [s for s in dataset if s['id'].startswith("val")]
# Optionally, process only a subset of the input files.
if num_partitions > 0:
lb, ub = _get_partition_bounds(
len(dataset), num_samples_per_partition, num_partitions, partition_id
)
dataset = dataset[lb:ub]
# Using the LLaVA config from the MMMU repo.
config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml")
for k, v in config.items():
if isinstance(v, list):
assert len(v) == 1, "only one value supported."
config[k] = v[0]
self._config = config
self._dataset = dataset
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._prompt_style = prompt_style
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._dataset)
def process_image_tag(self, q):
q = q.strip()
# heuristic way of removing <image 1>
if q == '<image 1>':
q = 'Answer the question in the image.'
elif ':<image 1>' in q:
q = q.replace(':<image 1>', ' in the image. ')
q = q.strip()
elif ': <image 1>' in q:
q = q.replace(': <image 1>', ' in the image. ')
q = q.strip()
elif '.<image 1>' in q or '. <image 1>' in q:
q_list = q.split('<image 1>')
q_list = [part.strip() for part in q_list if part.strip() != '']
q = ' '.join(q_list)
elif q.startswith('<image 1> '):
if q[10].isupper():
q = q.replace('<image 1>', '')
else:
q = q.replace('<image 1>', 'The image')
q = q.strip()
elif q.startswith('<image 1>'):
q = q.replace('<image 1>', '')
elif q.endswith('<image 1>?'):
q = q.replace('<image 1>', 'the image')
elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'):
q = q.replace('<image 1>', '')
q = q.strip()
elif ' <image 1> ' in q:
q = q.replace('<image 1>', 'the image')
elif ' <image 1>' in q:
q = q.replace('<image 1>', 'the image')
elif '()<image 1>' in q:
q = q.replace('()<image 1>', '')
elif '(<image 1>)' in q:
q = q.replace('(<image 1>)', '')
elif '<image 1>.' in q:
q = q.replace("<image 1>.", ". ")
else:
q = q.replace("<image 1>", ". ")
q = q.strip()
# remove <image 2> to <image 8>
for i in range(2, 8):
q = q.replace(f"<image {i}>", "")
return q
def __getitem__(self, idx):
from .mmmu_utils import construct_prompt, process_single_sample
sample = self._dataset[idx]
# Use the single image approach from the MMMU repo.
if self._prompt_style == "single_image":
sample = process_single_sample(sample)
sample = construct_prompt(sample, self._config)
img = sample["image"]
sample_imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
sample_num_tiles = [len(sample_imgs)]
prompt = sample["final_input_prompt"]
sample["final_input_prompt"] = self.process_image_tag(prompt)
elif self._prompt_style == "vlmevalkit":
sample = construct_prompt(sample, self._config)
if sample["question_type"] == "multiple-choice":
question = sample["question"]
options = ""
for k, v in sample["index2ans"].items():
options += f"{k}. {v}\n"
final_prompt = f"{question}\n"
if "hint" in sample:
final_prompt += f"Hint: {sample['hint']}\n"
if "task_instructions" in sample:
final_prompt += f"Task instructions: {sample['task_instructions']}\n"
final_prompt += options
final_prompt += "Answer with the option's letter from the given choices directly."
sample["final_input_prompt"] = final_prompt.rstrip()
else:
question = sample["question"]
final_prompt = f"{question}\n"
final_prompt += "Answer the question directly."
sample["final_input_prompt"] = final_prompt.rstrip()
sample_imgs = []
sample_num_tiles = []
img_indices = sorted(list(set(re.findall(r"<image (\d+)", sample["final_input_prompt"]))))
# If there are multiple input images, we need to avoid the number of image embeddings getting too large.
adjusted_max_num_tiles = max(1, self._max_num_tiles // len(img_indices))
adjusted_max_num_tiles = min(adjusted_max_num_tiles, self._max_num_tiles)
for img_idx in img_indices:
img_key = f"image_{img_idx}"
img_str = f"<image {img_idx}>"
img = sample[img_key]
assert img is not None, f"{img_str} is in prompt but not in sample images"
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
adjusted_max_num_tiles,
self._use_thumbnail,
augment=False,
) # List of tiles.
sample_imgs.extend(imgs)
sample_num_tiles.append(len(imgs))
sample["final_input_prompt"] = " ".join([f'<image {i + 1}><image>' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"]
elif self._prompt_style == "multi_image":
sample = construct_prompt(sample, self._config)
sample_imgs = []
sample_num_tiles = []
img_indices = re.findall(r"<image (\d+)", sample["final_input_prompt"])
# If there are multiple input images, we need to avoid the number of image embeddings getting too large.
adjusted_max_num_tiles = max(1, self._max_num_tiles // len(img_indices))
for img_idx in img_indices:
img_key = f"image_{img_idx}"
img_str = f"<image {img_idx}>"
img = sample[img_key]
assert img is not None, f"{img_str} is in prompt but not in sample images"
# Note: Only replace the current image tag.
sample["final_input_prompt"] = sample["final_input_prompt"].replace(
img_str, "<image>", 1
)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
adjusted_max_num_tiles,
self._use_thumbnail,
augment=False,
) # List of tiles.
sample_imgs.extend(imgs)
sample_num_tiles.append(len(imgs))
# Sanity check.
for i in range(1, 8):
assert (
f"<image {i}>" not in sample["final_input_prompt"]
), "prompt contains unhandled image tags"
else:
raise ValueError(f"unknown prompt style {self._prompt_style}")
# MMMU specific metadata.
metadata = {"question_type": sample["question_type"],
"field": sample["field"],
"subfield": sample["subfield"]}
if sample["question_type"] == "multiple-choice":
metadata["index2ans"] = sample["index2ans"]
metadata["all_choices"] = sample["all_choices"]
prompt = sample['final_input_prompt']
tile_count = torch.tensor(sample_num_tiles, dtype=torch.int)
return (
torch.stack(sample_imgs),
tile_count,
sample["id"],
prompt,
sample["answer"],
metadata,
)
class VideoMMEDataset(torch.utils.data.Dataset):
"Video MME evaluation dataset."
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
):
ground_truth_original = json.load(open(gt_path))
ground_truth = []
for gt in ground_truth_original:
video_path = gt["url"]
video_path = video_path.replace("https://www.youtube.com/watch?v=", "")
video_path = video_path.replace("https://m.youtube.com/watch?v=", "")
video_path = os.path.join(input_image_path, video_path + ".mp4")
if not os.path.exists(video_path):
continue
gt["video_path"] = video_path
ground_truth.append(gt)
ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"])
print_rank_0(f"Found {len(ground_truth)} videos to process.")
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(ground_truth), num_samples_per_partition, num_partitions, partition_id
)
ground_truth = ground_truth[start_idx:end_idx]
self._ground_truth = ground_truth
self._img_h = img_h
self._img_w = img_w
self._use_tiling = False
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._num_frames = num_frames
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._ground_truth)
def __getitem__(self, idx):
from torchvision.io import read_video
gt = self._ground_truth[idx]
video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec')
video = video.numpy()
selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long()
video_frames = video[selected_frames]
if self._num_frames == 1:
video_frames = video_frames[None]
imgs = []
for img in video_frames:
from torchvision.transforms import ToPILImage
to_pil = ToPILImage()
img = to_pil(img)
imgs += self._transform_img(
img, self._img_h, self._img_w, self._use_tiling, self._max_num_tiles,
self._use_thumbnail, augment=False,
)
for question in gt["questions"]:
# Very hacky, but we essentially re-create gt holding only the
# question of interest. This is the make this generation script
# compatible with the Video MME evaluation script.
question_dict = {
"video_id": gt["video_id"],
"duration_category": gt["duration_category"],
"video_category": gt["video_category"],
"video_subcategory": gt["video_subcategory"],
"url": gt["url"],
"questions": [question],
}
num_tiles = torch.tensor([len(imgs)], dtype=torch.int)
answer = ""
metadata = ""
return (
torch.stack(imgs),
num_tiles,
question["question_id"],
question_dict,
answer,
metadata,
)
class OCRBenchDataset(torch.utils.data.Dataset):
"""OCRBench evaluation dataset."""
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
):
gt = json.load(open(gt_path, encoding='utf-8'))
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(gt), num_samples_per_partition, num_partitions, partition_id
)
gt = gt[start_idx:end_idx]
self._input_image_path = input_image_path
self._gt = gt
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._gt)
def __getitem__(self, idx):
img_path = os.path.join(self._input_image_path, self._gt[idx]['image_path'])
img = Image.open(img_path)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
metadata = {
"dataset_name": self._gt[idx]["dataset_name"],
"data_type": self._gt[idx]["type"],
}
return (
torch.stack(imgs),
tile_count,
idx,
self._gt[idx]["question"],
self._gt[idx]["answers"],
metadata,
)
class MathVistaDataset(torch.utils.data.Dataset):
"""MathVista evaluation dataset."""
def __init__(
self,
input_image_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
):
import datasets
hf_datasets_cache = os.environ["HF_DATASETS_CACHE"]
assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE."
if os.path.exists(input_image_path):
dataset = datasets.load_dataset(
input_image_path, cache_dir=hf_datasets_cache, verification_mode="no_checks", split="train"
)
else:
dataset = datasets.load_dataset(
"AI4Math/MathVista", split="testmini", cache_dir=hf_datasets_cache
)
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(dataset), num_samples_per_partition, num_partitions, partition_id
)
dataset = dataset[start_idx:end_idx]
self._dataset = dataset
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._dataset["pid"])
def __getitem__(self, idx):
# Already a PIL object.
img = self._dataset['decoded_image'][idx]
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
question_id = self._dataset["pid"][idx]
question = self._dataset["question"][idx]
question_type = self._dataset["question_type"][idx] # free_form or multi_choice
query = self._dataset["query"][idx]
choices = self._dataset["choices"][idx]
answer = self._dataset["answer"][idx]
if question_type == 'multi_choice':
start_chr = 'A'
choices_str = ''
index2ans = {}
all_choices = []
for choice in choices:
all_choices.append(start_chr)
index2ans[start_chr] = choice
choices_str += f"{start_chr}. {choice}\n"
start_chr = chr(ord(start_chr) + 1)
question = question + '\n' + choices_str
question = question + "Answer with the option's letter from the given choices directly."
answer = chr(ord('A') + choices.index(answer))
else:
question = query.replace("Hint: ", "")
index2ans = {}
all_choices = []
metadata = {
"question_type": question_type,
"index2ans": index2ans,
"all_choices": all_choices,
}
return torch.stack(imgs), tile_count, question_id, question, answer, metadata
class AI2DDataset(torch.utils.data.Dataset):
"""AI2D evaluation dataset."""
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
):
with open(gt_path, 'r') as f:
jsonl = list(f)
gt = [json.loads(json_str) for json_str in jsonl]
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(gt), num_samples_per_partition, num_partitions, partition_id
)
gt = gt[start_idx:end_idx]
self._gt = gt
self._input_image_path = input_image_path
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._gt)
def __getitem__(self, idx):
img_path = os.path.join(self._input_image_path, self._gt[idx]['image'].split("/")[-1])
img = Image.open(img_path)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
metadata = "" # Not used.
return (
torch.stack(imgs),
tile_count,
self._gt[idx]["question_id"],
self._gt[idx]["question"],
self._gt[idx]["answer"],
metadata,
)
class RDTableBenchDataset(torch.utils.data.Dataset):
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
):
gt_paths = sorted(glob.glob(os.path.join(gt_path, "*.html")))
gt = []
for gt_path in gt_paths:
img_path = os.path.join(input_image_path, os.path.basename(gt_path).replace(".html", ".jpg"))
with open(gt_path) as f:
html = f.read()
gt.append({
"answer": html,
"image": img_path,
})
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(gt), num_samples_per_partition, num_partitions, partition_id
)
gt = gt[start_idx:end_idx]
self._input_image_path = input_image_path
self._gt = gt
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._gt)
def __getitem__(self, idx):
img_path = os.path.join(self._input_image_path, self._gt[idx]['image'])
img = Image.open(img_path)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
metadata = ""
prompt = (
"Convert the image to an HTML table. The output should begin with <table> and end with </table>. "
"Specify rowspan and colspan attributes when they are greater than 1. Do not specify any other attributes. "
"Only use table related HTML tags, no additional formatting is required."
)
return (
torch.stack(imgs),
tile_count,
idx,
prompt,
self._gt[idx]["answer"],
metadata,
)
class RealworldQADataset(torch.utils.data.Dataset):
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
):
gt = json.load(open(gt_path, encoding='utf-8'))
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(gt), num_samples_per_partition, num_partitions, partition_id
)
gt = gt[start_idx:end_idx]
self._gt = gt
self._input_image_path = input_image_path
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._gt)
def __getitem__(self, idx):
img_path = os.path.join(self._input_image_path, self._gt[idx]['image'])
img = Image.open(img_path)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
question_id = int(self._gt[idx]['image'].replace(".webp", ""))
question = self._gt[idx]["question"]
if self._gt[idx]['question_type'] == "multi-choice":
choices = self._gt[idx]["choices"]
start_chr = 'A'
choices_str = ''
index2ans = {}
all_choices = []
for choice in choices:
all_choices.append(start_chr)
index2ans[start_chr] = choice
choices_str += f"{start_chr}. {choice}\n"
start_chr = chr(ord(start_chr) + 1)
question = question + '\n' + choices_str
question = question + "Answer with the option's letter from the given choices directly."
answer = chr(ord('A') + self._gt[idx]['correct_choice_index'])
else:
question = question + "\nAnswer the question using a single word or phrase."
answer = self._gt[idx]['answer']
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
metadata = "" # Not used.
return (
torch.stack(imgs),
tile_count,
question_id,
question,
[answer],
metadata,
)
class MotionBenchDataset(torch.utils.data.Dataset):
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
split
):
with open(gt_path) as f:
ground_truth_original = [json.loads(line) for line in f]
ground_truth = []
for gt in ground_truth_original:
# video path handling
video_path = gt['video_path']
if ".mp4" not in video_path:
video_path = f"{video_path}.mp4"
video_path = os.path.join(input_image_path, video_path)
if not os.path.exists(video_path):
continue
gt["video_path"] = video_path
ground_truth.append(gt)
ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"])
print_rank_0(f"Found {len(ground_truth)} videos to process.")
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(ground_truth), num_samples_per_partition, num_partitions, partition_id
)
ground_truth = ground_truth[start_idx:end_idx]
self._ground_truth = ground_truth
self._img_h = img_h
self._img_w = img_w
self._use_tiling = False
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._num_frames = num_frames
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._ground_truth)
def __getitem__(self, idx):
gt = self._ground_truth[idx]
from torchvision.io.video import read_video
video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec')
video = video.permute((0, 3, 1, 2))
selected_frames = torch.linspace(0, video.shape[0] - 1, min(self._num_frames, video.shape[0])).long()
video_frames = video[selected_frames]
if self._num_frames == 1:
video_frames = video_frames[None]
imgs = []
for img in video_frames:
from torchvision.transforms import ToPILImage
to_pil = ToPILImage()
img = to_pil(img)
imgs += self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
num_tiles = torch.tensor([len(imgs)], dtype=torch.int)
q_id = gt['qa'][0]['uid']
question = gt['qa'][0]['question']
answer = gt['qa'][0]['answer']
metadata = ""
return (
torch.stack(imgs),
num_tiles,
q_id,
question,
answer,
metadata,
)
# The following class is adapted from
# https://github.com/PhysGame/PhysGame/blob/main/physvlm/test/PhysGame_bench/utils.py#L27
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/PhysGame/PhysGame/tree/main?tab=Apache-2.0-1-ov-file#readme
class PhysGameBenchDataset(torch.utils.data.Dataset):
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
split
):
ground_truth_original = json.load(open(gt_path, encoding='utf-8'))
ground_truth = []
for gt in ground_truth_original:
video_path = os.path.join(input_image_path, gt['question_id']) + ".mp4"
if not os.path.exists(video_path):
continue
gt["video_path"] = video_path
ground_truth.append(gt)
ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"])
print_rank_0(f"Found {len(ground_truth)} videos to process.")
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(ground_truth), num_samples_per_partition, num_partitions, partition_id
)
ground_truth = ground_truth[start_idx:end_idx]
self._ground_truth = ground_truth
self._img_h = img_h
self._img_w = img_w
self._use_tiling = False
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._num_frames = num_frames
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._ground_truth)
def _qa_template(self, data):
question = f"Question: {data['question']}\n"
question += "Options:\n"
answer = data['answer']
for ch, c in data['options'].items():
question += f"({ch}) {c}\n"
question = question.rstrip()
return question, answer
def __getitem__(self, idx):
gt = self._ground_truth[idx]
from torchvision.io.video import read_video
video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec')
video = video.permute((0, 3, 1, 2))
selected_frames = torch.linspace(0, video.shape[0] - 1, min(self._num_frames, video.shape[0])).long()
video_frames = video[selected_frames]
if self._num_frames == 1:
video_frames = video_frames[None]
imgs = []
for img in video_frames:
from torchvision.transforms import ToPILImage
to_pil = ToPILImage()
img = to_pil(img)
imgs += self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
num_tiles = torch.tensor([len(imgs)], dtype=torch.int)
q_id = gt['question_id']
question, answer = self._qa_template(gt)
metadata = {
'class': gt['class_anno'],
'subclass': gt['subclass_anno']
}
return (
torch.stack(imgs),
num_tiles,
q_id,
question,
answer,
metadata,
)
# The following class is adapted from
# https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/mvbench.ipynb
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/OpenGVLab/Ask-Anything/tree/main?tab=MIT-1-ov-file#readme
class MVBenchDataset(torch.utils.data.Dataset):
def __init__(
self,
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
split
):
data_list = {
"Action Sequence": ("action_sequence.json", f"{input_image_path}/star/Charades_v1_480/", "video", True), # has start & end
"Action Prediction": ("action_prediction.json", f"{input_image_path}/star/Charades_v1_480/", "video", True), # has start & end
"Action Antonym": ("action_antonym.json", f"{input_image_path}/ssv2_video/", "video", False),
"Fine-grained Action": ("fine_grained_action.json", f"{input_image_path}/Moments_in_Time_Raw/videos/", "video", False),
"Unexpected Action": ("unexpected_action.json", f"{input_image_path}/FunQA_test/test/", "video", False),
"Object Existence": ("object_existence.json", f"{input_image_path}/clevrer/video_validation/", "video", False),
"Object Interaction": ("object_interaction.json", f"{input_image_path}/star/Charades_v1_480/", "video", True), # has start & end
"Object Shuffle": ("object_shuffle.json", f"{input_image_path}/perception/videos/", "video", False),
"Moving Direction": ("moving_direction.json", f"{input_image_path}/clevrer/video_validation/", "video", False),
"Action Localization": ("action_localization.json", f"{input_image_path}/sta/sta_video/", "video", True), # has start & end
"Scene Transition": ("scene_transition.json", f"{input_image_path}/scene_qa/video/", "video", False),
"Action Count": ("action_count.json", f"{input_image_path}/perception/videos/", "video", False),
"Moving Count": ("moving_count.json", f"{input_image_path}/clevrer/video_validation/", "video", False),
"Moving Attribute": ("moving_attribute.json", f"{input_image_path}/clevrer/video_validation/", "video", False),
"State Change": ("state_change.json", f"{input_image_path}/perception/videos/", "video", False),
"Fine-grained Pose": ("fine_grained_pose.json", f"{input_image_path}/nturgbd/", "video", False),
"Character Order": ("character_order.json", f"{input_image_path}/perception/videos/", "video", False),
"Egocentric Navigation": ("egocentric_navigation.json", f"{input_image_path}/vlnqa/", "video", False),
"Episodic Reasoning": ("episodic_reasoning.json", f"{input_image_path}/tvqa/frames_fps3_hq/", "frame", True), # has start & end, read frame
"Counterfactual Inference": ("counterfactual_inference.json", f"{input_image_path}/clevrer/video_validation/", "video", False)
}
ground_truth = []
for k, v in data_list.items():
with open(os.path.join(gt_path, v[0]), 'r') as f:
json_data = json.load(f)
for data_id, data in enumerate(json_data):
ground_truth.append({
'task_type': k,
'prefix': v[1],
'data_type': v[2],
'bound': v[3],
'data': data,
'question_id': f"{k}-{data_id}"
})
print("total ground truth ==> ", len(ground_truth))
self.decord_method = {
'video': self.read_video_ours,
'frame': self.read_frame,
}
if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(ground_truth), num_samples_per_partition, num_partitions, partition_id
)
ground_truth = ground_truth[start_idx:end_idx]
print("Partitioned ==> ", {start_idx}, {end_idx}, len(ground_truth))
self._ground_truth = ground_truth
self._img_h = img_h
self._img_w = img_w
self._use_tiling = False
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._num_frames = num_frames
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._ground_truth)
def get_index(self, bound, fps, max_frame, first_idx=0):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / self._num_frames
frame_indices = np.array([
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(self._num_frames)
])
return frame_indices
def qa_template(self, data):
question = f"Question: {data['question']}\n"
question += "Options:\n"
answer = data['answer']
answer_idx = -1
for idx, c in enumerate(data['candidates']):
question += f"({chr(ord('A') + idx)}) {c}\n"
if c == answer:
answer_idx = idx
question = question.rstrip()
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
return question, answer
def read_frame(self, video_path, bound=None, fps=2):
max_frame = len(os.listdir(video_path))
images_group = list()
frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
for frame_index in frame_indices:
img = Image.open(os.path.join(video_path, f"{frame_index:05d}.jpg"))
images_group.append(img)
return images_group
def read_video_ours(self, video_path, bound=None):
from torchvision.io.video import read_video
video, _, v_meta_info = read_video(video_path, start_pts=0, end_pts=None, pts_unit='sec')
video = video.permute((0, 3, 1, 2))
fps = float(v_meta_info['video_fps'])
max_frame = len(video) - 1
selected_frames_indices = self.get_index(bound, fps, max_frame, first_idx=0)
video_frames = video[selected_frames_indices]
return video_frames
def __getitem__(self, idx):
data = self._ground_truth[idx]
bound = None
if data['bound']:
bound = (
data['data']['start'],
data['data']['end'],
)
video_path = os.path.join(data['prefix'], data['data']['video'])
video_decode_func = self.decord_method[data['data_type']]
video_frames = video_decode_func(video_path, bound)
imgs = []
for img in video_frames:
from torchvision.transforms import ToPILImage
if data['data_type'] == 'video':
to_pil = ToPILImage()
img = to_pil(img)
imgs += self._transform_img(
img, self._img_h, self._img_w, self._use_tiling, self._max_num_tiles,
self._use_thumbnail, augment=False
)
num_tiles = torch.tensor([len(imgs)], dtype=torch.int)
q_id = data['question_id']
metadata = {'task_type': data['task_type']}
question, answer = self.qa_template(data['data'])
return (
torch.stack(imgs),
num_tiles,
q_id,
question,
answer,
metadata,
)
class ExampleInferenceDataset(torch.utils.data.Dataset):
def __init__(
self,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
):
# Define your own inference samples here. The following is an example.
samples = [
# Use <image> token to indicate the image position.
{"image_paths": ["examples/multimodal/assets/pretrain_curves.png"], "question": "<image>\nWhat is the curve?"},
# Optional: if you have an answer for the question.
{"image_paths": ["examples/multimodal/assets/pretrain_curves.png"], "question": "What is the curve?<image>", "answer": "It's a loss function curve."},
# If you have multiple images for the question, then use <image> token to indicate the image positions.
{"image_paths": ["examples/multimodal/assets/pretrain_curves.png", "examples/multimodal/assets/pretrain_curves.png"], "question": "<image>What is the curve?<image>"},
# Text only sample.
{"question": "Who is Jensen Huang?"},
]
self._samples = samples
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._transform_img = ImageTransform(img_h, vision_model_type)
def __len__(self):
return len(self._samples)
def __getitem__(self, idx):
sample = self._samples[idx]
sample_imgs = []
sample_tile_count = []
for image_path in sample.get("image_paths", []):
img = Image.open(image_path)
imgs = self._transform_img(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
)
sample_imgs.extend(imgs)
sample_tile_count.append(len(imgs))
sample_id = idx
metadata = "" # Not used.
return (
torch.stack(sample_imgs) if len(sample_imgs) > 0 else torch.tensor([]),
torch.tensor(sample_tile_count, dtype=torch.int),
sample_id,
sample["question"],
sample.get("answer", ""),
metadata,
)
def get_evaluation_dataset(
task,
input_image_path,
gt_path,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_samples_per_partition,
num_partitions,
partition_id,
num_frames,
vision_model_type,
split="validation",
):
"""Get an evaluation dataset."""
if task == "TextVQA":
keys = {
"image_id": "image_id",
"sample_id": "question_id",
"question": "question",
"answer": "answers",
}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "VQAv2":
keys = {
"image_id": "image",
"sample_id": "question_id",
"question": "question",
"answer": "answer",
}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "ChartQA":
keys = {"image_id": "imgname", "question": "query", "answer": "label"}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "captioning":
dataset = CaptioningDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == 'MMMU':
# Note:
# - prompt_style="single_image" uses only one image like in the MMMU repo example.
# - prompt_style="multi_image" uses multiple input images.
# - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499
dataset = MMMUDataset(
input_image_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
prompt_style="single_image",
vision_model_type=vision_model_type,
split=split,
)
elif task == 'RealworldQA':
dataset = RealworldQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type=vision_model_type,
)
elif task in ["OCRBench", "OCRBench_v2"]:
dataset = OCRBenchDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "MathVista":
dataset = MathVistaDataset(
input_image_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "AI2D":
dataset = AI2DDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type=vision_model_type,
)
elif task == "SPDocVQA":
keys = {"sample_id": "questionId", "image_id": "image", "question": "question", "answer": "answers"}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "InfoVQA":
keys = {"sample_id": "questionId", "image_id": "image_local_name", "question": "question", "answer": "answers"}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "RD_TableBench":
dataset = RDTableBenchDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
### video QA
elif task == "VideoMME":
dataset = VideoMMEDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
)
elif task == "MotionBench":
dataset = MotionBenchDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
split=split
)
elif task == "PhysGameBench":
dataset = PhysGameBenchDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
split=split
)
elif task == "MVBench":
dataset = MVBenchDataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
num_frames,
vision_model_type,
split=split
)
elif task == "inference":
dataset = ExampleInferenceDataset(
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
else:
raise NotImplementedError(f"unsupported task {task}")
return dataset
# The following code is adapted from
# https://github.com/MMMU-Benchmark/MMMU/blob/main/mmmu/utils/data_utils.py,
# which is licensed under the Apache License 2.0. More details on the license can be
# found at https://github.com/MMMU-Benchmark/MMMU/tree/main?tab=Apache-2.0-1-ov-file#readme
"""Utils for data load, save, and process (e.g., prompt construction)"""
import os
import json
import yaml
import re
DOMAIN_CAT2SUB_CAT = {
'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'],
'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ],
'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine',
'Pharmacy', 'Public_Health'],
'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics',
'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
}
CAT_SHORT2LONG = {
'acc': 'Accounting',
'agri': 'Agriculture',
'arch': 'Architecture_and_Engineering',
'art': 'Art',
'art_theory': 'Art_Theory',
'bas_med': 'Basic_Medical_Science',
'bio': 'Biology',
'chem': 'Chemistry',
'cli_med': 'Clinical_Medicine',
'cs': 'Computer_Science',
'design': 'Design',
'diag_med': 'Diagnostics_and_Laboratory_Medicine',
'econ': 'Economics',
'elec': 'Electronics',
'ep': 'Energy_and_Power',
'fin': 'Finance',
'geo': 'Geography',
'his': 'History',
'liter': 'Literature',
'manage': 'Manage',
'mark': 'Marketing',
'mate': 'Materials',
'math': 'Math',
'mech': 'Mechanical_Engineering',
'music': 'Music',
'phar': 'Pharmacy',
'phys': 'Physics',
'psy': 'Psychology',
'pub_health': 'Public_Health',
'socio': 'Sociology'
}
def load_yaml(file_path):
with open(file_path, 'r') as stream:
try:
yaml_dict = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
return yaml_dict
def parse_img_path(text):
matches = re.findall("<img='(.*?)'>", text)
return matches
def process_single_sample(data):
question = data['question']
o_imgs_paths = []
for option in data['options']:
current_o_imgs_paths = parse_img_path(option)
for img_path in current_o_imgs_paths:
o_imgs_paths.append(img_path)
categories = list(CAT_SHORT2LONG.values())
for c in categories:
if c in data['id']:
field = c.lower().replace('_', ' ')
break
if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
'image': None, 'question_type': data['question_type'],
'field': field, 'subfield': data['subfield']}
else:
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
'image': data['image_1'], 'question_type': data['question_type'],
'field': field, 'subfield': data['subfield']}
# DATA PROCESSING
def construct_prompt(sample, config):
question = sample['question'].strip()
options = eval(sample['options'])
example = ""
if sample['question_type'] == 'multiple-choice':
start_chr = 'A'
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)
empty_prompt_sample_structure = config['multi_choice_example_format']
empty_prompt = empty_prompt_sample_structure.format(question, example)
res_dict = {'type': 'multichoice'}
res_dict['index2ans'] = index2ans
res_dict['correct_choice'] = sample['answer']
res_dict['all_choices'] = prediction_range
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
else:
empty_prompt_sample_structure = config['short_ans_example_format']
empty_prompt = empty_prompt_sample_structure.format(question)
res_dict = {'type': 'open'}
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = sample['answer']
res_dict.update(sample)
return res_dict
"""Response Parsing and Evaluation for various models"""
from typing import Dict
import re
import random
import numpy as np
# ----------- Process Multi-choice -------------
def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Return the predicted index e.g., A, B, C, D.
"""
for char in [',', '.', '!', '?', ';', ':', "'"]:
response = response.strip(char)
response = " " + response + " " # add space to avoid partial match
index_ans = True
ans_with_brack = False
candidates = []
for choice in all_choices: # e.g., (A) (B) (C) (D) A) B) C) D)
if f'({choice})' in response or f'{choice})' in response:
candidates.append(choice)
ans_with_brack = True
if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f' {choice} ' in response:
candidates.append(choice)
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False # it's content ans.
if len(candidates) == 0: # still not get answer, randomly choose one.
pred_index = all_choices[0]
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_brack:
for can in candidates:
index = response.rfind(f'({can})')
start_indexes.append(index) # -1 will be ignored anyway
else:
for can in candidates:
index = response.rfind(f" {can} ")
start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
start_indexes.append(index)
# get the last one
pred_index = candidates[np.argmax(start_indexes)]
else: # if only one candidate, use it.
pred_index = candidates[0]
return pred_index
# ----------- Process Open -------------
def check_is_number(string):
"""
Check if the given string a number.
"""
try:
float(string.replace(',', ''))
return True
except ValueError:
# check if there's comma inside
return False
def normalize_str(string):
"""
Normalize the str to lower case and make them float numbers if possible.
"""
# check if characters in the string
# if number, numerize it.
string = string.strip()
is_number = check_is_number(string)
if is_number:
string = string.replace(',', '')
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
else: # it's likely to be a string
# lower it
string = string.lower()
if len(string) == 1:
return [" " + string, string + " "] # avoid trivial matches
return [string]
def extract_numbers(string):
"""
Exact all forms of numbers from a string with regex.
"""
# Pattern for numbers with commas
pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
# Pattern for scientific notation
pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
# Pattern for simple numbers without commas
pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
# Extract numbers with commas
numbers_with_commas = re.findall(pattern_commas, string)
# Extract numbers in scientific notation
numbers_scientific = re.findall(pattern_scientific, string)
# Extract simple numbers without commas
numbers_simple = re.findall(pattern_simple, string)
# Combine all extracted numbers
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers
def parse_open_response(response):
"""
Parse the prediction from the generated response.
Return a list of predicted strings or numbers.
"""
# content = content.strip("\n").strip(".").strip(" ")
def get_key_subresponses(response):
key_responses = []
response = response.strip().strip(".").lower()
sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
indicators_of_keys = ['could be ', 'so ', 'is ',
'thus ', 'therefore ', 'final ', 'answer ', 'result ']
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(['='])
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
# pdb.set_trace()
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))
tmp_pred_list = []
for i in range(len(pred_list)):
tmp_pred_list.extend(normalize_str(pred_list[i]))
pred_list = tmp_pred_list
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
# ----------- Evaluation -------------
def eval_multi_choice(gold_i, pred_i):
"""
Evaluate a multiple choice instance.
"""
correct = False
# only they are exactly the same, we consider it as correct
if isinstance(gold_i, list):
for answer in gold_i:
if answer == pred_i:
correct = True
break
else: # gold_i is a string
if gold_i == pred_i:
correct = True
return correct
def eval_open(gold_i, pred_i):
"""
Evaluate an open question instance
"""
correct = False
if isinstance(gold_i, list):
# use float to avoid trivial matches
norm_answers = []
for answer in gold_i:
norm_answers.extend(normalize_str(answer))
else:
norm_answers = normalize_str(gold_i)
for pred in pred_i: # pred is already normalized in parse response phase
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
if isinstance(norm_ans, str) and norm_ans in pred:
if not correct:
correct = True
break
else: # it's a float number
if pred in norm_answers:
if not correct:
correct = True
break
return correct
# ----------- Batch Evaluation -------------
def evaluate(samples):
"""
Batch evaluation for multiple choice and open questions.
"""
pred_correct = 0
judge_dict = dict()
for sample in samples:
gold_i = sample['answer']
pred_i = sample['parsed_pred']
if sample['question_type'] == 'multiple-choice':
correct = eval_multi_choice(gold_i, pred_i)
else: # open question
correct = eval_open(gold_i, pred_i)
if correct:
judge_dict[sample['id']] = 'Correct'
pred_correct += 1
else:
judge_dict[sample['id']] = 'Wrong'
if len(samples) == 0:
return {'acc': 0}
return judge_dict, {'acc': pred_correct / len(samples)}
# ----------- Calculate Accuracy -------------
def calculate_ins_level_acc(results: Dict):
"""Calculate the instruction level accuracy for given Subject results"""
acc = 0
ins_num = 0
for cat_results in results.values():
acc += cat_results['acc'] * cat_results['num_example']
ins_num += cat_results['num_example']
if ins_num == 0:
return 0
return acc / ins_num
def mmmu_main_eval(output_dict, task_cfg):
answer_dict = json.load(open(task_cfg["answer_dict"]))
# group by category
output_dict_w_cat = {}
for data_id, parsed_pred in output_dict.items():
category = "_".join(data_id.split("_")[1:-1])
if category not in output_dict_w_cat:
output_dict_w_cat.update({category: {}})
output_dict_w_cat[category].update({data_id: parsed_pred})
# group by category
answer_dict_w_cat = {}
for data_id, parsed_pred in answer_dict.items():
category = "_".join(data_id.split("_")[1:-1])
if category not in answer_dict_w_cat:
answer_dict_w_cat.update({category: {}})
answer_dict_w_cat[category].update({data_id: parsed_pred})
evaluation_result = {}
for category in CAT_SHORT2LONG.values():
# print("Evaluating: {}".format(category))
# get cat_outputs and cat_answers
try:
cat_outputs = output_dict_w_cat[category]
cat_answers = answer_dict_w_cat[category]
except KeyError:
print("Skipping {} for not found".format(category))
continue
exampels_to_eval = []
for data_id, parsed_pred in cat_outputs.items():
question_type = cat_answers[data_id]['question_type']
if question_type != 'multiple-choice':
parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.)
else:
parsed_pred = parsed_pred
exampels_to_eval.append({
"id": data_id,
"question_type": question_type,
"answer": cat_answers[data_id]['ground_truth'],
"parsed_pred": parsed_pred
})
judge_dict, metric_dict = evaluate(exampels_to_eval)
metric_dict.update({"num_example": len(exampels_to_eval)})
evaluation_result[category] = metric_dict
printable_results = {}
# pdb.set_trace()
# add domain Subject
for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
in_domain_cat_results = {}
for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
if cat_name in evaluation_result.keys():
in_domain_cat_results[cat_name] = evaluation_result[cat_name]
else:
pass
in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
"acc": round(in_domain_ins_acc, 4)
}
# add sub category
for cat_name, cat_results in in_domain_cat_results.items():
printable_results[cat_name] = {"num": int(cat_results['num_example']),
"acc": round(cat_results['acc'], 4)
}
# table.append(["-----------------------------", "-----", "----"])
all_ins_acc = calculate_ins_level_acc(evaluation_result)
printable_results['Overall'] = {
"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
"acc": round(all_ins_acc, 4)
}
return printable_results
if __name__ == '__main__':
tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi.yaml"))['datasets']
print(tasks)
with open("eval_results.json") as f:
merged_results = json.load(f)
eval_samples = []
eval_output_dict = {}
for res in merged_results:
pred_ans = res["answer"].upper()
gt_ans = res['gt_answer']
if res['question_type'] == 'multiple-choice':
parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
if pred_ans != parsed_pred:
print(f"MC: Original: {pred_ans}, Parsed: {parsed_pred}")
eval_samples.append(
{
'id': res['question_id'],
'question_type': res['question_type'],
'answer': res['gt_answer'], # the content in option, not answer index.
'response': pred_ans,
'parsed_pred': parsed_pred,
'index2ans': res['index2ans'],
}
)
eval_output_dict[res['question_id']] = parsed_pred
else:
parsed_pred = parse_open_response(pred_ans)
if pred_ans != parsed_pred:
print(f"Open: Original: {pred_ans}, Parsed: {parsed_pred}")
eval_samples.append(
{
'id': res['question_id'],
'question_type': res['question_type'],
'answer': res['gt_answer'],
'response': pred_ans,
'parsed_pred': parsed_pred,
}
)
eval_output_dict[res['question_id']] = pred_ans
json.dump(eval_output_dict, open("validation_mmmu_iter6000_merged.0.53.sorted.json", "w"), indent=4, sort_keys=True)
x = mmmu_main_eval(eval_output_dict,
task_cfg=tasks['mmmu'])
print(x)
\ No newline at end of file
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
from torchvision import transforms as T
from torchvision.transforms import Compose
from torchvision.transforms.functional import InterpolationMode
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
pixel_statistics = {
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
"cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
}
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
# Copyright (c) 2023 OpenGVLab.
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
"""
Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
"""
best_factor = float('-inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
factor_based_on_area_n_ratio = (
min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) *
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio))
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
class ImageTransform:
"""Image transformation."""
def __init__(self, input_size, vision_model_type):
self._transform = _build_transform(input_size, vision_model_type)
self._vision_model_type = vision_model_type
def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
assert not augment, "Image augmentation not implemented."
if use_tiling:
assert img_h == img_w, "dynamic tiling expects equal tile height and width"
imgs = dynamic_preprocess(
img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn)
imgs = [self._transform(img) for img in imgs]
else:
imgs = [self._transform(img)]
return imgs
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
# Copyright (c) 2023 OpenGVLab.
def dynamic_preprocess(
image, min_num=1, max_num=6, image_size=448, use_thumbnail=False,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
def _build_transform(input_size, vision_model_type):
if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"):
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std)
])
elif vision_model_type == "clip":
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
transform = Compose([
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
])
elif vision_model_type.startswith("hf://"):
from megatron.core.models.huggingface.module import get_hf_model_type
model_type = get_hf_model_type(vision_model_type)
if "siglip" in model_type:
from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
processor = SiglipImageProcessor(size={"height": input_size, "width": input_size})
def transform(x):
x = x.convert("RGB") if x.mode != "RGB" else x
x = processor(x, return_tensors="pt")
return x["pixel_values"][0]
else:
raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}")
else:
raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}")
return transform
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from functools import partial
import torch
from megatron.core.transformer.transformer_layer import TransformerLayer
def _bias_dropout_add_func_layer_scaling(ls, x_with_bias, residual, prob, training):
x, bias = x_with_bias # unpack
residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out * ls
return out
else:
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out * ls
return out
def bias_dropout_add_unfused_layer_scaling(ls, training):
"""Bias-dropout-add as in Megatron but with added LayerScaling handling."""
def _bias_dropout_add(x_with_bias, residual, prob):
return _bias_dropout_add_func_layer_scaling(ls, x_with_bias, residual, prob, training)
return _bias_dropout_add
def get_bias_dropout_add_layer_scaling(ls, training, fused):
"""Bias-dropout-add as in Megatron but with added LayerScaling handling."""
assert not fused, "Fused bias-dropout-add not implemented for LayerScaling."
return bias_dropout_add_unfused_layer_scaling(ls, training)
# Add LayerScaling to our default TransformerLayer.
class LayerScalingTransformerLayer(TransformerLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ls1 = torch.nn.Parameter(torch.ones(self.config.hidden_size))
self.ls2 = torch.nn.Parameter(torch.ones(self.config.hidden_size))
self.self_attn_bda = partial(self.self_attn_bda, self.ls1)
self.mlp_bda = partial(self.mlp_bda, self.ls2)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.ssm.mlp_layer import MLPLayer
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.torch_norm import WrappedTorchNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_layer_spec(is_vit, normalization) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
if normalization == "LayerNorm":
norm = LNImpl
elif normalization == "RMSNorm":
if HAVE_TE:
norm = TENorm
else:
version = torch.__version__.split('.')
version_geq_2_4 = (
int(TORCH_VERSION[0]) > 2
or (
int(TORCH_VERSION[0]) == 2
and int(TORCH_VERSION[1]) >= 4
)
)
assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm"
if HAVE_APEX:
warnings.warn(f'Apex does not support RMSNorm. Falling back to Torch Norm')
norm = WrappedTorchNorm
else:
raise RuntimeError("unknown normalization", normalization)
mlp = get_mlp_module_spec(use_te=False) # doesn't include norm.
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=norm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
# Padding mask is needed for e.g. Context Parallel.
if padding:
assert not is_vit, "padding_causal mask not used with ViT"
attn_mask_type = AttnMaskType.padding_causal
mlp = get_norm_mlp_module_spec_te()
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
attn_mask_type = AttnMaskType.causal
# Padding mask is needed for e.g. Context Parallel.
if padding:
attn_mask_type = AttnMaskType.padding_causal
return ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py (with MLP removed)
# Using the TE spec because we had problems getting the non-TE spec
# working
attention_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py
# Using the TE spec because we had problems getting the non-TE spec
# working
mlp_layer=ModuleSpec(
module=MLPLayer,
submodules=TransformerLayerSubmodules(
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
),
),
),
)
def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
def get_norm_mlp_module_spec_te() -> ModuleSpec:
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
)
FROM nvcr.io/nvidia/pytorch:25.04-py3
RUN apt update && \
apt -y upgrade && \
apt install -y --no-install-recommends \
software-properties-common \
build-essential \
python3-pip \
python3-dev \
bash \
git \
vim \
python-is-python3 \
default-jre \
net-tools \
wget \
curl \
rsync \
zip \
unzip \
htop \
tmux \
bmon
RUN pip install --upgrade pip
RUN git clone https://github.com/Dao-AILab/causal-conv1d.git && cd causal-conv1d && git checkout && CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install . --no-build-isolation
RUN git clone https://github.com/state-spaces/mamba.git && cd mamba && git checkout && MAMBA_FORCE_BUILD=TRUE pip install . --no-build-isolation
RUN pip install numpy
RUN pip install einops einops-exts sentencepiece braceexpand webdataset packaging
RUN pip install transformers datasets accelerate timm
RUN pip install pytest-cov pytest_mock nltk wrapt
RUN pip install black isort pylint mypy click
RUN pip install mistral-common tiktoken
RUN pip install git+https://github.com/openai/CLIP.git
RUN pip install fairscale fire blobfile
# Use --no-deps for the following to avoid outdated and unnecessary dependencies.
RUN pip install mmf --no-deps
RUN pip install open_clip_torch open-flamingo[eval] --no-deps
RUN pip install zarr "tensorstore==0.1.45"
RUN pip install git+https://github.com/NVIDIA/Megatron-Energon.git#egg=megatron-energon[av_decode]
# Llama-3.1-Nemotron-Nano-VL-8B-V1
See [Hugging Face](https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1) for details.
# Checkpoints
[HuggingFace version](https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1)
[Megatron-Core version](https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1-mcore)
# Setup
## Docker image
See `examples/multimodal/llama_3p1_nemotron_nano_vl_8b_v1/Dockerfile`.
## Dataset preparation
We use [Megatron Energon](https://github.com/NVIDIA/Megatron-Energon) for multimodal dataloading.
## Model
You can download trained tensor parallel size 1 and 4 Megatron checkpoints [here](https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1-mcore).
Alternatively, you can follow the steps in [Model conversion](#model-conversion) and [Training](#training) below to prepare a model
and run pretraining and SFT from scratch using a prepared dataset.
### Model conversion
#### Language model conversion
We start from [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) from HuggingFace.
Please download it and run the following command to convert it to Megatron format.
```
export LLAMA_DOWNLOAD_DIR=<downloaded hf model directory>
CUDA_DEVICE_MAX_CONNECTIONS=1 python tools/checkpoint/convert.py --bf16 --model-type GPT --loader llama_mistral --saver core \
--target-tensor-parallel-size 4 --checkpoint-type hf \
--load-dir $LLAMA_DOWNLOAD_DIR --save-dir llama3p1 --tokenizer-model $LLAMA_DOWNLOAD_DIR \
--saver-transformer-impl transformer_engine --model-size llama3
```
#### Vision model conversion
You can run the following command to convert RADIO to an mcore compatible format:
```
python examples/multimodal/model_converter/radio_converter.py --output radio_tp_4 --tensor-parallel-size 4 --use-te \
--version c-radio_v2-vlm-h --model-type radio_v2.5-h
```
#### Combined checkpoint
Combine the language and vision model by running:
```
examples/multimodal/combine_lm_vision_checkpoints.sh <language model directory> <vision model directory> <output directory>
```
# Training
1. Pretraining: we provide an example pretraining script at `examples/multimodal/llama_3p1_nemotron_nano_vl_8b_v1/pretraining_llama_3p1_nemotron_nano_vl_8b_v1.sh`.
2. SFT: we provide an example SFT script at `examples/multimodal/llama_3p1_nemotron_nano_vl_8b_v1/sft_llama_3p1_nemotron_nano_vl_8b_v1.sh`.
# Inference and evaluation
To run a simple inference example:
```
export LLAMA_NEMOTRON_NANO_VL_PATH=<path to the megatron tp=4 checkpoint>
examples/multimodal/llama_3p1_nemotron_nano_vl_8b_v1/text_generation.sh --model-path $LLAMA_NEMOTRON_NANO_VL_PATH \
--task inference --output-path inference-example --tensor-model-parallel-size 4
```
To evaluate the model, you can change `--task` to `MMMU` or `TextVQA`, for example.
#!/bin/bash
# Your SBATCH commands here if using SLURM.
# Please launch this script from megatron-lm root.
# Train a multimodal model.
export CUDA_DEVICE_MAX_CONNECTIONS=1
USER=$SLURM_JOB_USER
# Auto-detect batch or interactive mode.
which srun
BATCH=$((1-$?))
DEBUG=0
USE_TILING=1
# Remember to update model and job name if running in batch mode!!
if [[ $BATCH -eq 0 ]]; then
DATETIME=`date +'%y-%m-%d-%H-%M-%S'`
MODEL_NAME="interactive_pretraining_llama_3p1_nemotron_nano_vl_8b_v1_${DATETIME}"
SPECIAL_TOKENS="--special-tokens <image> <img> </img> <quad> </quad> <ref> </ref> <box> </box>"
DEBUG=1
else
MODEL_NAME="pretraining_llama_3p1_nemotron_nano_vl_8b_v1"
SPECIAL_TOKENS="--special-tokens \<image\> \<img\> \</img\> \<quad\> \</quad\> \<ref\> \</ref\> \<box\> \</box\>"
fi
WORKSPACE="<some dir>"
SOURCE=`pwd`
OUTPUT_BASE="${WORKSPACE}/output"
OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}"
FINETUNE_DIR=${OUTPUT}/checkpoints
LOGS_DIR="${OUTPUT}/logs"
TENSORBOARD_DIR="${OUTPUT}/tensorboard"
TP=4
CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints"
DATA_TRAIN="${SOURCE}/examples/multimodal/llama_3p1_nemotron_nano_vl_8b_v1/pretrain_blend.yaml"
if [[ $DEBUG -eq 1 ]]; then
MBZ=1
BZ=1
NW=0
AD=0.0
HD=0.0
LI=1
NONDETERMINISTIC_ATTN=1
NUM_GPU=4
export CUDA_VISIBLE_DEVICES=0,1,2,3
else
MBZ=1
BZ=1024
NW=8
AD=0.0
HD=0.0
LI=5
EXTRA_ARGS=""
NONDETERMINISTIC_ATTN=1
NUM_GPU=8
fi
SEQ_LEN=1024
DECODER_SEQ_LEN=4096
if [[ $USE_TILING -eq 1 ]]; then
EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail"
SEQ_LEN=256
fi
OPTIONS=" \
--use-checkpoint-args \
--disable-bias-linear \
--tokenizer-type MultimodalTokenizer \
--tokenizer-model meta-llama/Llama-3.1-8B-Instruct \
--transformer-impl transformer_engine \
--normalization RMSNorm \
--group-query-attention \
--num-query-groups 8 \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--attention-dropout ${AD} \
--hidden-dropout ${HD} \
--untie-embeddings-and-output-weights \
--position-embedding-type rope \
--rotary-percent 1.0 \
--rotary-base 500000 \
--use-rope-scaling \
--swiglu \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size 1 \
--num-layers 32 \
--hidden-size 4096 \
--ffn-hidden-size 14336 \
--num-attention-heads 32 \
--use-distributed-optimizer \
--use-te \
--num-workers ${NW} \
--exit-duration-in-mins 230 \
--seq-length ${SEQ_LEN} \
--decoder-seq-length ${DECODER_SEQ_LEN} \
--max-position-embeddings 131072 \
--train-samples 1491231 \
--lr-warmup-samples 102400 \
--micro-batch-size ${MBZ} \
--global-batch-size ${BZ} \
--lr 2e-4 \
--min-lr 0.0 \
--lr-decay-style cosine \
--log-interval ${LI} \
--eval-iters 10 \
--eval-interval 500 \
--data-path ${DATA_TRAIN} \
--prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \
--save-interval 5000 \
--save ${FINETUNE_DIR} \
--load ${FINETUNE_DIR} \
--dataloader-save ${FINETUNE_DIR}/dataloader \
--pretrained-checkpoint ${CHECKPOINT_DIR} \
--split 100,0,0 \
--clip-grad 1.0 \
--weight-decay 1e-2 \
--adam-beta1 0.9 \
--adam-beta2 0.999 \
--init-method-std 0.02 \
--log-params-norm \
--log-num-zeros-in-grad \
--bf16 \
--eod-mask-loss \
--freeze-ViT \
--freeze-LM \
--patch-dim 16 \
--img-h 512 \
--img-w 512 \
--dataloader-type external \
--tensorboard-dir ${TENSORBOARD_DIR} \
--language-model-type=llama3.1_8b \
${EXTRA_ARGS} \
--distributed-timeout-minutes 60 \
--allow-missing-vision-projection-checkpoint \
--vision-model-type radio \
--tokenizer-prompt-format llama3p1 \
--use-loss-scaling \
${SPECIAL_TOKENS} \
--ckpt-format torch \
--image-tag-type internvl \
--force-system-message \
--disable-vision-class-token \
--use-area-weighted-aspect-ratio \
--inference-max-seq-length 32768 \
"
export NVTE_APPLY_QK_LAYER_SCALING=0
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN}
# Interactive or batch mode
if [[ $BATCH -eq 0 ]]; then
torchrun --nproc_per_node ${NUM_GPU} examples/multimodal/train.py ${OPTIONS}
else
run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}"
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
srun -l --verbose \
--container-image <path to docker image> \
--container-mounts "<some mount>" \
--output=${LOGS_DIR}/%x_%j_$DATETIME.log \
sh -c "echo ${run_cmd}; ${run_cmd}"
set +x
fi
#!/bin/bash
# Your SBATCH commands here if using SLURM.
# Please launch this script from megatron-lm root.
# Train a multimodal model.
export CUDA_DEVICE_MAX_CONNECTIONS=1
USER=$SLURM_JOB_USER
# Auto-detect batch or interactive mode.
which srun
BATCH=$((1-$?))
DEBUG=0
USE_TILING=1
# Remember to update model and job name if running in batch mode!!
if [[ $BATCH -eq 0 ]]; then
DATETIME=`date +'%y-%m-%d-%H-%M-%S'`
MODEL_NAME="interactive_sft_llama_3p1_nemotron_nano_vl_8b_v1_${DATETIME}"
SPECIAL_TOKENS="--special-tokens <image> <img> </img> <quad> </quad> <ref> </ref> <box> </box>"
DEBUG=1
else
MODEL_NAME="sft_llama_3p1_nemotron_nano_vl_8b_v1"
SPECIAL_TOKENS="--special-tokens \<image\> \<img\> \</img\> \<quad\> \</quad\> \<ref\> \</ref\> \<box\> \</box\>"
fi
WORKSPACE="<some dir>"
SOURCE=`pwd`
OUTPUT_BASE="${WORKSPACE}/output"
OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}"
FINETUNE_DIR=${OUTPUT}/checkpoints
LOGS_DIR="${OUTPUT}/logs"
TENSORBOARD_DIR="${OUTPUT}/tensorboard"
TP=4
CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints/pretraining_llama_3p1_nemotron_nano_vl_8b_v1"
DATA_TRAIN="${SOURCE}/examples/multimodal/llama_3p1_nemotron_nano_vl_8b_v1/sft_blend.yaml"
SEQ_LEN=1024
DECODER_SEQ_LEN=16384
if [[ $DEBUG -eq 1 ]]; then
MBZ=1
BZ=2
NW=0
AD=0.0
HD=0.0
LI=1
EVAL_INTERVAL=1
NONDETERMINISTIC_ATTN=1
NUM_GPU=8
else
MBZ=1
BZ=128
NW=8
AD=0.0
HD=0.0
LI=5
EXTRA_ARGS=""
NONDETERMINISTIC_ATTN=1
NUM_GPU=8
EVAL_INTERVAL=2000
fi
if [[ $USE_TILING -eq 1 ]]; then
EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail"
SEQ_LEN=256
fi
OPTIONS=" \
--use-checkpoint-args \
--disable-bias-linear \
--tokenizer-type MultimodalTokenizer \
--tokenizer-model meta-llama/Llama-3.1-8B-Instruct \
--transformer-impl transformer_engine \
--normalization RMSNorm \
--group-query-attention \
--num-query-groups 8 \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--attention-dropout ${AD} \
--hidden-dropout ${HD} \
--untie-embeddings-and-output-weights \
--position-embedding-type rope \
--rotary-percent 1.0 \
--rotary-base 500000 \
--use-rope-scaling \
--swiglu \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size 1 \
--num-layers 32 \
--hidden-size 4096 \
--ffn-hidden-size 14336 \
--num-attention-heads 32 \
--use-distributed-optimizer \
--use-te \
--num-workers ${NW} \
--exit-duration-in-mins 230 \
--seq-length ${SEQ_LEN} \
--decoder-seq-length ${DECODER_SEQ_LEN} \
--max-position-embeddings 131072 \
--train-samples 2494236 \
--lr-warmup-fraction 0.03 \
--micro-batch-size ${MBZ} \
--global-batch-size ${BZ} \
--lr 2e-5 \
--min-lr 0.0 \
--lr-decay-style cosine \
--log-interval ${LI} \
--eval-iters 10 \
--eval-interval ${EVAL_INTERVAL} \
--data-path ${DATA_TRAIN} \
--prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \
--save-interval 2000 \
--save ${FINETUNE_DIR} \
--load ${FINETUNE_DIR} \
--pretrained-checkpoint ${CHECKPOINT_DIR} \
--dataloader-save ${FINETUNE_DIR}/dataloader \
--split 100,0,0 \
--clip-grad 1.0 \
--weight-decay 0.05 \
--adam-beta1 0.9 \
--adam-beta2 0.999 \
--init-method-std 0.014 \
--bf16 \
--eod-mask-loss \
--patch-dim 16 \
--img-h 512 \
--img-w 512 \
--dataloader-type external \
--tensorboard-dir ${TENSORBOARD_DIR} \
--language-model-type=llama3.1_8b \
${EXTRA_ARGS} \
--distributed-timeout-minutes 60 \
--vision-model-type radio \
--tokenizer-prompt-format llama3p1 \
--use-loss-scaling \
--packing-seq-length ${DECODER_SEQ_LEN} \
${SPECIAL_TOKENS} \
--ckpt-format torch \
--image-tag-type internvl \
--disable-vision-class-token \
--recompute-granularity full \
--recompute-method block \
--recompute-num-layers 32 \
--recompute-vision \
--use-area-weighted-aspect-ratio \
--inference-max-seq-length 32768 \
"
export NVTE_APPLY_QK_LAYER_SCALING=0
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN}
# Interactive or batch mode
if [[ $BATCH -eq 0 ]]; then
torchrun --nproc_per_node ${NUM_GPU} examples/multimodal/train.py ${OPTIONS}
else
run_cmd="cd ${SOURCE}; python -u examples/multimodal/train.py ${OPTIONS}"
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
srun -l --verbose \
--container-image <path to docker image> \
--container-mounts "<some mount>" \
--output=${LOGS_DIR}/%x_%j_$DATETIME.log \
sh -c "${run_cmd}"
set +x
fi
#!/bin/bash
export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_APPLY_QK_LAYER_SCALING=0
INPUT_IMAGE_PATH="placeholder"
GROUNDTRUTH_PATH="placeholder"
NUM_FRAMES=1
TP=4
OUT_SEQ_LEN=1024
INFERENCE_MAX_SEQ_LEN=8192
USE_TILING=1
MAX_NUM_TILES=12
while [[ $# -gt 0 ]]; do
case $1 in
--tensor-model-parallel-size)
TP="$2"
shift
shift
;;
--input-image-path)
INPUT_IMAGE_PATH="$2"
shift
shift
;;
--num-frames)
NUM_FRAMES="$2"
shift
shift
;;
--out-seq-length)
OUT_SEQ_LEN="$2"
shift
shift
;;
--inference-max-seq-length)
INFERENCE_MAX_SEQ_LEN="$2"
shift
shift
;;
--max-num-tiles)
MAX_NUM_TILES="$2"
shift
shift
;;
-g|--groundtruth-path)
GROUNDTRUTH_PATH="$2"
shift
shift
;;
-o|--output-path)
OUTPUT_PATH="$2"
shift
shift
;;
-m|--model-path)
MODEL_PATH="$2"
shift
shift
;;
--task)
TASK="$2"
shift
shift
;;
-g|--gt-path)
GROUNDTRUTH_PATH="$2"
shift
shift
;;
-*|--*)
echo "Invalid option $1"
exit 1
;;
esac
done
# Please modify these as needed.
NUM_PARTITIONS=0
START=0
END=0
SEQ_LEN=1024
DECODER_SEQ_LEN=16384
EXTRA_ARGS=""
if [[ $USE_TILING -eq 1 ]]; then
EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles ${MAX_NUM_TILES} --use-thumbnail"
SEQ_LEN=256
fi
for PARTITION_ID in $( eval echo {$START..$END} )
do
torchrun --nproc_per_node ${TP} examples/multimodal/run_text_generation.py \
--attention-softmax-in-fp32 \
--transformer-impl transformer_engine \
--use-te \
--use-checkpoint-args \
--normalization RMSNorm \
--language-model-type=llama3.1_8b \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--position-embedding-type rope \
--rotary-percent 1.0 \
--rotary-base 500000 \
--use-rope-scaling \
--swiglu \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size 1 \
--group-query-attention \
--num-query-groups 8 \
--num-layers 32 \
--hidden-size 4096 \
--ffn-hidden-size 14336 \
--num-attention-heads 32 \
--max-position-embeddings 131072 \
--no-masked-softmax-fusion \
--load ${MODEL_PATH} \
--tokenizer-type MultimodalTokenizer \
--tokenizer-model /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/mcore_mmodal_models/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f/ \
--tokenizer-prompt-format llama3p1 \
--bf16 \
--micro-batch-size 1 \
--seq-length ${SEQ_LEN} \
--decoder-seq-length ${DECODER_SEQ_LEN} \
--out-seq-length ${OUT_SEQ_LEN} \
--inference-max-seq-length ${INFERENCE_MAX_SEQ_LEN} \
--temperature 1.0 \
--img-h 512 \
--img-w 512 \
--patch-dim 16 \
--seed 153 \
--top_k 1 \
--no-load-rng \
--no-load-optim \
--input-image-path ${INPUT_IMAGE_PATH} \
--num-partitions ${NUM_PARTITIONS} \
--partition-id ${PARTITION_ID} \
--output-path ${OUTPUT_PATH} \
--gt-path ${GROUNDTRUTH_PATH} \
--task ${TASK} \
${EXTRA_ARGS} \
--vision-model-type radio \
--num-frames ${NUM_FRAMES} \
--special-tokens "<image>" "<img>" "</img>" "<quad>" "</quad>" "<ref>" "</ref>" "<box>" "</box>" \
--ckpt-format torch \
--image-tag-type internvl \
--disable-vision-class-token \
--force-system-message \
--exit-on-missing-checkpoint
done
{
"COMMENT": "Sources for these prompts include https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/viewer and https://huggingface.co/datasets/HuggingFaceM4/M3IT",
"Captioning": {
"raw": [
"Can you briefly explain what you see in the image?",
"Describe what's happening in this image in one short sentence.",
"Write a short caption that accurately represents the content of this image.",
"Please generate a descriptive caption for the image provided.",
"How would you summarize the scene depicted in the picture in short?",
"Describe the image briefly.",
"Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details.",
"Create a concise caption that accurately describes the main elements in the image provided.",
"Write a brief, yet comprehensive, description of the image.",
"Describe the image in a clear and concise manner.",
"For the given image, provide a one-sentence summary that captures the most important details.",
"Generate a short caption for the picture.",
"Write a short and informative description that highlights the primary subjects and actions occurring in the given image.",
"Provide a concise and informative caption for the image, focusing on the primary subjects.",
"Write a clear description of the image, make sure the key features are well covered.",
"Offer a succinct explanation of the picture presented."
]
},
"CaptioningPretraining": {
"raw": [
"Generate a short caption of the image.",
"Describe the image concisely.",
"Provide a brief description of the given image."
],
"llava": [
"Give a brief description of image.",
"Give a brief description of the image.",
"Provide a brief description of the given image.",
"Provide a one-sentence caption for the provided image.",
"Write a terse but informative summary of the picture.",
"Describe the image concisely.",
"Generate a clear and concise summary of the photo."
]
},
"OCR": {
"raw": [
"Can you read the text from image and output here?",
"Extract and document the text from the provided image.",
"Converting the text embedded in this image into a readable document.",
"Transcribe all the text you find.",
"Can you extract all visible text from the image here?"
]
}
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
import logging
from copy import deepcopy
import torch
from config import get_language_model_config, get_vision_model_config, get_vision_projection_config
from layer_specs import (get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te,
get_mamba_layer_spec_te)
from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core.utils import log_single_rank
def model_provider(
pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
) -> LLaVAModel:
"""Builds the model.
Args:
pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True.
post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True.
add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder
will live on only a subset of the pipeline stages (specifically, only the first stage).
add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder
will live on only a subset of the pipeline stages (specifically, every stage after the first one).
parallel_output (bool): Enable parallel model output.
Returns:
model: A multimodal model.
"""
args = get_args()
# Deprecation warning for encoder pipeline parallelism
if args.encoder_pipeline_model_parallel_size > 0 or args.encoder_tensor_model_parallel_size > 0:
warnings.warn(
"Encoder-specific pipeline parallelism functionality is deprecated and will be removed in core_r0.14.0. "
"This includes the parameters 'encoder_tensor_model_parallel_size' and 'encoder_pipeline_model_parallel_size', "
"as well as all associated encoder pipeline parallel logic and infrastructure. "
"This functionality is being replaced by the new 'orthotope' parallelism management system, which provides "
"a more general and flexible approach to handling complex parallelism configurations including encoder-decoder models. "
"Please refrain from building new features or dependencies on encoder pipeline parallelism as this entire "
"capability will not be supported in future releases. For migration guidance and information on the orthotope "
"system, please refer to the Megatron-LM documentation.",
DeprecationWarning,
stacklevel=2
)
assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank"
use_te = args.use_te
print_rank_0('building a multimodal model ...')
num_image_embeddings = get_num_image_embeddings(
args.img_h,
args.img_w,
args.patch_dim,
args.vision_model_type,
args.disable_vision_class_token,
1,
args.pixel_shuffle,
args.use_tile_tags,
args.max_num_tiles,
args.tokenizer_prompt_format
)
old_seq_length = args.seq_length
args.seq_length = args.encoder_seq_length = num_image_embeddings
if old_seq_length != args.seq_length:
log_single_rank(
logging.getLogger(__name__),
logging.WARNING,
f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
)
max_num_image_embeddings = max((args.max_num_tiles + int(args.use_thumbnail)), args.num_frames) * num_image_embeddings
assert (
args.decoder_seq_length is not None
), "Please provide --decoder-seq-length to set the language model sequence length"
assert (
args.decoder_seq_length > max_num_image_embeddings
), "Language model sequence length must be greater than the maximum number of image embeddings"
if args.decoder_seq_length > args.max_position_embeddings:
args.max_position_embeddings = args.decoder_seq_length
warnings.warn(
f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length"
)
language_model_type = args.language_model_type
vision_model_type = args.vision_model_type
base_config = core_transformer_config_from_args(get_args())
base_config.language_model_type = args.language_model_type
base_config.vision_model_type = args.vision_model_type
base_config.calculate_per_token_loss = True
language_config = deepcopy(base_config)
language_config = get_language_model_config(language_config)
if language_model_type.startswith("hf://"):
assert args.tensor_model_parallel_size == 1, "Huggingface models do not support --tensor-model-parallel-size > 1"
assert args.pipeline_model_parallel_size < 2, "Huggingface models do not support --pipeline-model-parallel-size > 1"
assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
if language_model_type.startswith("hf://"):
language_transformer_layer_spec = None
elif use_te:
# Padding mask needed for SP/CP.
padding = args.context_parallel_size > 1 and args.sequence_parallel
if args.language_model_type.startswith('nemotron5-hybrid'):
language_transformer_layer_spec = get_mamba_layer_spec_te(padding=padding)
else:
language_transformer_layer_spec = get_layer_spec_te(
is_vit=False, padding=padding
) # TENorm detects LayerNorm/RMS automatically.
else:
language_transformer_layer_spec = get_layer_spec(
is_vit=False, normalization=language_config.normalization
)
vision_config = deepcopy(base_config)
vision_config = get_vision_model_config(
vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling
)
if vision_model_type.startswith("hf://"):
assert args.encoder_tensor_model_parallel_size < 2, "Huggingface vision encoders do not support --encoder-tensor-model-parallel-size > 1"
assert args.encoder_pipeline_model_parallel_size == 0, "Huggingface vision encoders do not support --encoder-pipeline-model-parallel-size > 0"
assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
if vision_model_type in ["clip", "siglip", "radio", "cradio-g"]:
if use_te:
vision_transformer_layer_spec = get_layer_spec_te(
is_vit=True
) # TENorm detects LayerNorm/RMS automatically.
else:
vision_transformer_layer_spec = get_layer_spec(
is_vit=True, normalization=vision_config.normalization
)
elif vision_model_type == "radio-g":
if use_te:
from radio.radio_g import get_radio_g_layer_spec_te
vision_transformer_layer_spec = get_radio_g_layer_spec_te() # TENorm detects LayerNorm/RMS automatically.
else:
from radio.radio_g import get_radio_g_layer_spec
vision_transformer_layer_spec = get_radio_g_layer_spec(
normalization=vision_config.normalization
)
elif vision_model_type == "internvit":
from nvlm.internvit import get_internvit_layer_spec
vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te)
elif vision_model_type == "internvit300M":
from nvlm.internvit import get_internvit300M_layer_spec
vision_transformer_layer_spec = get_internvit300M_layer_spec(use_te=use_te)
elif vision_model_type.startswith("hf://"):
vision_transformer_layer_spec = None
else:
raise RuntimeError("unsupported vision model type", vision_model_type)
vision_projection_config = deepcopy(base_config)
vision_projection_config = get_vision_projection_config(
vision_projection_config, language_config.hidden_size
)
# --encoder-pipeline-model-parallel-size 1 will enable a separate pipeline stage for the vision model.
if args.encoder_pipeline_model_parallel_size > 0:
assert (
args.encoder_pipeline_model_parallel_size == 1
), "vision model and projection can only live on 1 pipeline stage."
if args.encoder_tensor_model_parallel_size > 0:
vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
vision_projection_config.tensor_model_parallel_size = (
args.encoder_tensor_model_parallel_size
)
# Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size.
# 0 is not a valid for the config value, hence max(1, ).
vision_config.pipeline_model_parallel_size = max(1, args.encoder_pipeline_model_parallel_size)
vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size
# Make sure the vision model does not inherit first and last pipeline num layers from the language model.
vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None
if vision_projection_config.normalization:
vision_projection_layer_spec = get_norm_mlp_module_spec_te().submodules
else:
vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules
# Toggle --recompute* for the vision and language model separately.
if args.recompute_vision:
if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None:
vision_config.recompute_num_layers = vision_config.num_layers
else:
vision_config.recompute_granularity = None
vision_config.recompute_method = None
vision_config.recompute_num_layers = None
vision_projection_config.recompute_granularity = None
vision_projection_config.recompute_method = None
vision_projection_config.recompute_num_layers = None
# TODO: Vision model and projection do not use SP/CP yet.
vision_config.sequence_parallel = False
vision_config.context_parallel_size = 1
vision_config.tp_comm_overlap = False
vision_projection_config.sequence_parallel = False
vision_projection_config.context_parallel_size = 1
vision_projection_config.tp_comm_overlap = False
tokenizer = get_tokenizer()
image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
assert image_token_index is not None, f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg."
tile_tags = _get_tile_tags(args, tokenizer)
model = LLaVAModel(
language_transformer_config=language_config,
language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.decoder_seq_length,
vision_transformer_config=vision_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.disable_vision_class_token,
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_layer_spec,
vision_projection_type="mlp",
allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint,
parallel_output=parallel_output,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
language_position_embedding_type=args.position_embedding_type,
language_rotary_percent=args.rotary_percent,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder,
img_h=args.img_h,
img_w=args.img_w,
patch_dim=args.patch_dim,
language_rotary_base=args.rotary_base,
language_rope_scaling=args.use_rope_scaling,
hybrid_attention_ratio=args.hybrid_attention_ratio,
hybrid_mlp_ratio=args.hybrid_mlp_ratio,
hybrid_override_pattern=args.hybrid_override_pattern,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
image_token_index=image_token_index,
pixel_shuffle=args.pixel_shuffle,
tile_tags=tile_tags,
max_num_tiles=args.max_num_tiles,
tokenizer_type=args.tokenizer_prompt_format,
)
model.freeze(
freeze_language_model=args.freeze_LM,
freeze_vision_model=args.freeze_ViT,
freeze_vision_projection=False,
)
return model
def _get_tile_tags(args, tokenizer):
"""Tile tags are used in NVLM to surround image tiles with text tags."""
if not args.use_tile_tags:
return None
# We expect the tokenized length of the tags is same.
if args.max_num_tiles < 10:
thumbnail_tag_text = "<tile_global_thumbnail>"
if args.tokenizer_prompt_format == "nvlm-yi-34b":
thumbnail_tag_text = "<tile_global>"
if args.tokenizer_prompt_format.startswith("nemotron"):
tile_tags_text = [f"<tile_{i:02d}>" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text]
else:
tile_tags_text = [f"<tile_{i}>" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text]
elif args.max_num_tiles <= 12:
thumbnail_tag_text = "<tile_global_thumbnail0>"
if args.tokenizer_prompt_format == "nvlm-yi-34b":
thumbnail_tag_text = "<tile_global0>"
elif args.tokenizer_prompt_format.startswith("nemotron") or args.tokenizer_prompt_format == "llama3p1":
thumbnail_tag_text = "<tile_global_thumbnail>"
tile_tags_text = [f"<tile_{i:02d}>" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text]
else:
raise ValueError("We only support max_num_tiles <= 12 when using nvlm image_tag_type")
start_idx = 0
if tokenizer._prompt_config.has_bos:
start_idx = 1
# Convert to tokens [num_tiles, tile_seq_len].
tile_tags = [tokenizer.tokenize(t)[start_idx:] for t in tile_tags_text]
return tile_tags
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import argparse
import os
import torch
import clip
def convert(download_root, output_path, tensor_parallel_size, use_te):
device = "cuda"
model, _ = clip.load("ViT-L/14@336px", device=device, download_root=download_root)
state_dict = model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
# Indices from mapping pytorch multihead attention to megatron.
kv_channels = 64
hidden_dim = 1024
num_heads = 16
indices = []
for i in range(num_heads):
lb = i * kv_channels
ub = (i + 1) * kv_channels
indices.append(torch.arange(lb, ub, dtype=torch.int))
indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int))
indices.append(torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int))
indices = torch.cat(indices)
for name, tensor in state_dict.items():
# Skip text model.
if "visual" not in name:
continue
# Skip final layers not used in our model.
if name == "visual.proj" or "ln_post" in name:
continue
# Map parameter names to ones used in megatron.
new_name = ""
new_tensor = tensor
if new_tensor.dtype == torch.float16:
new_tensor = new_tensor.to(torch.float32)
# This is used for chunking some tensors to target tensor parallel size.
chunk_dim = None
if "class_embedding" in name:
new_name = "class_token"
# Our model uses class token that is expanded to input dimensions already.
new_tensor = new_tensor.expand(1, 1, -1)
elif "positional_embedding" in name:
new_name = "position_embeddings.weight"
elif "conv1" in name:
new_name = "conv1.weight"
elif "ln_pre.weight" in name:
new_name = "ln_pre.weight"
elif "ln_pre.bias" in name:
new_name = "ln_pre.bias"
elif "transformer.resblocks" in name:
layer_idx = name.split(".")[3]
base = f"decoder.layers.{layer_idx}"
if "attn.in_proj_weight" in name:
new_name = f"{base}.self_attention.linear_qkv.weight"
new_tensor = new_tensor[indices]
chunk_dim = 0
elif "attn.in_proj_bias" in name:
new_name = f"{base}.self_attention.linear_qkv.bias"
new_tensor = new_tensor[indices]
chunk_dim = 0
elif "attn.out_proj.weight" in name:
new_name = f"{base}.self_attention.linear_proj.weight"
chunk_dim = 1
elif "attn.out_proj.bias" in name:
new_name = f"{base}.self_attention.linear_proj.bias"
elif "ln_1.weight" in name:
new_name = f"{base}.input_layernorm.weight"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight"
elif "ln_1.bias" in name:
new_name = f"{base}.input_layernorm.bias"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias"
elif "mlp.c_fc.weight" in name:
new_name = f"{base}.mlp.linear_fc1.weight"
chunk_dim = 0
elif "mlp.c_fc.bias" in name:
new_name = f"{base}.mlp.linear_fc1.bias"
chunk_dim = 0
elif "mlp.c_proj.weight" in name:
new_name = f"{base}.mlp.linear_fc2.weight"
chunk_dim = 1
elif "mlp.c_proj.bias" in name:
new_name = f"{base}.mlp.linear_fc2.bias"
elif "ln_2.weight" in name:
new_name = f"{base}.pre_mlp_layernorm.weight"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_weight"
elif "ln_2.bias" in name:
new_name = f"{base}.pre_mlp_layernorm.bias"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_bias"
assert new_name != "", f"unexpected layer name {name}"
if chunk_dim is None:
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else:
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
for i in range(tensor_parallel_size):
# chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage.
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2]
if layer in extra_state_layers:
extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name.
new_state_dicts[i]["model"][extra_state_name] = None
for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}")
os.makedirs(output_dir_tp)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Convert OpenAI CLIP VIT weights to megatron format.
Example usage:
python clip_converter.py --download-root /some/download/folder --output /some/output/folder --tensor-parallel-size 4
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--download-root", type=str, required=True, help="Download folder for OpenAI CLIP weights"
)
parser.add_argument(
"--output", type=str, required=True, help="output directory for megatron state dict file(s)"
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="model tensor parallel size"
)
parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine")
args = parser.parse_args()
convert(args.download_root, args.output, args.tensor_parallel_size, args.use_te)
print("done.")
import argparse
import os
import torch
from transformers import AutoModel
def convert(model_name, output_path, tensor_parallel_size, use_te):
"""Convert InternViT HF checkpoint to mcore."""
hf_model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
)
hf_state_dict = hf_model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
hidden_size = 3200
num_heads = 25
dim = 128
order = torch.ones(3 * hidden_size).long()
for j in range(num_heads):
for i in range(dim):
order[i + dim*3*j] = j*dim+i
order[dim + i + dim*3*j] = j*dim+i+num_heads*dim
order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2
for name, tensor in hf_state_dict.items():
# Map parameter names to ones used in megatron.
new_name = ""
new_tensor = tensor
# This is used for chunking some tensors to target tensor parallel size.
chunk_dim = None
if "embeddings.class_embedding" in name:
new_name = "class_token"
elif "embeddings.patch_embedding.weight" in name:
new_name = "conv1.weight"
elif "embeddings.patch_embedding.bias" in name:
new_name = "conv1.bias"
elif "embeddings.position_embedding" in name:
new_name = "position_embeddings.weight"
new_tensor = new_tensor.squeeze(0)
elif "encoder.layers" in name:
layer_idx = name.split(".")[2]
base = f"decoder.layers.{layer_idx}"
head_dim = 128
if tensor_parallel_size == 1:
num_padded_heads = 25
elif tensor_parallel_size == 8:
# Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism.
# So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model.
num_padded_heads = 32
else:
raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size)
if "ls1" in name:
new_name = f"{base}.ls1"
elif "ls2" in name:
new_name = f"{base}.ls2"
elif "attn.qkv.weight" in name:
new_name = f"{base}.self_attention.linear_qkv.weight"
num_tensors = 3
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0], :] = new_tensor[order]
new_tensor = padded_tensor
chunk_dim = 0
elif "attn.q_norm.weight" in name:
new_name = f"{base}.self_attention.q_layernorm.weight"
num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0]] = new_tensor
new_tensor = padded_tensor
chunk_dim = 0
elif "attn.k_norm.weight" in name:
new_name = f"{base}.self_attention.k_layernorm.weight"
num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0]] = new_tensor
new_tensor = padded_tensor
chunk_dim = 0
elif "attn.proj.weight" in name:
new_name = f"{base}.self_attention.linear_proj.weight"
num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:, :new_tensor.shape[-1]] = new_tensor
new_tensor = padded_tensor
chunk_dim = 1
elif "attn.proj.bias" in name:
new_name = f"{base}.self_attention.linear_proj.bias"
elif "mlp.fc1.weight" in name:
new_name = f"{base}.mlp.linear_fc1.weight"
chunk_dim = 0
elif "mlp.fc1.bias" in name:
new_name = f"{base}.mlp.linear_fc1.bias"
chunk_dim = 0
elif "mlp.fc2.weight" in name:
new_name = f"{base}.mlp.linear_fc2.weight"
chunk_dim = 1
elif "mlp.fc2.bias" in name:
new_name = f"{base}.mlp.linear_fc2.bias"
elif "norm1" in name:
new_name = f"{base}.input_layernorm.weight"
elif "norm2" in name:
new_name = f"{base}.pre_mlp_layernorm.weight"
else:
raise RuntimeError("unexpected transformer layer name", name)
else:
raise RuntimeError("unexpected layer name", name)
assert new_name != "", f"unexpected layer name {name}"
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2]
if layer in extra_state_layers:
extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name.
for i in range(tensor_parallel_size):
new_state_dicts[i]["model"][extra_state_name] = None
if chunk_dim is None:
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else:
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
for i in range(tensor_parallel_size):
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}")
os.makedirs(output_dir_tp, exist_ok=True)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)
print("saved file", output_path_tp)
print("done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter")
parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace")
parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.")
parser.add_argument("--use-te", action="store_true", default=True)
parser.add_argument("--tensor-parallel-size", type=int, required=True)
args = parser.parse_args()
convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import argparse
import os
import torch
def convert_radio_h(output_path, tensor_parallel_size, use_te, version):
device = "cuda"
version = version if version is not None else 'radio_v2.5-h'
model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=version, progress=True)
state_dict = model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
# Indices from mapping pytorch multihead attention to megatron.
kv_channels = 80
hidden_dim = 1280
num_heads = 16
indices = []
for i in range(num_heads):
lb = i * kv_channels
ub = (i + 1) * kv_channels
indices.append(torch.arange(lb, ub, dtype=torch.int))
indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int))
indices.append(torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int))
indices = torch.cat(indices)
for name, tensor in state_dict.items():
# Map parameter names to ones used in megatron.
new_name = ""
new_tensor = tensor
if new_tensor.dtype == torch.float16:
new_tensor = new_tensor.to(torch.float32)
# This is used for chunking some tensors to target tensor parallel size.
chunk_dim = None
if "summary_idxs" in name:
continue
elif "patch_generator" in name:
if "embedder" in name:
new_name = "embedder.weight"
chunk_dim = 0
elif "cls_token" in name:
new_name = "class_token"
elif "pos_embed" in name:
new_name = "position_embeddings"
elif "input_conditioner" in name:
continue
elif "blocks" in name:
layer_idx = name.split(".")[2]
base = f"decoder.layers.{layer_idx}"
if "attn.qkv.weight" in name:
new_name = f"{base}.self_attention.linear_qkv.weight"
new_tensor = new_tensor[indices]
chunk_dim = 0
elif "attn.qkv.bias" in name:
new_name = f"{base}.self_attention.linear_qkv.bias"
new_tensor = new_tensor[indices]
chunk_dim = 0
elif "attn.proj.weight" in name:
new_name = f"{base}.self_attention.linear_proj.weight"
chunk_dim = 1
elif "attn.proj.bias" in name:
new_name = f"{base}.self_attention.linear_proj.bias"
elif "norm1.weight" in name:
new_name = f"{base}.input_layernorm.weight"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight"
elif "norm1.bias" in name:
new_name = f"{base}.input_layernorm.bias"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias"
elif "mlp.fc1.weight" in name:
new_name = f"{base}.mlp.linear_fc1.weight"
chunk_dim = 0
elif "mlp.fc1.bias" in name:
new_name = f"{base}.mlp.linear_fc1.bias"
chunk_dim = 0
elif "mlp.fc2.weight" in name:
new_name = f"{base}.mlp.linear_fc2.weight"
chunk_dim = 1
elif "mlp.fc2.bias" in name:
new_name = f"{base}.mlp.linear_fc2.bias"
elif "norm2.weight" in name:
new_name = f"{base}.pre_mlp_layernorm.weight"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_weight"
elif "norm2.bias" in name:
new_name = f"{base}.pre_mlp_layernorm.bias"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_bias"
assert new_name != "", f"unexpected layer name {name}"
if chunk_dim is None:
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else:
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
for i in range(tensor_parallel_size):
# chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage.
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2]
if layer in extra_state_layers:
extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name.
new_state_dicts[i]["model"][extra_state_name] = None
for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}")
os.makedirs(output_dir_tp)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)
with open(os.path.join(output_path, "latest_checkpointed_iteration.txt"), "w") as f:
f.write("1")
def convert_radio_g(output_path, tensor_parallel_size, use_te, version):
device = "cuda"
version = version if version is not None else 'radio_v2.5-g'
model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=version, progress=True)
state_dict = model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
# Indices from mapping pytorch multihead attention to megatron.
kv_channels = 64
hidden_dim = 1536
num_heads = 24
ffn_hidden_dim = 4096
indices = []
for i in range(num_heads):
lb = i * kv_channels
ub = (i + 1) * kv_channels
indices.append(torch.arange(lb, ub, dtype=torch.int))
indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int))
indices.append(torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int))
indices = torch.cat(indices)
mlp_indices = []
step = ffn_hidden_dim // tensor_parallel_size
for i in range(tensor_parallel_size):
mlp_indices.append(torch.arange(i * step, (i + 1) * step, dtype=torch.int))
mlp_indices.append(torch.arange(ffn_hidden_dim + i * step, ffn_hidden_dim + (i + 1) * step, dtype=torch.int))
mlp_indices = torch.cat(mlp_indices)
for name, tensor in state_dict.items():
# Map parameter names to ones used in megatron.
new_names = []
new_tensor = tensor
if new_tensor.dtype == torch.float16:
new_tensor = new_tensor.to(torch.float32)
new_tensors = [new_tensor]
# This is used for chunking some tensors to target tensor parallel size.
chunk_dim = None
if "model" not in name:
continue;
elif "patch_generator" in name:
if "embedder.weight" in name:
new_names.append("embedder.weight")
chunk_dim = 0
elif "embedder.bias" in name:
new_names.append("embedder.bias")
chunk_dim = 0
elif "cls_token" in name:
new_names.append("class_token")
elif "pos_embed" in name:
new_names.append("position_embeddings")
elif "input_conditioner" in name:
continue;
elif "mask_token" in name:
new_names.append("mask_token")
elif "inner.norm" in name:
if "norm.weight" in name:
new_names.append("ln_post.weight")
elif "norm.bias" in name:
new_names.append("ln_post.bias")
elif "blocks" in name:
layer_idx = name.split(".")[3]
base = f"decoder.layers.{layer_idx}"
if "attn.qkv.weight" in name:
new_names.append(f"{base}.self_attention.linear_qkv.weight")
new_tensors[0] = new_tensors[0][indices]
chunk_dim = 0
elif "attn.qkv.bias" in name:
new_names.append(f"{base}.self_attention.linear_qkv.bias")
new_tensors[0] = new_tensors[0][indices]
chunk_dim = 0
elif "attn.proj.weight" in name:
new_names.append(f"{base}.self_attention.linear_proj.weight")
chunk_dim = 1
elif "attn.proj.bias" in name:
new_names.append(f"{base}.self_attention.linear_proj.bias")
elif "norm1.weight" in name:
new_name = f"{base}.input_layernorm.weight"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight"
new_names.append(new_name)
elif "norm1.bias" in name:
new_name = f"{base}.input_layernorm.bias"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias"
new_names.append(new_name)
elif "mlp.w12.weight" in name:
new_names.append(f"{base}.mlp.linear_fc1.weight")
new_tensors[0] = new_tensors[0][mlp_indices]
chunk_dim = 0
elif "mlp.w12.bias" in name:
new_names.append(f"{base}.mlp.linear_fc1.bias")
new_tensors[0] = new_tensors[0][mlp_indices]
chunk_dim = 0
elif "mlp.w3.weight" in name:
new_names.append(f"{base}.mlp.linear_fc2.weight")
chunk_dim = 1
elif "mlp.w3.bias" in name:
new_names.append(f"{base}.mlp.linear_fc2.bias")
elif "norm2.weight" in name:
new_name = f"{base}.pre_mlp_layernorm.weight"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_weight"
new_names.append(new_name)
elif "norm2.bias" in name:
new_name = f"{base}.pre_mlp_layernorm.bias"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_bias"
new_names.append(new_name)
elif "ls1.grandma" in name:
new_names.append(f"{base}.ls1")
elif "ls2.grandma" in name:
new_names.append(f"{base}.ls2")
assert len(new_names) == len(new_tensors), f"{new_names} {new_tensors}"
for new_name, new_tensor in zip(new_names, new_tensors):
if chunk_dim is None:
tp_new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else:
tp_new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
for i in range(tensor_parallel_size):
# chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage.
new_state_dicts[i]["model"][new_name] = tp_new_tensors[i].clone()
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2]
if layer in extra_state_layers:
extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name.
new_state_dicts[i]["model"][extra_state_name] = None
for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}")
os.makedirs(output_dir_tp)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)
with open(os.path.join(output_path, "latest_checkpointed_iteration.txt"), "w") as f:
f.write("1")
def convert(output_path, tensor_parallel_size, use_te, model_type, version):
if model_type == "radio_v2.5-h":
convert_radio_h(output_path, tensor_parallel_size, use_te, version)
elif model_type == "radio_v2.5-g":
convert_radio_g(output_path, tensor_parallel_size, use_te, version)
else:
raise NotImplementedError(f"Converter doesn't support model type {model_type}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Convert RADIO weights to megatron format.
Example usage:
python radio_converter.py --output /some/output/folder --tensor-parallel-size 4
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--output", type=str, required=True, help="output directory for megatron state dict file(s)"
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="model tensor parallel size"
)
parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine")
parser.add_argument("--model-type", required=True, type=str, choices=['radio_v2.5-h', 'radio_v2.5-g'], help="Type of radio to load for conversion")
parser.add_argument("--version", type=str, default=None, help="Version to pass to torch.hub.load. Can be a local path or a version RADIO on torch hub. By default use the version from the model type.")
args = parser.parse_args()
convert(args.output, args.tensor_parallel_size, args.use_te, args.model_type, args.version)
print("done.")
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import argparse
import os
from transformers import PaliGemmaForConditionalGeneration
import torch
def convert(output_path, tensor_parallel_size, use_te):
device = "cuda"
model_id = "google/paligemma-3b-pt-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
model = model.to(device)
print(model.config)
for name, tensor in model.state_dict().items():
if "vision_model" not in name:
continue
shape_str = "(" + ", ".join([str(x) for x in tensor.shape]) + ")"
print(f"{name:<75} {shape_str:>20}")
state_dict = model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
def add_chunck_tensor(new_tensor, new_name, chunk_dim=None):
if chunk_dim is None:
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else:
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
for i in range(tensor_parallel_size):
# chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage.
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2]
if layer in extra_state_layers:
extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name.
new_state_dicts[i]["model"][extra_state_name] = None
for name, tensor in state_dict.items():
if tensor.dtype == torch.float16:
state_dict[name] = tensor.to(torch.float32)
add_chunck_tensor(
state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"],
"position_embeddings.weight")
add_chunck_tensor(
state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"],
"conv1.weight")
add_chunck_tensor(
state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"],
"conv1.bias")
head_dim = 72
num_head = 16
for layer_idx in range(27):
origin_base = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
target_base = f"decoder.layers.{layer_idx}"
for param_type in ["weight", "bias"]:
# QKV
q_proj_params = state_dict[f"{origin_base}.self_attn.q_proj.{param_type}"]
k_proj_params = state_dict[f"{origin_base}.self_attn.k_proj.{param_type}"]
v_proj_params = state_dict[f"{origin_base}.self_attn.v_proj.{param_type}"]
# Do some tensor manipulation because megatron expect one tensor
# projection for the QKV in the order
# [(Q1, K1, V1), (Q2, K2, V2), ...] where Qi is the query of the
# i-th head with dimension num_head.
new_tensor = torch.concatenate([
q_proj_params.view(num_head, head_dim, -1),
k_proj_params.view(num_head, head_dim, -1),
v_proj_params.view(num_head, head_dim, -1)], axis=1).view(
3*head_dim*num_head, -1)
if param_type == "bias":
new_tensor = new_tensor[:, 0]
new_name = f"{target_base}.self_attention.linear_qkv.{param_type}"
add_chunck_tensor(new_tensor, new_name, chunk_dim=0)
# linear_proj
add_chunck_tensor(
state_dict[f"{origin_base}.self_attn.out_proj.{param_type}"],
f"{target_base}.self_attention.linear_proj.{param_type}",
chunk_dim=1 if param_type == "weight" else None)
# layer_norm
new_name = f"{target_base}.input_layernorm.{param_type}"
if use_te:
new_name = f"{target_base}.self_attention.linear_qkv.layer_norm_{param_type}"
add_chunck_tensor(
state_dict[f"{origin_base}.layer_norm1.{param_type}"],
new_name)
# FC 1
add_chunck_tensor(
state_dict[f"{origin_base}.mlp.fc1.{param_type}"],
f"{target_base}.mlp.linear_fc1.{param_type}",
chunk_dim=0)
# FC 2
add_chunck_tensor(
state_dict[f"{origin_base}.mlp.fc2.{param_type}"],
f"{target_base}.mlp.linear_fc2.{param_type}",
chunk_dim=1 if param_type=="weight" else None)
# layer_norm
new_name = f"{target_base}.pre_mlp_layernorm.{param_type}"
if use_te:
new_name = f"{target_base}.mlp.linear_fc1.layer_norm_{param_type}"
add_chunck_tensor(
state_dict[f"{origin_base}.layer_norm2.{param_type}"],
new_name)
add_chunck_tensor(
state_dict["vision_tower.vision_model.post_layernorm.weight"],
"ln_post.weight")
add_chunck_tensor(
state_dict["vision_tower.vision_model.post_layernorm.bias"],
"ln_post.bias")
for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}")
os.makedirs(output_dir_tp)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Convert SigLIP weights to megatron format.
Example usage:
python siglip_converter.py --tensor-parallel-size 4 --output google_paligemma_3b_pt_44_mcore_tp_4 --use-te
examples/multimodal/combine_mistral_clip.sh Mistral-7B-Instruct-v0.3-mcore-tp4 google_paligemma_3b_pt_44_mcore_tp_4 mistral_7b_instruct_v0p3_google_paligemma_3b_pt_44_mcore_tp_4
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--output", type=str, required=True, help="output directory for megatron state dict file(s)"
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="model tensor parallel size"
)
parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine")
args = parser.parse_args()
convert(args.output, args.tensor_parallel_size, args.use_te)
print("done.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment