# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings from copy import deepcopy from functools import partial import torch import yaml import torch._dynamo torch._dynamo.config.suppress_errors = True from megatron.training.checkpointing import get_checkpoint_name from megatron.training import get_args, get_timers, get_tokenizer, pretrain from megatron.training.utils import is_last_rank from megatron.training import get_args, get_tokenizer, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from megatron.core import mpu, tensor_parallel from megatron.core.enums import ModelType from megatron.core.parallel_state import get_tensor_model_parallel_rank from megatron.core import parallel_state from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.energon import ( LimitDataset, RepeatDataset, WorkerConfig, get_loader, get_savable_loader, get_train_dataset, get_val_datasets, ) from megatron_patch.model.llava_mcore.layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec from megatron_patch.model.llava_mcore.llava_model import IMAGE_TOKEN, LLaVAModel, IGNORE_INDEX from megatron_patch.model.llava_mcore.vision.clip_vit_model import get_num_image_embeddings from megatron_patch.model.llava_mcore.transformer_config import get_language_model_config, get_vision_model_config, get_vision_projection_config from megatron_patch.arguments import get_patch_args from dataset_helpers import TaskEncoder, print_error_handler def model_provider( pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True ) -> LLaVAModel: args = get_args() use_te = args.transformer_impl == "transformer_engine" print_rank_0('building a multimodal model ...') num_image_embeddings = get_num_image_embeddings( args.img_h, args.img_w, args.patch_dim, args.vision_model_type, args.disable_vision_class_token, 1 ) old_seq_length = args.seq_length args.seq_length = args.encoder_seq_length = num_image_embeddings if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: warnings.warn( f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" ) max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings assert ( args.decoder_seq_length is not None ), "Please provide --decoder-seq-length to set the language model sequence length" assert ( args.decoder_seq_length > max_num_image_embeddings ), "Language model sequence length must be greater than the maximum number of image embeddings" if args.decoder_seq_length > args.max_position_embeddings: args.max_position_embeddings = args.decoder_seq_length warnings.warn( f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length" ) base_config = core_transformer_config_from_args(get_args()) base_config.language_model_type = args.language_model_type base_config.vision_model_type = args.vision_model_type base_config.calculate_per_token_loss = True language_config = deepcopy(base_config) language_config = get_language_model_config(language_config) if use_te: language_transformer_layer_spec = get_layer_spec_te( is_vit=False ) # TENorm detects LayerNorm/RMS automatically. else: language_transformer_layer_spec = get_layer_spec( is_vit=False, normalization=language_config.normalization ) vision_config = deepcopy(base_config) vision_config = get_vision_model_config( vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling ) vision_model_type = args.vision_model_type if vision_model_type in ["clip", "siglip"]: if use_te: vision_transformer_layer_spec = get_layer_spec_te( is_vit=True ) # TENorm detects LayerNorm/RMS automatically. else: vision_transformer_layer_spec = get_layer_spec( is_vit=True, normalization=vision_config.normalization ) else: raise RuntimeError("unsupported vision model type", vision_model_type) vision_projection_config = deepcopy(base_config) vision_projection_config = get_vision_projection_config( vision_projection_config, language_config.hidden_size ) if args.encoder_pipeline_model_parallel_size > 0: assert ( args.encoder_pipeline_model_parallel_size == 1 ), "vision model and projection can only live on 1 pipeline stage." vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size vision_projection_config.pipeline_model_parallel_size = ( args.encoder_pipeline_model_parallel_size ) if args.encoder_tensor_model_parallel_size > 0: vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size vision_projection_config.tensor_model_parallel_size = ( args.encoder_tensor_model_parallel_size ) vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules model = LLaVAModel( language_transformer_config=language_config, language_transformer_layer_spec=language_transformer_layer_spec, language_vocab_size=args.padded_vocab_size, language_max_sequence_length=args.decoder_seq_length, vision_transformer_config=vision_config, vision_transformer_layer_spec=vision_transformer_layer_spec, drop_vision_class_token=args.disable_vision_class_token, vision_projection_config=vision_projection_config, vision_projection_layer_spec=vision_projection_layer_spec, vision_projection_type="mlp", allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint, parallel_output=parallel_output, language_position_embedding_type=args.position_embedding_type, language_rotary_percent=args.rotary_percent, pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder, img_h=args.img_h, img_w=args.img_w, patch_dim=args.patch_dim, language_rotary_base=args.rotary_base, language_rope_scaling=args.use_rope_scaling, image_token_index=get_tokenizer().convert_tokens_to_ids(IMAGE_TOKEN), ) model.freeze( freeze_language_model=args.freeze_LM, freeze_vision_model=args.freeze_ViT, freeze_vision_projection=False, ) return model def get_batch(data_iterator): """Generate a batch""" imgs = None tokens = None labels = None loss_mask = None attention_mask = None position_ids = None num_tiles = None # Broadcast data. torch.cuda.nvtx.range_push("get_data") if data_iterator is not None and get_tensor_model_parallel_rank() == 0: data = next(data_iterator) else: data = None data_text = tensor_parallel.broadcast_data(["text"], data, torch.int64)["text"] target = tensor_parallel.broadcast_data(["target"], data, torch.int64)["target"] imgs = tensor_parallel.broadcast_data(["imgs"], data, torch.float32)["imgs"] num_tiles = tensor_parallel.broadcast_data(["num_tiles"], data, torch.int)["num_tiles"] # Dummy image, no image. if imgs.shape == torch.Size([1, 1]): imgs = torch.tensor([], dtype=torch.float32, device=data_text.device) num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device) torch.cuda.nvtx.range_pop() tokens_ = data_text.long() torch.cuda.nvtx.range_push("index tokens") tokenizer = get_tokenizer() text_length = tokens_.shape[1] tokens = tokens_[:, :text_length].contiguous() labels = target[:, 1 : text_length + 1].contiguous() assert tokens.shape == labels.shape, f"tokens: {tokens.shape} != labels: {labels.shape}" torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push("get_ltor_masks_and_position_ids") attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, labels, tokenizer.pad ) torch.cuda.nvtx.range_pop() return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles def get_ltor_masks_and_position_ids(input_ids, target, pad_token): """Build masks and position id for left to right model.""" seq_length = input_ids.shape[1] # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # Loss mask. loss_mask = torch.ones(target.size(), dtype=torch.float, device=input_ids.device) loss_mask[target == pad_token] = 0.0 # mask paddings loss_mask[target == IGNORE_INDEX] = 0.0 # mask prompts # Attention mask. attention_mask = None return attention_mask, loss_mask, position_ids def loss_func(loss_mask, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.contiguous().view(-1).float() total_tokens = loss_mask.sum() total_loss = torch.sum(losses.view(-1) * loss_mask) loss = torch.cat([total_loss.view(1), total_tokens.view(1)]) reporting_loss = loss.clone().detach() torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) local_num_tokens = loss[1].clone().detach().to(torch.int) return (total_loss, local_num_tokens, {'lm loss': (reporting_loss[0], reporting_loss[1])}) def forward_step(data_iterator, model: LLaVAModel): """Forward training step. Args: data_iterator (torch.utils.data.dataloader): Input data iterator model: Multimodal model Returns: output_tensor (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. loss_func (callable): Loss function with a loss mask specified. """ timers = get_timers() # Get the batch. timers('batch-generator', log_level=2).start() tokens, labels, loss_mask, attention_mask, position_ids, images, num_image_tiles = get_batch( data_iterator ) timers('batch-generator').stop() output_tensor, loss_mask = model( images, tokens, position_ids, attention_mask, labels, loss_mask, num_image_tiles=num_image_tiles, ) return output_tensor, partial(loss_func, loss_mask) def llava_embedding_ranks(pp_ranks): """LLava's embedding ranks consist of the decoder's first and last ranks (ie, the ViT has no embeddings). Args: pp_ranks: A list of global ranks that constitute a pipeline group. """ args = get_args() # encoder size is also the index to the first rank of the decoder. epp = args.encoder_pipeline_model_parallel_size last_rank = pp_ranks[-1] if len(pp_ranks) == 1 or pp_ranks[epp] == last_rank: return [last_rank] else: return [pp_ranks[epp], last_rank] def llava_position_embedding_ranks(pp_ranks): """LLava's embedding ranks consist of the singular rank of the model or the decoder's first rank. Args: pp_ranks: A list of global ranks that constitute a pipeline group. """ args = get_args() # encoder size is also the index to the first rank of the decoder. epp = args.encoder_pipeline_model_parallel_size last_rank = pp_ranks[-1] if len(pp_ranks) == 1: return [last_rank] else: return [pp_ranks[epp]] def run_online_eval(model): """Run an evaluation benchmark during training.""" args = get_args() # Online evaluation config is not defined. Do nothing. if not args.online_evaluation_config: return [] from megatron_patch.model.llava_mcore.transformer_config import EvaluationConfig from run_text_generation import generate_and_write_samples with open(args.online_evaluation_config, "r") as f: config_dict = yaml.safe_load(f) config = EvaluationConfig(**config_dict) # The inference code assumes the first rank is the leader. # Tensorboard writer is on the last rank. # We must write to a storage space that all ranks see. output_dir = os.path.join(args.save, "online_eval") os.makedirs(output_dir, exist_ok=True) config.output_path = os.path.join(output_dir, args.language_model_type) # The actual generation. generate_and_write_samples(model[0].module, config, print_output=False) # Make sure the first rank is done writing so that the last rank can run eval. torch.distributed.barrier() if not is_last_rank(): return [] # Run evaluation. if config.task == "TextVQA": from evaluate_textvqa import textvqa_eval avg_acc = textvqa_eval(config.output_path) return [{"TextVQA accuracy": avg_acc}] else: raise NotImplementedError(f"online evaluation of {config.task} not implemented yet") def write_online_eval_to_tensorboard(data, iteration, writer): """Write online evaluation data to Tensorboard.""" if not writer: return for item in data: for k, v in item.items(): writer.add_scalar(k, v, iteration) def datasets_provider(worker_config=None): """Create multimodal train, validation and test datasets.""" args = get_args() train_dataset = get_train_dataset( args.train_data_path[0], batch_size=args.micro_batch_size, task_encoder=TaskEncoder(), worker_config=worker_config, virtual_epoch_length=1000, max_samples_per_sequence=100, shuffle_buffer_size=100, handler=print_error_handler, image_decode="pil", ) val_datasets = get_val_datasets( args.valid_data_path[0], batch_size=args.micro_batch_size, # This is the total number over all workers # limit=args.eval_iters * get_num_microbatches(), task_encoder=TaskEncoder(), worker_config=worker_config, handler=print_error_handler, image_decode="pil", ) val_datasets_without_source_datasets = [ # Limit the dataset to eval_iters * num_microbatches LimitDataset( # Repeat the inner dataset in case it's too short RepeatDataset(val_ds, worker_config=worker_config), length=args.eval_iters * get_num_microbatches(), worker_config=worker_config, reset_after_epoch=True, ) for val_ds, _src_ds in val_datasets ] return train_dataset, val_datasets_without_source_datasets, None def train_valid_test_dataloaders_provider(train_val_test_num_samples): """Build multimodal train, validation and test dataloaders.""" if get_tensor_model_parallel_rank() != 0: return None, None, None args = get_args() worker_debug_path = None worker_log_level = 0 rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() data_parallel_group = parallel_state.get_data_parallel_group() worker_config = WorkerConfig( rank=rank, world_size=world_size, num_workers=args.num_workers, data_parallel_group=data_parallel_group, worker_debug_path=worker_debug_path, worker_log_level=worker_log_level, ) train_ds, valid_ds1, test_ds = datasets_provider(worker_config) train_dataloader = get_savable_loader(train_ds, worker_config=worker_config) if args.load is not None: if getattr(args, "dataloader_save", None): dp_rank = parallel_state.get_data_parallel_rank() data_save_name = get_checkpoint_name( args.dataloader_save, args.iteration, basename=f"train_dataloader_dprank{dp_rank:03d}.pt", ) if os.path.exists(data_save_name): try: dataset_state_dict = torch.load(data_save_name, map_location="cpu") train_dataloader.restore_state_rank(dataset_state_dict["dataloader_state_dict"]) print_rank_0(f"restored dataset state from {data_save_name}") except Exception as e: print_rank_0("loading dataloader checkpoint failed. Skipping. " + str(e)) valid_dataloader = [ EnergonDataloader(get_loader(valid_ds, worker_config=worker_config)) for valid_ds in valid_ds1 ] test_dataloader = None return EnergonDataloader(train_dataloader), valid_dataloader, EnergonDataloader(test_dataloader) class EnergonDataloader: """A wrapper to use Megatron Energon dataloader with the Megatron-LM training loop.""" def __init__(self, dataloader): self._dataloader = dataloader self._iter = iter(cyclic_iter(dataloader)) def __next__(self): return self._iter.__next__() def __iter__(self): return self._iter.__iter__() def save_state(self): return self._dataloader.save_state_rank() def cyclic_iter(iter): while True: for x in iter: yield x if __name__ == "__main__": train_valid_test_dataloaders_provider.is_distributed = True pretrain( train_valid_test_dataloaders_provider, model_provider, ModelType.encoder_and_decoder, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, extra_args_provider=get_patch_args, process_non_loss_data_func=write_online_eval_to_tensorboard, get_embedding_ranks=llava_embedding_ranks, get_position_embedding_ranks=llava_position_embedding_ranks, non_loss_data_func=run_online_eval, )