# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain and SFT GPT.""" import torch from functools import partial from typing import List, Optional, Tuple from megatron.core import parallel_state from megatron.training import inprocess_restart from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.utils import get_attr_wrapped_model, StragglerDetector from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training.utils import ( get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_blend_and_blend_per_split, is_first_or_last_pipeline_stage, ) from megatron.training.datasets.sft_dataset import SFTDataset from model_provider import model_provider from gpt_builders import gpt_builder try: from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.loss_func import loss_func as loss_func_modelopt has_nvidia_modelopt = True except ImportError: has_nvidia_modelopt = False from dcu_megatron.core.parallel_state import get_virtual_vocab_parallel_chunk from input_store import InputStore from dcu_megatron import megatron_adaptor stimer = StragglerDetector() def get_batch(data_iterator, vp_stage=None, microbatch_id=None): """Generate a batch.""" # TODO: this is pretty hacky, find a better way if not get_args().enable_vocab_parallel and not is_first_or_last_pipeline_stage(vp_stage): return None, None, None, None, None if ( get_args().schedule_method == "dualpipev" and parallel_state.is_pipeline_last_stage(ignore_virtual=True) ): return None, None, None, None, None if (get_args().enable_vocab_parallel) and (get_virtual_vocab_parallel_chunk() != 2): return InputStore.get_batch(microbatch_id) # get batches based on the TP rank you are on batch = get_batch_on_this_tp_rank(data_iterator) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) if (get_args().enable_vocab_parallel) and (get_virtual_vocab_parallel_chunk() == 2): InputStore.save_batch(microbatch_id, batch.values()) return batch.values() # define spiky loss as a loss that's 10x the max loss observed SPIKY_LOSS_FACTOR = 10 def loss_func( loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None ): """Loss function. Args: loss_mask (torch.Tensor): Used to mask out some portions of the loss output_tensor (torch.Tensor): The tensor with the losses model (GPTModel, optional): The model (can be wrapped) Returns: the loss scalar for this micro-batch the number of non-padded tokens in this microbatch a dict containing reporting metrics on the loss and number of tokens across the data parallel ranks """ args = get_args() if has_nvidia_modelopt and getattr(args, 'modelopt_enabled', False): # [ModelOpt] return loss_func_modelopt(loss_mask, output_tensor, model=model) losses = output_tensor.view(-1).float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses * loss_mask) # Check individual rank losses are not NaN prior to DP all-reduce. rerun_state_machine = get_rerun_state_machine() if args.check_for_nan_in_loss_and_grad: rerun_state_machine.validate_result( result=loss, rejection_func=torch.isnan, message="found NaN in local forward loss calculation", tolerance=0.0, # forward pass calculations are determinisic fatal=True, ) rerun_state_machine.validate_result( result=loss, rejection_func=torch.isinf, message="found Inf in local forward loss calculation", tolerance=0.0, # forward pass calculations are determinisic fatal=True, ) # Check for spiky loss if args.check_for_spiky_loss: rerun_state_machine.validate_result( result=loss, rejection_func=partial( rerun_state_machine.is_unexpectedly_large, threshold=SPIKY_LOSS_FACTOR, context="loss", ), message="Spiky loss", tolerance=0.0, # forward pass calculations are determinisic fatal=False, ) num_tokens = loss_mask.sum().clone().detach().to(torch.int) reporting_loss = torch.cat([loss.clone().detach().view(1), num_tokens.view(1)]) return (loss, num_tokens, {'lm loss': reporting_loss}) def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False, microbatch_id = None): """Forward training step. Args: data_iterator : Input data iterator model (GPTModel): The GPT Model return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor """ args = get_args() timers = get_timers() # Get the batch. timers('batch-generator', log_level=2).start() global stimer with stimer(bdata=True): vp_stage = get_attr_wrapped_model(model, "vp_stage") tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage, microbatch_id=microbatch_id) timers('batch-generator').stop() with stimer: if args.use_legacy_models: output_tensor = model(tokens, position_ids, attention_mask, labels=labels) else: if return_schedule_plan: assert args.overlap_moe_expert_parallel_comm, \ "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" schedule_plan = model.build_schedule_plan( tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask ) return schedule_plan, partial(loss_func, loss_mask, model=model) else: output_tensor = model( tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask ) # [ModelOpt]: model is needed to access ModelOpt distillation losses return output_tensor, partial(loss_func, loss_mask, model=model) def is_dataset_built_on_rank(vp_stage=None): if get_args().enable_vocab_parallel: return parallel_state.get_tensor_model_parallel_rank() == 0 return is_first_or_last_pipeline_stage(vp_stage) and parallel_state.get_tensor_model_parallel_rank() == 0 def core_gpt_dataset_config_from_args(args): if args.legacy_tokenizer: tokenizer = get_tokenizer() else: tokenizer = build_tokenizer(args) # Sometimes --data-path is too long, instead we parse it from a file. blend: Optional[Tuple[List[str], Optional[List[float]]]] blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] blend, blend_per_split = get_blend_and_blend_per_split(args) return GPTDatasetConfig( random_seed=args.seed, sequence_length=args.seq_length, blend=blend, blend_per_split=blend_per_split, split=args.split, multiple_validation_sets=args.multiple_validation_sets, full_validation=args.full_validation, num_dataset_builder_threads=args.num_dataset_builder_threads, path_to_cache=args.data_cache_path, mmap_bin_files=args.mmap_bin_files, tokenizer=tokenizer, reset_position_ids=args.reset_position_ids, reset_attention_mask=args.reset_attention_mask, eod_mask_loss=args.eod_mask_loss, create_attention_mask=args.create_attention_mask_in_dataloader, object_storage_cache_path=args.object_storage_cache_path, mid_level_dataset_surplus=args.mid_level_dataset_surplus, ) def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None): """Build the train test and validation datasets. Args: train_val_test_num_samples : A list containing the number of samples in train test and validation. """ args = get_args() config = core_gpt_dataset_config_from_args(args) if args.sft: dataset_type = SFTDataset else: if args.mock_data: dataset_type = MockGPTDataset else: dataset_type = GPTDataset print_rank_0("> building train, validation, and test datasets for GPT ...") train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( dataset_type, train_val_test_num_samples, partial(is_dataset_built_on_rank, vp_stage=vp_stage), config ).build() print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds if __name__ == "__main__": # Temporary for transition to core datasets train_valid_test_datasets_provider.is_distributed = True # Optionally enable inprocess restart on pretrain pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) pretrain( train_valid_test_datasets_provider, partial(model_provider, gpt_builder), ModelType.encoder_or_decoder, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None, store=store, )