Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import argparse
import os
import sys
import torch
# Add megatron to the path.
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)
def combine(input_files, module_prefixes, output_files):
num_inputs_per_output = int(len(input_files) / len(output_files))
for output_idx, output_file in enumerate(output_files):
combined_state_dict = None
lb = output_idx * num_inputs_per_output
ub = (output_idx + 1) * num_inputs_per_output
current_input_files = input_files[lb:ub]
current_module_prefixes = module_prefixes[lb:ub]
for i, (input_file, module_prefix) in enumerate(
zip(current_input_files, current_module_prefixes)
):
# initialize the combined state dict using the first provided input file
current_state_dict = torch.load(input_file)
if i == 0:
combined_state_dict = current_state_dict.copy()
combined_state_dict["model"] = dict()
# copy model state dict and prefix names with the given module keys.
for k, v in current_state_dict["model"].items():
combined_state_dict["model"]["%s.%s" % (module_prefix, k)] = v
torch.save(combined_state_dict, output_file)
print("saved:", output_file)
print("done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Combine multiple state dicts into a single state dict.
The combined state dict is first initialized by taking a copy of the first provided input state dict.
To avoid conflicts in model parameter names, a prefix must be provided for each input file.
Model parameter names will be renamed from <original name> to <model prefix>.<original name>.
Example usage:
python combine_state_dicts.py --input language_model.pt vision_model.pt --prefixes language_model vision_model --output multimodal.pt
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--input", nargs="*", required=True, help="paths to input state dict files")
parser.add_argument(
"--prefixes",
nargs="*",
required=True,
help="prefixes to use with each input model's parameters",
)
parser.add_argument(
"--output", nargs="*", required=True, help="path(s) to output state dict file"
)
args = parser.parse_args()
assert len(args.input) > 1, "must provide more than 1 input model to combine"
assert len(args.input) == len(args.prefixes), "each input model must have a corresponding key"
assert (
len(args.input) % len(args.output) == 0
), "each output file must use the same number of input files"
combine(args.input, args.prefixes, args.output)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.training.activations import quick_gelu, squared_relu
def get_language_model_config(config):
if config.language_model_type == "2b":
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = True
config.bias_dropout_fusion = False
config.rotary_percent = 0.5
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
elif config.language_model_type == "8b":
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = False
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = True
config.bias_dropout_fusion = False
config.rotary_percent = 0.5
config.attention_dropout = 0.0
config.apply_rope_fusion = False
config.activation_func = squared_relu
config.ffn_hidden_size = 16384
config.masked_softmax_fusion = True
config.attention_softmax_in_fp32 = True
config.num_query_groups = 32
config.kv_channels = 128
config.rotary_interleaved = False
elif config.my_model_type == "llama3_8b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm
)
config.bias_dropout_fusion = False
config.te_attn_mask_type = None
config.rotary_percent = 0.5
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
return config
def get_vision_model_config(config, apply_query_key_layer_scaling=False):
config.num_layers = 24
config.num_attention_heads = 16
config.add_bias_linear = True
config.add_qkv_bias = True
config.hidden_size = 1024
config.hidden_dropout = 0.0
config.attention_dropout = 0.0
config.ffn_hidden_size = 4096
config.gated_linear_unit = False
config.activation_func = quick_gelu
config.kv_channels = 64
config.num_attention_heads = 16
config.num_query_groups = 16
config.layernorm_zero_centered_gamma = False
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
config.bias_activation_fusion = False
config.bias_dropout_fusion = False
config.attention_softmax_in_fp32 = True
return config
def get_vision_projection_config(config, hidden_size):
config.gated_linear_unit = False
config.bias_activation_fusion = False
config.add_bias_linear = False
config.hidden_size = hidden_size
if config.language_model_type == "2b":
config.ffn_hidden_size = 5440
config.activation_func = torch.nn.functional.gelu
if config.language_model_type == "8b":
config.ffn_hidden_size = 16384
config.activation_func = squared_relu
elif config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.silu
return config
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from dataset_helpers import TaskEncoder, print_error_handler
from megatron.core import mpu
from megatron.energon import (
LimitDataset,
RepeatDataset,
WorkerConfig,
get_loader,
get_savable_loader,
get_train_dataset,
get_val_datasets,
)
from megatron.training import get_args, get_num_microbatches, print_rank_0
from megatron.training.checkpointing import get_checkpoint_name
def datasets_provider(worker_config=None):
"""Create multimodal train, validation and test datasets."""
args = get_args()
dname = args.data_path[0] if type(args.data_path) is list else args.data_path
train_dataset = get_train_dataset(
dname,
batch_size=args.micro_batch_size,
task_encoder=TaskEncoder(),
worker_config=worker_config,
virtual_epoch_length=1000,
max_samples_per_sequence=100,
shuffle_buffer_size=100,
handler=print_error_handler,
image_decode="pil",
)
val_datasets = get_val_datasets(
dname,
batch_size=args.micro_batch_size,
# This is the total number over all workers
# limit=args.eval_iters * get_num_microbatches(),
task_encoder=TaskEncoder(),
worker_config=worker_config,
handler=print_error_handler,
image_decode="pil",
)
val_datasets_without_source_datasets = [
# Limit the dataset to eval_iters * num_microbatches
LimitDataset(
# Repeat the inner dataset in case it's too short
RepeatDataset(val_ds, worker_config=worker_config),
length=args.eval_iters * get_num_microbatches(),
worker_config=worker_config,
reset_after_epoch=True,
)
for val_ds, _src_ds in val_datasets
]
return train_dataset, val_datasets_without_source_datasets, None
def train_valid_test_dataloaders_provider(train_val_test_num_samples):
"""Build multimodal train, validation and test dataloaders."""
args = get_args()
worker_debug_path = None
worker_log_level = 0
rank = mpu.get_data_parallel_rank()
world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
worker_config = WorkerConfig(
rank=rank,
world_size=world_size,
num_workers=args.num_workers,
data_parallel_group=data_parallel_group,
worker_debug_path=worker_debug_path,
worker_log_level=worker_log_level,
)
train_ds, valid_ds1, test_ds = datasets_provider(worker_config)
train_dataloader = get_savable_loader(train_ds, worker_config=worker_config)
if args.load is not None:
if hasattr(args, "dataloader_path"):
dp_rank = (
mpu.get_data_parallel_rank()
if torch.distributed.is_initialized()
else 0
)
data_save_name = get_checkpoint_name(
args.dataloader_path,
args.iteration,
save_basename=f"train_dataloader_dprank{dp_rank:03d}.pt",
)
try:
dataset_state_dict = torch.load(
data_save_name, map_location="cpu"
)
if (
"dataset_state_dict" in dataset_state_dict.keys()
and dataset_state_dict["train_data_path"]
!= args.train_data_path
):
print_rank_0(
f"Not restoring dataset state from {data_save_name}, path to dataset changed from {dataset_state_dict['train_data_path']} to {args.train_data_path}"
)
else:
train_dataloader.restore_state_rank(
dataset_state_dict["dataloader_state_dict"]
)
print_rank_0(
f"restoring dataset state from {data_save_name}"
)
except Exception as e:
print_rank_0(
"loading dataloader checkpoint failed. Skipping. " + str(e)
)
valid_dataloader = [
iter(cyclic_iter(get_loader(valid_ds, worker_config=worker_config)))
for valid_ds in valid_ds1
]
test_dataloader = None
return iter(cyclic_iter(train_dataloader)), valid_dataloader, iter(cyclic_iter(test_dataloader))
def cyclic_iter(iter):
while True:
for x in iter:
yield x
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import json
import random
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Any, List, Dict, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageDraw
from torchvision import transforms as T
from torchvision.transforms import Compose, RandAugment, RandomResizedCrop, Resize, ToPILImage
from megatron.core import mpu
from megatron.energon import Batch, CaptioningSample, DefaultTaskEncoder, OCRSample, VQASample
from megatron.energon.transforms import CustomTransform, MergeTransform
from megatron.training import get_args
from megatron.training.tokenizer import build_tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
# Imagenet's mean and std.
pixel_mean = [123.675, 116.28, 103.53]
pixel_std = [58.395, 57.12, 57.375]
def convert_to_rgb(image):
return image.convert("RGB")
def _transform_train(img_h, img_w):
return Compose([
ToPILImage(),
RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0)),
convert_to_rgb,
])
def _transform_train_aug(img_h, img_w):
return Compose([
ToPILImage(),
RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0)),
convert_to_rgb,
RandAugment(2, 5, isPIL=True, augs=['Identity', 'AutoContrast', 'Brightness', 'Sharpness', 'Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
])
def _transform_test(img_h, img_w):
return Compose([
ToPILImage(),
Resize((img_h, img_w)),
convert_to_rgb,
])
class RandomResize(CustomTransform):
"""Resizes the image by a random scale factor in the given interval, but at most max_size"""
def __init__(self, min_scale: float, max_scale: float, max_size: int):
self._min_scale = min_scale
self._max_scale = max_scale
self._max_size = max_size
def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]:
scale = random.uniform(self._min_scale, self._max_scale)
new_size = tuple(int(x * scale) for x in dst_size)
if max(new_size) > self._max_size:
scale = self._max_size / max(new_size)
new_size = tuple(int(x * scale) for x in dst_size)
matrix = self.scale(scale, scale) @ matrix
dst_size = np.array(new_size, dtype=dst_size.dtype)
return matrix, dst_size, (self.__class__.__name__, scale)
class RandomResizeLongEdge(CustomTransform):
"""Resizes the image's longer edge to a random length between min_size and max_size pixels."""
def __init__(self, min_size: int, max_size: int):
self._min_size = min_size
self._max_size = max_size
def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]:
new_long = random.randint(self._min_size, self._max_size)
if dst_size[0] > dst_size[1]: # h > w
new_w, new_h = int(new_long * dst_size[1] / dst_size[0]), new_long
else: # w > h
new_w, new_h = new_long, int(new_long * dst_size[0] / dst_size[1])
new_size = (new_h, new_w)
matrix = self.scale(new_w / dst_size[1], new_h / dst_size[0]) @ matrix
dst_size = np.array(new_size, dtype=dst_size.dtype)
return matrix, dst_size, (self.__class__.__name__, new_size)
class RandomPad(CustomTransform):
"""Pads the image to the given size, randomly choosing the position of the image within the new larger image.
If the image is already larger than the given size, it will not be padded in that direction(s)."""
def __init__(self, size: Tuple[int, int]):
self._new_size = size # h, w
def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]:
h_pad = max(self._new_size[0] - dst_size[0], 0)
w_pad = max(self._new_size[1] - dst_size[1], 0)
if h_pad == 0 and w_pad == 0:
return matrix, dst_size, (self.__class__.__name__, None)
else:
# TODO: fix me
# top = random.randint(0, h_pad)
# left = random.randint(0, w_pad)
top = 0
left = 0
matrix = self.translate(left, top) @ matrix
dst_size = np.array(self._new_size, dtype=dst_size.dtype)
return matrix, dst_size, (self.__class__.__name__, (top, left))
def _get_ocr_document_visual_transform(IMG_H=1024, IMG_W=1024):
document_visual_transform = T.Compose(
[
MergeTransform(
[
# T.RandomResizedCrop(size=FINAL_SIZE, scale=(0.5, 1.0), ratio=(0.8, 1.2)),
RandomResizeLongEdge(960, 1008), # Note: 1008 comes from list(range(960, 1024, 16))[-1]
T.RandomRotation(5, interpolation=T.InterpolationMode.BILINEAR),
T.RandomPerspective(distortion_scale=0.1, p=0.1),
RandomPad((IMG_H, IMG_W)),
]
),
T.ColorJitter(brightness=(0.8, 1.2), contrast=(0.7, 1.0)),
T.RandomGrayscale(p=0.5),
T.RandomInvert(p=0.5),
T.RandomAdjustSharpness(sharpness_factor=0.0, p=0.5),
T.RandomAdjustSharpness(sharpness_factor=2.0, p=0.5),
# LogImage(),
# T.ToTensor(),
# T.Normalize(IMAGE_MEAN, IMAGE_STD),
]
)
return document_visual_transform
def _get_ocr_document_identity_transform(IMG_H=1024, IMG_W=1024):
long_edge = max(IMG_H, IMG_W)
document_identity_transform = T.Compose(
[
MergeTransform(
[
RandomResizeLongEdge(long_edge, long_edge),
RandomPad((long_edge, long_edge)),
]
)
]
)
return document_identity_transform
def _get_ocr_paragraph_visual_transform(IMG_H=1024, IMG_W=1024):
paragraph_visual_transform = T.Compose(
[
MergeTransform(
[
# T.RandomResizedCrop(size=FINAL_SIZE, scale=(0.5, 1.0), ratio=(0.8, 1.2)),
RandomResize(0.5, 2.0, min(IMG_H, IMG_W)), #FINAL_SIZE),
T.RandomRotation(1, interpolation=T.InterpolationMode.BILINEAR),
T.RandomPerspective(distortion_scale=0.1, p=0.1),
RandomPad((IMG_H, IMG_W)),
]
),
T.ColorJitter(brightness=(0.8, 1.2), contrast=(0.7, 1.0)),
T.RandomGrayscale(p=0.5),
T.RandomInvert(p=0.5),
# T.RandomAdjustSharpness(sharpness_factor=0.0, p=0.5),
# T.RandomAdjustSharpness(sharpness_factor=2.0, p=0.5),
# LogImage(),
# T.ToTensor(),
# T.Normalize(IMAGE_MEAN, IMAGE_STD),
]
)
return paragraph_visual_transform
# Type for intermediate batch, after batch()
@dataclass
class ImageTaskSample:
__key__: str
__subflavors__: Dict
# (c, h, w)
img: torch.Tensor
text: np.ndarray
prompt_len: np.int64
img_clip: Optional[torch.Tensor] = None
# Typing for the resulting batch data after encode_batch()
@dataclass
class ImageTaskBatch(Batch):
__keys__: List[str]
__subflavors__: List[Dict]
# (n, c, h, w)
img: torch.Tensor
# (n, seq_len)
text: torch.Tensor
# (n, 1)
prompt_len: torch.Tensor
# (n, c, h, w)
img_clip: Optional[torch.Tensor] = None
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Tokenizer:
def __init__(self):
args = get_args()
self.args = args
self.IMAGE_TOKEN_INDEX = -200
self.initializer()
def initializer(self):
# Use Encoder class as a container for global data
Tokenizer.tokenizer = build_tokenizer(self.args)
self.eod_token = Tokenizer.tokenizer.eod
self.split_token = 313131
if (
hasattr(self.args, "split_sentences") and self.args.split_sentences
): # default false
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format("english")
# print("loading: " + library)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Tokenizer.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text=splitter._params, lang_vars=CustomLanguageVars()
)
else:
Tokenizer.splitter = splitter
else:
Tokenizer.splitter = IdentitySplitter()
def __call__(self, text: str, padded: bool = True): # -> torch.Tensor:
sentence = Tokenizer.splitter.tokenize(text)[0]
sentence = Tokenizer.tokenizer.tokenize(sentence)
return sentence
def pad(self, content, seq_len=1024):
out = np.pad(content, pad_width=(0,max(0,seq_len-len(content))), mode='constant', constant_values=self.eod_token)
return out
class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatch, dict]):
"""A simple task encoder for captioning."""
def __init__(
self
):
# Specify the batch_type for default batching (batching is performed here "manually" by
# overwriting the `batch` method)
super().__init__()
self.args = get_args()
self.tokenizer = Tokenizer()
self.manual_prompts = json.load(open(self.args.prompt_path))
self.seq_len = self.args.seq_length
self.txt_to_token_dict = {}
self.img_h, self.img_w = self.args.img_h, self.args.img_w
self.pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
self.pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
self.ocr_document_visual_transform = _get_ocr_document_visual_transform(self.img_h, self.img_w)
self.ocr_document_identity_transform = _get_ocr_document_identity_transform(self.img_h, self.img_w)
self.ocr_paragraph_visual_transform = _get_ocr_paragraph_visual_transform(self.img_h, self.img_w)
def get_visual_transform(self, img_sample, sample_augmentation=False):
raw_h, raw_w = img_sample.shape[0], img_sample.shape[1]
ratio = float(max(self.img_h, self.img_w)) / max(raw_h, raw_w)
scaled_h, scaled_w = int(raw_h * ratio + 0.5), int(raw_w * ratio + 0.5)
# if the sample needs augmentation or not
if sample_augmentation:
# further check if augmentation is a global flag in args
if self.args.aug:
visual_transform = _transform_train_aug(scaled_h, scaled_w)
else:
visual_transform = _transform_train(scaled_h, scaled_w)
else:
visual_transform = _transform_test(scaled_h, scaled_w)
img = visual_transform(img_sample)
# Normalize pixel values.
img = (torch.Tensor(np.array(img)).permute(2, 0, 1) - self.pixel_mean) / self.pixel_std
# Pad to target image size.
delta_h, delta_w = self.img_h - scaled_h, self.img_w - scaled_w
img = torch.nn.functional.pad(img, (0, delta_w, 0, delta_h))
return img
def encode_sample(self, sample: Union[
CaptioningSample, OCRSample, VQASample]
):
if isinstance(sample, OCRSample):
yield self.encode_ocr(sample)
elif isinstance(sample, CaptioningSample):
yield self.encode_captioning(sample)
elif isinstance(sample, VQASample):
yield self.encode_vqa(sample)
else:
raise NotImplementedError('Sample format not supported')
yield None
def encode_captioning(self, sample: CaptioningSample):
sample_augmentation = sample.__subflavors__["augmentation"] == True
img = self.get_visual_transform(np.array(sample.image), sample_augmentation=sample_augmentation)
# randomly select a prompt
if 'CaptioningDetailed' in sample.__subflavors__["type"]:
prompt_idx = np.random.randint(len(self.manual_prompts["CaptioningDetailed"]["raw"]))
cur_prompt = self.manual_prompts["CaptioningDetailed"]["raw"][prompt_idx]
else:
prompt_idx = np.random.randint(len(self.manual_prompts["Captioning"]["raw"]))
cur_prompt = self.manual_prompts["Captioning"]["raw"][prompt_idx]
if cur_prompt not in self.txt_to_token_dict:
self.txt_to_token_dict[cur_prompt] = self.tokenizer(cur_prompt)
cur_prompt = self.txt_to_token_dict[cur_prompt]
prompt_len = len(cur_prompt)
caption = sample.caption
if 'SplitByLine' in sample.__subflavors__["type"]:
# caption = re.sub(r"\n+", "\n", caption)
caption_list = caption.split('\n')
caption_list = [caption for caption in caption_list if caption.strip() != '']
caption = np.random.choice(caption_list)
caption_token = self.tokenizer(caption.strip())
if len(caption.strip()) == 0:
raise RuntimeError('Empty string in caption!')
seq_len = self.seq_len + 4
text_sample = np.concatenate([[self.tokenizer.IMAGE_TOKEN_INDEX], cur_prompt, caption_token])
text_sample = self.tokenizer.pad(text_sample, seq_len)
text_sample = text_sample[:seq_len]
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
img=img,
text=text_sample,
prompt_len=prompt_len
)
def encode_vqa(self, sample: VQASample):
task_name = None
no_image_flag = True if '-noimage' in sample.__key__ else False
if 'pretrain' in sample.__key__:
task_name = 'pretrain'
else:
task_name = sample.__key__.split("/")[0]
sample_augmentation = sample.__subflavors__["augmentation"] == True
if no_image_flag:
img = torch.from_numpy(np.array([0]).astype(np.float32))
else:
img = self.get_visual_transform(np.array(sample.image), sample_augmentation=sample_augmentation)
if "<image>" in sample.context:
sample.context = sample.context.replace("<image>","")
if task_name != 'pretrain' and sample.context[-1:] != "\n":
sample.context = sample.context + "\n"
question_token = self.tokenizer(sample.context)
if isinstance(sample.answers, list):
answer_list = sample.answers
weight_list = np.array(sample.answer_weights).astype(np.float32)
weight_list = weight_list / np.sum(weight_list)
answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0]
answer = answer_list[answer_idx]
answer_token = self.tokenizer(answer)
else:
answer_token = self.tokenizer(sample.answers)
prompt_len = len(question_token)
seq_len = self.seq_len + 4
text_sample = np.concatenate([[self.tokenizer.IMAGE_TOKEN_INDEX], question_token, answer_token])
text_sample = self.tokenizer.pad(text_sample, seq_len)
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
img=img,
text=text_sample,
prompt_len=prompt_len
)
def encode_ocr(self, sample: OCRSample) -> ImageTaskSample:
if sample.__subflavors__["type"] == "document":
visual_transform = self.ocr_document_visual_transform
elif sample.__subflavors__["type"] == "paragraph":
visual_transform = self.ocr_paragraph_visual_transform
elif sample.__subflavors__["augmentation"] == False:
visual_transform = self.ocr_document_identity_transform
else:
raise ValueError(f"Unknown subflavor {sample.__subflavors__}")
if sample.words_boxes is not None and sample.words_boxes.shape[1] >= 5:
# Boxes with conf below 0.9 are skipped
filter_words_mask = sample.words_boxes[:, 4] < 0.9
filter_boxes = sample.words_boxes[filter_words_mask, :4]
for x, y, x2, y2 in filter_boxes:
if isinstance(sample.image, Image.Image):
draw = ImageDraw.Draw(sample.image)
draw.rectangle([int(x), int(y), (int(x2), int(y2))], fill=0)
else:
sample.image[:, int(y) : int(y2) + 1, int(x) : int(x2) + 1] = 0
text = " ".join(
text for skip, text in zip(filter_words_mask, sample.words_text) if not skip
)
else:
text = " ".join(sample.text.splitlines())
match = re.search(r'"text_sequence": "(.*?)"', text)
if match:
text = match.group(1)
img = visual_transform(sample.image)
img_clip = None
img = (torch.Tensor(np.array(img)).permute(2, 0, 1) - self.pixel_mean) / self.pixel_std
img = torch.nn.functional.pad(img, (0, self.img_w - img.shape[2], 0, self.img_h - img.shape[1]))
# randomly select a prompt
prompt_idx = np.random.randint(len(self.manual_prompts["OCR"]["raw"]))
cur_prompt = self.manual_prompts["OCR"]["raw"][prompt_idx]
if cur_prompt not in self.txt_to_token_dict:
self.txt_to_token_dict[cur_prompt] = self.tokenizer(cur_prompt)
cur_prompt = self.txt_to_token_dict[cur_prompt]
text_sample = self.tokenizer(text)
prompt_len = len(cur_prompt)
seq_len = self.seq_len + 4
text_sample = np.concatenate([cur_prompt, text_sample])
text_sample = self.tokenizer.pad(text_sample, seq_len=seq_len)
text_sample = text_sample[:seq_len]
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
img=img,
img_clip=img_clip,
text=text_sample,
prompt_len=prompt_len
)
def batch(self, samples: List[ImageTaskSample]) -> ImageTaskBatch:
batch = ImageTaskBatch(
__keys__=[s.__key__ for s in samples],
__subflavors__=[s.__subflavors__ for s in samples],
img=torch.stack([s.img for s in samples]),
text=torch.from_numpy(np.stack([s.text for s in samples], axis=0).astype(np.int64)),
prompt_len=torch.from_numpy(np.array([s.prompt_len for s in samples], dtype=np.int64))
)
return batch
def encode_batch(self, batch: ImageTaskBatch) -> dict:
raw = dataclasses.asdict(batch)
del raw["__subflavors__"]
return raw
def print_error_handler(exc: Exception, key: Optional[str]):
print(
f"The following exception occurred in the dataloader for sample {key} and is skipped",
file=sys.stderr,
)
traceback.print_exc()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import argparse
import glob
import json
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO
def convert_to_coco_format(input_path):
"""Convert input files to COCO compatible format."""
output_file_path = input_path + "-captioning-merged.json"
pattern = input_path + "-captioning-[0-9].*jsonl"
input_file_paths = glob.glob(pattern)
captions = []
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
question_id = res['sample_id']
caption = res['caption'].rstrip('.').lower()
captions.append({"image_id": question_id, "caption": caption})
with open(output_file_path, "w") as output_file:
json.dump(captions, output_file)
return output_file_path
def coco_captioning_eval(input_path, groundtruth_file):
"""Run COCO captioning evaluation."""
coco = COCO(groundtruth_file)
input_file = convert_to_coco_format(input_path)
coco_result = coco.loadRes(input_file)
coco_eval = COCOEvalCap(coco, coco_result)
# Evaluate on the input subset of images.
coco_eval.params['image_id'] = coco_result.getImgIds()
coco_eval.evaluate()
for metric, score in coco_eval.eval.items():
print(metric, score)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)")
parser.add_argument(
"--groundtruth-path", type=str, required=True, help="Path to groundtruth file"
)
args = parser.parse_args()
coco_captioning_eval(args.input_path, args.groundtruth_path)
import argparse
import glob
import json
import subprocess
def convert_to_mmmu_format(input_path):
"""Convert input files to MMMU compatible format."""
output_file_path = input_path + "-MMMU-merged.json"
pattern = input_path + "-MMMU-[0-9].*jsonl"
input_file_paths = glob.glob(pattern)
output = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
sample_id = res["sample_id"]
prediction = res["prediction"]
output[sample_id] = prediction
with open(output_file_path, "w") as output_file:
json.dump(output, output_file)
return output_file_path
def main():
# Using the validation groundtruth file from the MMMU repo by default. This assumes you have cloned the MMMU github repo here.
default_groundtruth_path = "examples/multimodal/MMMU/eval/answer_dict_val.json"
parser = argparse.ArgumentParser()
parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)")
parser.add_argument(
"--groundtruth-path",
type=str,
default=default_groundtruth_path,
help="Path to groundtruth file. Defaults to the validation file in the MMMU repo.",
)
args = parser.parse_args()
result_file = convert_to_mmmu_format(args.input_path)
# The MMMU repo has a script for running the actual evaluation but no API. So launching the script here.
output = subprocess.run(
[
"python",
"examples/multimodal/MMMU/eval/main_eval_only.py",
"--output_path",
result_file,
"--answer_path",
default_groundtruth_path,
],
capture_output=True,
text=True,
)
print(output.stdout)
if __name__ == "__main__":
main()
import argparse
import glob
import json
import re
# This can help resolve an import error of an mmf dependency that is not needed.
try:
from mmf.utils.m4c_evaluators import TextVQAAccuracyEvaluator
except ModuleNotFoundError:
from mmf.utils.m4c_evaluators import TextVQAAccuracyEvaluator
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
output_file_path = input_path + "-TextVQA-merged.json"
pattern = input_path + "-TextVQA-[0-9].*jsonl"
input_file_paths = glob.glob(pattern)
results = []
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(res)
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
return output_file_path
# Note: This is based on https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/eval/eval_textvqa.py#L17
# and slightly modified.
def prompt_processor(prompt):
if prompt.startswith('OCR tokens: '):
pattern = r"Question: (.*?) Short answer:"
match = re.search(pattern, prompt, re.DOTALL)
question = match.group(1)
elif "Reference OCR token: " in prompt and len(prompt.split("\n")) == 3:
if prompt.startswith("Reference OCR token:"):
question = prompt.split("\n")[1]
else:
question = prompt.split("\n")[0]
elif len(prompt.split("\n")) == 2:
question = prompt.split("\n")[0]
else:
raise RuntimeError("unexpected prompt format")
return question.lower()
# Note: This is based on https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/eval/eval_textvqa.py#L35
# and slightly modified.
def evaluate(result_file_path, groundtruth_path):
with open(groundtruth_path) as groundtruth_file:
groundtruth = json.load(groundtruth_file)["data"]
groundtruth = {(gt["image_id"], gt["question"].lower()): gt["answers"] for gt in groundtruth}
with open(result_file_path, "r") as result_file:
results = json.load(result_file)
predictions = []
for result in results:
gt_answers = groundtruth[(result["sample_id"], prompt_processor(result["prompt"]))]
predictions.append({"pred_answer": result["text"], "gt_answers": gt_answers})
evaluator = TextVQAAccuracyEvaluator()
print(
'Samples: {}\nAccuracy: {:.2f}%\n'.format(
len(predictions), 100.0 * evaluator.eval_pred_list(predictions)
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
parser.add_argument('--groundtruth-path', type=str, help="Path to groundtruth file")
args = parser.parse_args()
result_file_path = merge_input_files(args.input_path)
evaluate(result_file_path, args.groundtruth_path)
import argparse
import glob
import json
from open_flamingo.eval.vqa_metric import compute_vqa_accuracy
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
output_file_path = input_path + "-VQAv2-merged.json"
pattern = input_path + "-VQAv2-[0-9].*jsonl"
input_file_paths = glob.glob(pattern)
results = []
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
res["question_id"] = res["sample_id"]
results.append(res)
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
return output_file_path
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
parser.add_argument('--groundtruth-path', type=str, help="Path to groundtruth file")
parser.add_argument('--question-path', type=str, help="Path to questions file")
args = parser.parse_args()
result_file = merge_input_files(args.input_path)
accuracy = compute_vqa_accuracy(result_file, args.question_path, args.groundtruth_path)
print(accuracy)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TEColumnParallelLinear,
TELayerNormColumnParallelLinear,
TEColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
class TorchLayerNormWrapper(torch.nn.LayerNorm):
def __init__(self, config, hidden_size, eps):
super().__init__(hidden_size, eps)
def get_layer_spec(is_vit=False) -> ModuleSpec:
mlp = get_mlp_module_spec(use_te=False)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm if not is_vit else TorchLayerNormWrapper,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm if not is_vit else TorchLayerNormWrapper,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_layer_spec_te(is_vit=False) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
mlp = get_mlp_module_spec_te()
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
def get_mlp_module_spec_te() -> ModuleSpec:
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear,
linear_fc2=TERowParallelLinear,
),
)
\ No newline at end of file
{
"Captioning": {
"raw": [
"Can you briefly explain what you see in the image?",
"Describe what's happening in this image in one short sentence.",
"Write a short caption that accurately represents the content of this image.",
"Please generate a descriptive caption for the image provided.",
"How would you summarize the scene depicted in the picture in short?"
]
},
"OCR": {
"raw": [
"Can you read the text from image and output here?",
"Extract and document the text from the provided image.",
"Converting the text embedded in this image into a readable document.",
"Transcribe all the text you find.",
"Can you extract all visible text from the image here?"
]
},
"VQA": {
"raw": [
"Given the image, answer the following question with few words.",
"Answer the following question: ",
"What is the answer to this question?",
"Write the answer: ",
"Please answer this question: "
]
}
}
#!/bin/bash
# Pretrain a multimodal model.
export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
DATETIME=`date +'%y-%m-%d-%H-%M-%S'`
MODEL_NAME="mcore-llava-8b-${DATETIME}"
# Check that the user has set an output path for model checkpoints.
if [[ -z $WORKSPACE ]]; then
echo "Please set WORKSPACE for storing your model checkpoints."
exit 1
fi
SOURCE=`pwd`
OUTPUT_BASE="${WORKSPACE}/output"
OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}"
FINETUNE_DIR=${OUTPUT}/checkpoints
LOGS_DIR="${OUTPUT}/logs"
TENSORBOARD_DIR="${OUTPUT}/tensorboard"
if [[ -z $LOAD_NAME ]]; then
echo "Please set LOAD_NAME for input model name."
exit 1
fi
if [[ -z $TOKENIZER_MODEL ]]; then
echo "Please set TOKENIZER_MODEL for tokenizer model name."
exit 1
fi
CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}"
DATA_TRAIN="${SOURCE}/examples/multimodal/pretrain_dataset.yaml"
DATA_VALID="${SOURCE}/examples/multimodal/pretrain_dataset.yaml"
DEBUG=1
if [[ $DEBUG -eq 1 ]]; then
BZ=8
NW=1
HD=0.0
LI=1
EXTRA_ARGS=""
NONDETERMINISTIC_ATTN=0
else
BZ=256
NW=2
HD=0.1
LI=10
EXTRA_ARGS=""
NONDETERMINISTIC_ATTN=1
fi
OPTIONS=" \
--num-workers ${NW} \
--exit-duration-in-mins 230 \
--use-flash-attn \
--apply-layernorm-1p \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--position-embedding-type rope \
--rotary-percent 0.5 \
--squared-relu \
--attention-dropout 0.0 \
--hidden-dropout ${HD} \
--tensor-model-parallel-size 4 \
--pipeline-model-parallel-size 1 \
--num-layers 32 \
--hidden-size 4096 \
--num-attention-heads 32 \
--seq-length 1024 \
--max-position-embeddings 4096 \
--train-samples 410000 \
--micro-batch-size 1 \
--global-batch-size ${BZ} \
--lr-decay-samples 25600000 \
--lr-warmup-samples 83200 \
--lr 1e-5 \
--min-lr 2.5e-6 \
--lr-decay-style cosine \
--log-interval ${LI} \
--eval-iters 10 \
--eval-interval 1000 \
--tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model ${WORKSPACE}/${TOKENIZER_MODEL} \
--data-path ${DATA_TRAIN} \
--valid-path ${DATA_VALID} \
--prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \
--save-interval 1000 \
--save ${FINETUNE_DIR} \
--load ${CHECKPOINT_DIR} \
--split 100,0,0 \
--clip-grad 0.5 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.014 \
--log-params-norm \
--log-num-zeros-in-grad \
--bf16 \
--eod-mask-loss \
--finetune \
--freeze-LM \
--freeze-ViT \
--patch-dim 14 \
--img-h 336 \
--img-w 336 \
--dataloader-type external \
--tensorboard-dir ${TENSORBOARD_DIR} \
--language-model-type=8b \
--disable-vision-class-token \
${EXTRA_ARGS} \
--distributed-timeout-minutes 60 \
--allow-missing-vision-projection-checkpoint \
--use-te
"
export NVTE_APPLY_QK_LAYER_SCALING=1
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN}
# MULTI GPU
torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS}
__module__: megatron.energon
__class__: Metadataset
splits:
train:
datasets:
- weight: 1.
path: /workspace/data/pretrain/train/dataset
subflavors:
augmentation: false
val:
datasets:
- weight: 1.
path: /workspace/data/pretrain/validation/dataset
subflavors:
augmentation: false
\ No newline at end of file
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Generate text using a vision language model."""
import glob
import json
import logging
import os
import sys
from collections import defaultdict
from functools import partial
# Add megatron to the path.
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, ToPILImage
from train import add_multimodal_extra_args, get_image_token_count, model_provider
from megatron.inference.text_generation.api import generate_and_post_process
from megatron.inference.text_generation.forward_step import ForwardStep
from megatron.training import get_args, get_model, print_rank_0
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron
def add_text_generation_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='Vision language model text generation')
group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.')
group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0, help='Top k sampling.')
group.add_argument(
"--out-seq-length", type=int, default=1024, help='Length of the output generated text.'
)
group.add_argument("--output-path", type=str, required=True, help='Output file path')
group.add_argument('--input-image-path', type=str, required=True, help="Input image directory")
group.add_argument('--input-metadata-path', type=str, help="Input metadata path")
group.add_argument(
'--num-partitions', type=int, default=0, help="Number of partitions for inputs."
)
group.add_argument('--partition-id', type=int, default=0, help="Partition index")
group.add_argument("--drop-vision-class-token", action="store_true", default=False)
group.add_argument("--gt-path", type=str, help="Optional ground truth file")
group.add_argument("--task", type=str, help="Generation task to run")
# Add common multimodal arguments needed for e.g. building the model.
parser = add_multimodal_extra_args(parser)
return parser
def preprocess_image(target_h, target_w, img):
"""Example image preprocessing. Resizes input image to target size.
Args:
target_h (int): Target height in pixels.
target_w (int): Target width in pixels
img (np.array [h, w, c]): Input image in a numpy array.
Returns:
output_img (torch.Tensor [c, h, w]): Input image resized to target size.
"""
# Imagenet's mean and std for normalization.
pixel_mean = [123.675, 116.28, 103.53]
pixel_std = [58.395, 57.12, 57.375]
pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
# Resize image considering ratio between input and target image sizes.
img_h, img_w = img.shape[0], img.shape[1]
ratio = float(max(target_h, target_w)) / max(img_h, img_w)
scaled_h, scaled_w = int(img_h * ratio + 0.5), int(img_w * ratio + 0.5)
image_transform = Compose(
[ToPILImage(), Resize((scaled_h, scaled_w)), lambda x: x.convert("RGB")]
)
img = image_transform(img)
# Normalize pixel values.
img = (torch.Tensor(np.array(img)).permute(2, 0, 1) - pixel_mean) / pixel_std
# Pad to target size.
delta_h, delta_w = target_h - scaled_h, target_w - scaled_w
output_img = torch.nn.functional.pad(img, (0, delta_w, 0, delta_h))
return output_img
def _get_partition_bounds(total_num_samples, num_partitions, partition_id):
samples_per_partition = total_num_samples // num_partitions
return samples_per_partition * partition_id, samples_per_partition * (partition_id + 1)
def generate_samples(model):
"""Text generation using a trained vision language model."""
args = get_args()
images = []
questions, answers = [], []
samples, sample_ids = [], []
if args.task in ("TextVQA", "VQAv2"):
input_metadata_path = args.input_metadata_path
if input_metadata_path.endswith(".json"):
samples = json.load(open(input_metadata_path))
elif input_metadata_path.endswith(".jsonl"):
with open(input_metadata_path, 'r') as jsonl_file:
json_list = list(jsonl_file)
samples = [json.loads(json_str) for json_str in json_list]
else:
return NotImplementedError
# Optionally, process only a subset of the input files.
if args.num_partitions > 0:
lb, ub = _get_partition_bounds(len(samples), args.num_partitions, args.partition_id)
samples = samples[lb:ub]
num_samples = len(samples)
for i in range(len(samples)):
sample = samples[i]
img_file = "{}/{}".format(args.input_image_path, sample["image"])
img_sample = np.array(Image.open(img_file))
processed_img = preprocess_image(args.img_h, args.img_w, img_sample)
images.append(processed_img.reshape(-1, 3, args.img_h, args.img_w))
if args.task == "VQAv2":
questions.append(sample["question"])
answers.append(sample["answer"])
elif args.task == 'TextVQA':
questions.append(sample["text"])
sample_ids.append(sample["question_id"])
if len(images) == num_samples:
break
elif args.task == "captioning":
image_files = sorted(glob.glob(args.input_image_path + "/*"))
# Optionally, process only a subset of the input files.
if args.num_partitions > 0:
lb, ub = _get_partition_bounds(len(image_files), args.num_partitions, args.partition_id)
image_files = image_files[lb:ub]
num_samples = len(image_files)
images = []
# Run image preprocessing.
for image_file in image_files:
img = np.array(Image.open(image_file))
img = preprocess_image(args.img_h, args.img_w, img)
images.append(img.reshape(-1, 3, args.img_h, args.img_w))
image_id = int(image_file.split("_")[-1].split(".")[0])
sample_ids.append(image_id)
# Load optional ground truth.
gt_sample_id_to_captions = defaultdict(list)
if args.gt_path:
gts = json.load(open(args.gt_path))
for gt in gts["annotations"]:
gt_sample_id_to_captions[gt["image_id"]].append(gt['caption'])
elif args.task == 'MMMU':
# The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation.
import datasets
from evaluation.MMMU.eval.utils.data_utils import (
CAT_SHORT2LONG,
construct_prompt,
load_yaml,
process_single_sample,
)
all_mmmu_datasets = []
hf_datasets_cache = os.environ["HF_DATASETS_CACHE"]
assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE."
for subject in CAT_SHORT2LONG.values():
subject_dataset = datasets.load_dataset(
"MMMU/MMMU", subject, split=datasets.Split.VALIDATION, cache_dir=hf_datasets_cache
)
all_mmmu_datasets.append(subject_dataset)
dataset = datasets.concatenate_datasets(all_mmmu_datasets)
# Optionally, process only a subset of the input files.
start_idx = 0
end_idx = len(dataset)
if args.num_partitions > 0:
start_idx, end_idx = _get_partition_bounds(
len(dataset), args.num_partitions, args.partition_id
)
# Using the LLaVA config from the MMMU repo.
config = load_yaml("evaluation/MMMU/eval/configs/llava1.5.yaml")
for k, v in config.items():
if isinstance(v, list):
assert len(v) == 1, "only one value supported."
config[k] = v[0]
for idx in range(start_idx, end_idx):
sample = dataset[idx]
sample = process_single_sample(sample)
sample = construct_prompt(sample, config)
# Skip samples with no images or multiple images. Not supported yet.
if "image" not in sample or "<image 2>" in sample['final_input_prompt']:
continue
img = np.array(sample['image'].convert("RGB"))
img = preprocess_image(args.img_h, args.img_w, img)
images.append(img.reshape(-1, 3, args.img_h, args.img_w))
sample_ids.append(sample['id'])
# TODO: Support different image positions.
prompt = sample['final_input_prompt']
prompt = prompt.replace("<image 1>", "")
questions.append(prompt.strip())
answers.append(sample['answer'])
samples.append(sample)
num_samples = len(samples)
else:
raise NotImplementedError("unsupported task")
idx = 0
while idx < num_samples:
image = images[idx].cuda()
sample_id = sample_ids[idx]
if args.task == "captioning":
prompt = "Give a short and clear explanation of the subsequent image.\n"
elif args.task == "TextVQA":
prompt = questions[idx]
elif args.task == "VQAv2":
prompt = questions[idx]
prompt += "\nAnswer the question using a single word or phrase."
elif args.task == "MMMU":
prompt = questions[idx]
forward_step = partial(VLMForwardStep, image, get_image_token_count())
if torch.distributed.get_rank() == 0:
resp_sentences, _, _, _ = generate_and_post_process(
model,
forward_step=forward_step,
prompts=[prompt],
tokens_to_generate=args.out_seq_length,
return_output_log_probs=False,
top_k_sampling=args.top_k,
top_p_sampling=args.top_p,
add_BOS=False,
temperature=args.temperature,
random_seed=123,
)
for prompt, generation in zip([prompt], resp_sentences):
output = {
"sample_id": sample_id,
"prompt": prompt,
}
output_name = ""
if args.task == "captioning":
output_name = "caption"
elif args.task == "VQAv2":
output_name = "answer"
elif args.task in ("TextVQA", "MMMU"):
output_name = "text"
generated = generation[len(prompt) :]
output[output_name] = generated
if args.task == "captioning":
output["ground_truth"] = gt_sample_id_to_captions[sample_id]
elif args.task == "VQAv2":
output["ground_truth"] = answers[idx]
elif args.task == "MMMU":
sample = samples[idx]
prediction = generated
if sample["question_type"] == "multiple-choice":
from evaluation.MMMU.eval.utils.eval_utils import (
parse_multi_choice_response,
)
prediction = parse_multi_choice_response(
generated, sample["all_choices"], sample["index2ans"]
)
output["prediction"] = prediction
print_rank_0(output)
yield output
idx += 1
else:
generate_and_post_process(model, forward_step=forward_step)
idx += 1
def generate_and_write_samples(model):
args = get_args()
for output in generate_samples(model):
if torch.distributed.get_rank() == 0:
with open(args.output_path, 'a') as f:
f.write(json.dumps(output) + "\n")
class VLMForwardStep(ForwardStep):
def __init__(self, images, num_image_tokens, model, max_batch_size, max_sequence_length):
super().__init__(model, max_batch_size, max_sequence_length + num_image_tokens)
self._images = images
def _forward(self, tokens, position_ids, attention_mask):
return self.model(
self._images,
tokens,
position_ids,
attention_mask=None,
inference_params=self.inference_params,
)
def __call__(self, tokens, position_ids, attention_mask):
logits = super().__call__(tokens, position_ids, attention_mask)
# On the first inference iteration, we compute image tokens.
# Update the sequence length offset by the number of image tokens.
num_tokens = tokens.size(1)
if num_tokens > 1:
self.inference_params.sequence_len_offset += self.inference_params.key_value_memory_dict[
"image_tokens_count"
]
return logits
def main():
"""Vision language model text generation."""
logging.getLogger(__name__).warning("Models using pipeline parallelism are not supported yet.")
initialize_megatron(extra_args_provider=add_text_generation_args)
def wrapped_model_provider(pre_process, post_process):
return model_provider(pre_process, post_process, parallel_output=False)
# Set up model and load checkpoint.
model = get_model(wrapped_model_provider, wrap_with_ddp=False)
args = get_args()
if args.load is not None:
_ = load_checkpoint(model, None, None)
model = model[0]
model.eval()
generate_and_write_samples(model)
if __name__ == "__main__":
main()
#!/bin/bash
# Run SFT on a pretrained multimodal model.
export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
DATETIME=`date +'%y-%m-%d-%H-%M-%S'`
MODEL_NAME="mcore-llava-sft-${DATETIME}"
# Check that the user has set an output path for model checkpoints.
if [[ -z $WORKSPACE ]]; then
echo "Please set WORKSPACE for storing your model checkpoints."
exit 1
fi
SOURCE=`pwd`
OUTPUT_BASE="${WORKSPACE}/output"
OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}"
FINETUNE_DIR=${OUTPUT}/checkpoints
LOGS_DIR="${OUTPUT}/logs"
TENSORBOARD_DIR="${OUTPUT}/tensorboard"
if [[ -z $LOAD_NAME ]]; then
echo "Please set LOAD_NAME for input model name."
exit 1
fi
if [[ -z $TOKENIZER_MODEL ]]; then
echo "Please set TOKENIZER_MODEL for tokenizer model name."
exit 1
fi
CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints"
DATA_TRAIN="${SOURCE}/examples/multimodal/sft_dataset.yaml"
DATA_VALID="${SOURCE}/examples/multimodal/sft_dataset.yaml"
DEBUG=0
if [[ $DEBUG -eq 1 ]]; then
BZ=8
NW=1
LI=1
HD=0.0
EXTRA_ARGS=""
else
BZ=128
NW=1
LI=10
HD=0.1
EXTRA_ARGS=""
fi
OPTIONS=" \
--num-workers ${NW} \
--use-flash-attn \
--apply-layernorm-1p \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--position-embedding-type rope \
--rotary-percent 0.5 \
--squared-relu \
--attention-dropout 0.0 \
--hidden-dropout ${HD} \
--tensor-model-parallel-size 4 \
--pipeline-model-parallel-size 1 \
--num-layers 32 \
--hidden-size 4096 \
--num-attention-heads 32 \
--seq-length 1024 \
--max-position-embeddings 4096 \
--train-samples 665000 \
--micro-batch-size 1 \
--global-batch-size ${BZ} \
--lr-decay-samples 25600000 \
--lr-warmup-samples 83200 \
--lr 1e-6 \
--min-lr 1e-7 \
--lr-decay-style cosine \
--log-interval ${LI} \
--eval-iters 10 \
--eval-interval 1000 \
--tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model ${WORKSPACE}/${TOKENIZER_MODEL} \
--data-path ${DATA_TRAIN} \
--valid-path ${DATA_VALID} \
--prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \
--save-interval 1000 \
--exit-duration-in-mins 230 \
--save ${FINETUNE_DIR} \
--load ${CHECKPOINT_DIR} \
--split 100,0,0 \
--clip-grad 0.5 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.014 \
--log-params-norm \
--log-num-zeros-in-grad \
--bf16 \
--eod-mask-loss \
--finetune \
--freeze-ViT \
--patch-dim 14 \
--img-h 336 \
--img-w 336 \
--dataloader-type external \
--tensorboard-dir ${TENSORBOARD_DIR} \
--language-model-type=8b \
--disable-vision-class-token \
${EXTRA_ARGS} \
--distributed-timeout-minutes 60 \
"
export NVTE_APPLY_QK_LAYER_SCALING=1
# MULTI GPU
torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS}
__module__: megatron.energon
__class__: Metadataset
splits:
train:
datasets:
- weight: 1.
path: /workspace/data/sft/train/dataset
subflavors:
augmentation: false
val:
datasets:
- weight: 1.
path: /workspace/data/sft/validation/dataset
subflavors:
augmentation: false
\ No newline at end of file
#!/bin/bash
export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_APPLY_QK_LAYER_SCALING=1
INPUT_METADATA_PATH="placeholder"
GROUNDTRUTH_PATH="placeholder"
while [[ $# -gt 0 ]]; do
case $1 in
--input-image-path)
INPUT_IMAGE_PATH="$2"
shift
shift
;;
--input-metadata-path)
INPUT_METADATA_PATH="$2"
shift
shift
;;
-g|--groundtruth-path)
GROUNDTRUTH_PATH="$2"
shift
shift
;;
-o|--output-path)
OUTPUT_PATH="$2"
shift
shift
;;
-m|--model-path)
MODEL_PATH="$2"
shift
shift
;;
-t|--tokenizer-path)
TOKENIZER_PATH="$2"
shift
shift
;;
--task)
TASK="$2"
shift
shift
;;
-g|--gt-path)
GROUNDTRUTH_PATH="$2"
shift
shift
;;
-*|--*)
echo "Invalid option $1"
exit 1
;;
esac
done
# Please modify these as needed.
NUM_PARTITIONS=100
START=0
END=2
for PARTITION_ID in $( eval echo {$START..$END} )
do
torchrun --nproc_per_node 4 examples/multimodal/run_text_generation.py \
--use-flash-attn \
--language-model-type 8b \
--apply-layernorm-1p \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--position-embedding-type rope \
--rotary-percent 0.5 \
--squared-relu \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--tensor-model-parallel-size 4 \
--pipeline-model-parallel-size 1 \
--num-layers 32 \
--hidden-size 4096 \
--num-attention-heads 32 \
--max-position-embeddings 4096 \
--no-masked-softmax-fusion \
--load ${MODEL_PATH} \
--tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model ${TOKENIZER_PATH} \
--bf16 \
--micro-batch-size 1 \
--seq-length 99 \
--out-seq-length 700 \
--temperature 1.0 \
--img-h 336 \
--img-w 336 \
--patch-dim 14 \
--seed 153 \
--top_k 1 \
--disable-vision-class-token \
--no-load-rng \
--no-load-optim \
--input-path ${INPUT_PATH} \
--num-partitions ${NUM_PARTITIONS} \
--partition-id ${PARTITION_ID} \
--output-path ${OUTPUT_PATH}/${PART_ID}.jsonl \
--gt-path ${GROUNDTRUTH_PATH}
done
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Pretrain or SFT multimodal."""
from copy import deepcopy
from functools import partial
import os
import sys
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
from megatron.training import get_args, get_timers, get_tokenizer, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from config import get_language_model_config, get_vision_model_config, get_vision_projection_config
from megatron.core.models.multimodal.llava_model import LLaVAModel
from layer_specs import get_layer_spec, get_mlp_module_spec, get_layer_spec_te
from megatron.training import pretrain
from megatron.training.utils import average_losses_across_data_parallel_group
from dataloader_provider import train_valid_test_dataloaders_provider
def model_provider(pre_process=True, post_process=True, parallel_output=True) -> LLaVAModel:
"""Builds the model.
Args:
pre_process (bool): Enable preprocessing in the model. NOTE: Not used at the moment.
post_process (bool): Enable postprocessing in the model. NOTE: Not used at the moment.
parallel_output (bool): Enable parallel model output.
Returns:
model: A multimodal model.
"""
args = get_args()
use_te = args.use_te
print_rank_0('building a multimodal model ...')
base_config = core_transformer_config_from_args(get_args())
base_config.language_model_type = args.language_model_type
language_config = deepcopy(base_config)
language_config = get_language_model_config(language_config)
if use_te:
language_transformer_layer_spec = get_layer_spec_te(is_vit=False)
else:
language_transformer_layer_spec = get_layer_spec(is_vit=False)
vision_config = deepcopy(base_config)
vision_config = get_vision_model_config(vision_config, apply_query_key_layer_scaling=use_te)
if use_te:
vision_transformer_layer_spec = get_layer_spec_te(is_vit=True)
else:
vision_transformer_layer_spec = get_layer_spec(is_vit=True)
vision_projection_config = deepcopy(base_config)
vision_projection_config = get_vision_projection_config(vision_projection_config, language_config.hidden_size)
vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules
model = LLaVAModel(
language_transformer_config=language_config,
language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.max_position_embeddings,
vision_transformer_config=vision_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.disable_vision_class_token,
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_layer_spec,
vision_projection_type="mlp",
allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint,
parallel_output=parallel_output,
language_position_embedding_type=args.position_embedding_type,
language_rotary_percent=args.rotary_percent,
)
model.freeze(freeze_language_model=args.freeze_LM, freeze_vision_model=args.freeze_ViT, freeze_vision_projection=False)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokens = None
labels = None
loss_mask = None
attention_mask = None
position_ids = None
# Broadcast data.
torch.cuda.nvtx.range_push("get_data")
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_text = tensor_parallel.broadcast_data(["text"], data, torch.int64)["text"]
data_img = tensor_parallel.broadcast_data(["img"], data, torch.float32)
prompt_len = tensor_parallel.broadcast_data(["prompt_len"], data, torch.int64)["prompt_len"]
torch.cuda.nvtx.range_pop()
tokens_ = data_text.long()
img_raw = data_img['img'].reshape(-1, 3, args.img_h, args.img_w)
torch.cuda.nvtx.range_push("index tokens")
tokenizer = get_tokenizer()
tokens = tokens_[:, :args.seq_length].contiguous()
labels = tokens_[:, 1:args.seq_length+1].contiguous()
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("get_ltor_masks_and_position_ids")
attention_mask, loss_mask, position_ids = \
get_ltor_masks_and_position_ids(tokens, tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
question_length=prompt_len)
torch.cuda.nvtx.range_pop()
loss_mask, labels, attention_mask = _preprocess_data_for_llava(loss_mask, labels, attention_mask)
tokens = tokens[:, 1:] # drop image index token
return tokens, labels, loss_mask, attention_mask, position_ids, img_raw
def get_image_token_count():
args = get_args()
add_class_token = not args.disable_vision_class_token
num_patches_per_dim_h = args.img_h // args.patch_dim
num_patches_per_dim_w = args.img_w // args.patch_dim
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
num_image_tokens = num_patches + (1 if add_class_token else 0)
return num_image_tokens
def _preprocess_data_for_llava(loss_mask, labels, attention_mask):
"""Preprocess data sample to the format expected by a LLaVA model."""
num_image_tokens = get_image_token_count()
batch_size = loss_mask.shape[0]
loss_mask2 = torch.cat(
[torch.zeros(batch_size, num_image_tokens - 1, dtype=torch.float32, device=loss_mask.device), loss_mask], dim=1
)
labels2 = torch.cat([torch.zeros(batch_size, num_image_tokens - 1, dtype=torch.int64, device=labels.device), labels], dim=1)
full_seq_length = len(labels2[0])
attention_mask2 = torch.tril(torch.ones((1, 1, full_seq_length, full_seq_length), device=attention_mask.device))
attention_mask2 = attention_mask2 < 0.5
return loss_mask2, labels2, attention_mask2
def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
question_length=None,
weights=None):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if question_length is not None:
for b in range(micro_batch_size):
loss_mask[b, :max(0, question_length[b].item() - 1)] = 0.0
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
if weights is not None:
loss_mask = loss_mask * weights
return attention_mask, loss_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
if loss_mask is not None:
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / max( 1,loss_mask.sum() )
else:
loss = torch.mean(losses)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model: LLaVAModel):
"""Forward training step.
Args:
data_iterator (torch.utils.data.dataloader): Input data iterator
model: Multimodal model
Returns:
output_tensor (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size].
loss_func (callable): Loss function with a loss mask specified.
"""
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids, images = get_batch(data_iterator)
timers('batch-generator').stop()
output_tensor = model(images, tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def add_multimodal_extra_args(parser):
"""Extra arguments."""
group = parser.add_argument_group(title='multimodal arguments')
group.add_argument('--valid-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--dataset-config', type=str, default=None)
group.add_argument("--prompt-path", type=str, default=None)
group.add_argument('--freeze-LM', action='store_true', default=False)
group.add_argument('--freeze-ViT', action='store_true', default=False)
group.add_argument('--language-model-type', type=str, required=True)
group.add_argument("--disable-vision-class-token", action="store_true", default=False)
group.add_argument("--allow-missing-vision-projection-checkpoint", action="store_true", default=False)
group.add_argument("--use-te", action="store_true", default=False)
return parser
if __name__ == "__main__":
train_valid_test_dataloaders_provider.is_distributed = True
pretrain(
train_valid_test_dataloaders_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_multimodal_extra_args,
)
# RETRO MODEL
## Table of contents
- [1. Training Setup](#1-training-setup)
- [2. Data Preprocessing](#2-data-preprocessing)
- [3. Configurations](#3-configurations)
## 1. Training setup
<a id="markdown-training-setup" name="training-setup"></a>
To run the model using a docker container run it as follows
```
PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.09-py3
CHECKPOINT_PATH="" #<Specify path>
TENSORBOARD_LOGS_PATH=""#<Specify path>
docker run \
--gpus=all \
--ipc=host \
--workdir /workspace/megatron-lm \
-v /path/to/data:/path/to/data \
-v /path/to/megatron-lm:/workspace/megatron-lm \
megatron-lm nvcr.io/nvidia/pytorch:23.09-py3 \
bash examples/retro/train_retro_2b_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH"
```
NOTE: Depending on the environment you are running it the above command might look slightly different.
NOTE: Due to how Retro preprocess and caches elements of the pretraining dataset before training begins, some arguments are auto-loaded from the Retro preprocessing configuration. These loaded arguments include:
- `--data-path`
- `--data-cache-path`
- `--eval-interval`
- `--eval-iters`
- `--global-batch-size`
- `--tokenizer-type`
- `--tokenizer-model`
- `--vocab-file`
- `--merge-file`
- `--seed`
- `--seq-length`
- `--train-samples`
## 2. Data Preprocessing
<a id="markdown-data-preprocessing" name="data-preprocessing"></a>
Retro preprocesses and caches data prior to pretraining, to greatly speed up pretraining. During data preprocessing, the retrieval database is built, and neighbor IDs are queried for each sample within the pretraining dataset. Please see `preprocess_data.sh` for an example script to preprocess data for Retro. The reference documentation for data preprocessing can be found [here](tools/retro/README.md).
## 3. Configurations
<a id="markdown-configurations" name="configurations"></a>
The example in this folder shows you how to run a 2B model. Below are a few other example configurations.
### 857M
```
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
```
### 4B
```
--num-layers 48 \
--hidden-size 2560 \
--num-attention-heads 32 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
```
#!/bin/bash
set -u
unset NCCL_DEBUG
######## Megatron, Retro dirs. ########
REPO_DIR="<path/to/megatron/repo>"
RETRO_PROJECT_DIR="<path/to/retro/project/directory>"
######## Task (e.g., db, index, query). ########
# This script takes a single argument, which specifies the retro task to be
# performed. The available tasks are: db-build, index-train, index-add, and
# query-neighbors.
# ~~ Examples ~~
# RETRO_TASKS="db-build" # Build the retrieval database
# RETRO_TASKS="index-train" # Train the index
# RETRO_TASKS="index-add" # Add data to the index
# RETRO_TASKS="query-neighbors" # Perform query pretraining for neighbors
# You can also provide the task as a command-line argument when executing the
# script. Example: ./preprocess_data.sh index-add
RETRO_TASKS=$1
######## Data. ########
DATA_BLEND="<see --data-path in arguments.py>"
######## Index. ########
RETRO_INDEX_STR="OPQ32_64,IVF65536_HNSW8,PQ32"
RETRO_INDEX_NTRAIN=66625331
RETRO_INDEX_TRAIN_LOAD_FRACTION=0.97
RETRO_INDEX_ADD_LOAD_FRACTION=0.95
######## GPT. ########
RETRO_GPT_SEED=1234
RETRO_GPT_SPLIT="98,2,0"
RETRO_GPT_DATA_PATH=${DATA_BLEND}
RETRO_GPT_TRAIN_SAMPLES=200000
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_GPT_LR_DECAY_SAMPLES=175000
RETRO_GPT_LR_WARMUP_SAMPLES=10000
RETRO_GPT_SEQ_LENGTH=2048
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_GPT_CHUNK_LENGTH=64
######## Query. ########
RETRO_QUERY_NUM_NEIGHBORS_QUERY=200
RETRO_QUERY_NUM_NEIGHBORS_SAVE=20
RETRO_QUERY_EF_SEARCH=32
RETRO_QUERY_NPROBE=4096
######## Args. ########
ARGS=" \
--distributed-timeout-minutes 600 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size 1 \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--load ${RETRO_PROJECT_DIR}/checkpoints/bert \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path [null] \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file ${RETRO_PROJECT_DIR}/tokenizer/bert-large-uncased-vocab.txt \
--split ${RETRO_GPT_SPLIT} \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--lr-decay-samples ${RETRO_GPT_LR_DECAY_SAMPLES} \
--lr-warmup-samples ${RETRO_GPT_LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--bf16 \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
--bert-embedder-type megatron \
--output-bert-embeddings \
\
--retro-project-dir ${RETRO_PROJECT_DIR} \
--retro-tasks ${RETRO_TASKS} \
--retro-bert-vocab-file tokenizer/bert-large-uncased-vocab.txt \
--retro-bert-tokenizer-type BertWordPieceLowerCase \
\
--retro-gpt-seed ${RETRO_GPT_SEED} \
--retro-gpt-tokenizer-type GPTSentencePieceTokenizer \
--retro-gpt-tokenizer-model /path/to/tokenizer/model \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-gpt-global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--retro-gpt-eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--retro-gpt-eval-iters ${RETRO_GPT_EVAL_ITERS} \
--retro-gpt-split ${RETRO_GPT_SPLIT} \
--retro-gpt-data-path ${RETRO_GPT_DATA_PATH} \
--retro-gpt-train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
\
--retro-index-str ${RETRO_INDEX_STR} \
--retro-index-ntrain ${RETRO_INDEX_NTRAIN} \
--retro-index-train-load-fraction ${RETRO_INDEX_TRAIN_LOAD_FRACTION} \
--retro-index-add-load-fraction ${RETRO_INDEX_ADD_LOAD_FRACTION} \
--no-retro-index-delete-training-embeddings \
--no-retro-index-delete-added-codes \
\
--retro-query-num-neighbors-query ${RETRO_QUERY_NUM_NEIGHBORS_QUERY} \
--retro-query-num-neighbors-save ${RETRO_QUERY_NUM_NEIGHBORS_SAVE} \
--retro-query-ef-search ${RETRO_QUERY_EF_SEARCH} \
--retro-query-nprobe ${RETRO_QUERY_NPROBE} \
"
######## Command. ########
NPROCS=8 # Number of GPUs.
CMD="\
cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.run \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \
--master_port 6000 \
tools/retro/preprocess_data.py ${ARGS} \
"
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "CMD = '$CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $CMD
#!/bin/bash
# Runs the "307M" parameter Retro model.
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
CHECKPOINT_PATH=$1 #<Specify path>
TENSORBOARD_LOGS_PATH=$2 #<Specify path>
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
######## GPT or Retro? ########
# 0 : GPT.
# 1 : Retro
ADD_RETRIEVER=1
######## Megatron, Retro dirs. ########
RETRO_PROJECT_DIR="<path/to/retro/project/directory>"
######## Model, training args. ########
# ** Note: --seq-length auto loaded from Retro project dir.
RETRO_MODEL_ARGS=(
--num-layers 32
--hidden-size 2048
--num-attention-heads 32
)
# ** Note: --data-path, --tokenizer-type, and --tokenizer-model auto loaded from Retro project dir.
DATA_ARGS=(
--split 98,2,0
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 8
--pipeline-model-parallel-size 1
)
# ** Note: --eval-interval, --eval-iters auto loaded from Retro project dir.
EVAL_AND_LOGGING_ARGS=(
--log-interval 100
--save-interval 10000
--eval-interval 1000
--save $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
--eval-iters 10
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)
TRAINING_ARGS=" \
--retro-project-dir ${RETRO_PROJECT_DIR} \
--transformer-impl transformer_engine \
--num-workers 8 \
--micro-batch-size 4 \
--lr-decay-samples 166400000 \
--lr-warmup-samples 162761 \
--lr 6.0e-4 \
--min-lr 6.0e-5 \
--lr-decay-style cosine \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.023 \
--log-params-norm \
--log-num-zeros-in-grad \
--bf16 \
--no-data-sharding \
"
if [ "$ADD_RETRIEVER" = "1" ]; then
TRAINING_ARGS+=" --retro-add-retriever"
fi
######## Command. ########
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_retro.py \
${RETRO_MODEL_ARGS[@]} \
${TRAINING_ARGS} \
${MODEL_PARALLEL_ARGS[@]} \
${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment