# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """General utilities.""" import json import os import sys from datetime import datetime import torch try: from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_l2norm except ImportError: try: from amp_C import multi_tensor_l2norm from apex.multi_tensor_apply import multi_tensor_applier except ImportError: import warnings warnings.warn( f'Transformer Engine and Apex are not installed. ' 'Falling back to local implementations of ' 'multi_tensor_applier and multi_tensor_l2norm' ) from megatron.core.utils import ( local_multi_tensor_l2_norm as multi_tensor_l2norm, local_multi_tensor_applier as multi_tensor_applier, ) from megatron.training import ( get_args, get_adlr_autoresume, ) from megatron.core import DistributedDataParallel as DDP from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP from megatron.core import mpu from megatron.core.datasets.utils import get_blend_from_list from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate from megatron.core.utils import ( get_batch_on_this_cp_rank, get_data_parallel_group_if_dtensor, to_local_if_dtensor, ) from megatron.legacy.model import Float16Module from megatron.legacy.model.module import param_is_not_shared try: from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, custom_FSDP, Float16Module) except ImportError: ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, custom_FSDP, Float16Module) def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): return_list = True if not isinstance(model, list): model = [model] return_list = False unwrapped_model = [] for model_module in model: while isinstance(model_module, module_instances): model_module = model_module.module unwrapped_model.append(model_module) if not return_list: return unwrapped_model[0] return unwrapped_model def calc_params_l2_norm(model, force_create_fp32_copy=False): """Calculate l2 norm of parameters """ args = get_args() if not isinstance(model, list): model = [model] # Seperate moe and dense params params_data = [] moe_params_data = [] sharded_params_data = [] data_parallel_group = None custom_fsdp_all_param_is_shared = False for model_chunk in model: for param in model_chunk.parameters(): data_parallel_group = get_data_parallel_group_if_dtensor(param, data_parallel_group) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if not is_not_tp_duplicate: continue assert is_not_tp_duplicate if hasattr(param, "fully_shard_param_local_shard"): param = param.fully_shard_param_local_shard assert [getattr(p, "fully_shard_param_local_shard", None) is not None for p in model_chunk.parameters()] custom_fsdp_all_param_is_shared = True if param.numel() == 0: continue if not getattr(param, 'allreduce', True): # TODO: Implement memory optimization for MoE parameters. assert param_is_not_shared(param) param = to_local_if_dtensor(param) moe_params_data.append(param.data.float() if args.bf16 else param.data) else: if param_is_not_shared(param): param = to_local_if_dtensor(param) if args.bf16: if not force_create_fp32_copy and hasattr(param, 'main_param'): if getattr(param, 'main_param_sharded', False): if param.main_param is not None: sharded_params_data.append(param.main_param) else: params_data.append(param.main_param) else: # Fallback to original logic of making a fp32 copy of the # parameter if `.main_param` attribute is not available. params_data.append(param.data.float()) else: params_data.append(param.data) # Calculate norm. dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') if len(params_data) > 0: norm, _ = multi_tensor_applier( multi_tensor_l2norm, dummy_overflow_buf, [params_data], False # no per-parameter norm. ) norm_2 = norm * norm else: norm_2 = torch.zeros((1,), dtype=torch.float32, device='cuda') if data_parallel_group is not None: torch.distributed.all_reduce(norm_2, op=torch.distributed.ReduceOp.SUM, group=data_parallel_group) # Add norm contribution from params with sharded main_params. These norms need to be # accumulated across the DP group since the main parameters are sharded because # of distributed optimizer. if len(sharded_params_data) > 0: dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') sharded_norm, _ = multi_tensor_applier( multi_tensor_l2norm, dummy_overflow_buf, [sharded_params_data], False # no per-parameter norm. ) sharded_norm_2 = sharded_norm * sharded_norm # Sum over all DP groups. torch.distributed.all_reduce( sharded_norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_data_parallel_group() ) norm_2 += sharded_norm_2 if custom_fsdp_all_param_is_shared: torch.distributed.all_reduce(norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_data_parallel_group()) # Sum across all model-parallel GPUs (tensor + pipeline). torch.distributed.all_reduce( norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group() ) # Add norm contribution from expert layers in MoEs. if len(moe_params_data) > 0: moe_norm, _ = multi_tensor_applier( multi_tensor_l2norm, dummy_overflow_buf, [moe_params_data], False # no per-parameter norm. ) moe_norm_2 = moe_norm * moe_norm if custom_fsdp_all_param_is_shared: torch.distributed.all_reduce(moe_norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_expert_data_parallel_group()) # Sum across expert tensor, model and pipeline parallel GPUs. torch.distributed.all_reduce( moe_norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_expert_tensor_model_pipeline_parallel_group() ) norm_2 += moe_norm_2 return norm_2.item() ** 0.5 def average_losses_across_data_parallel_group(losses): """Reduce a tensor of losses across all GPUs.""" averaged_losses = torch.cat( [loss.clone().detach().view(1) for loss in losses]) torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group()) averaged_losses = averaged_losses / \ torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) return averaged_losses def reduce_max_stat_across_model_parallel_group(stat: float) -> float: """ Ranks without an optimizer will have no grad_norm or num_zeros_in_grad stats. We need to ensure the logging and writer rank has those values. This function reduces a stat tensor across the model parallel group. We use an all_reduce max since the values have already been summed across optimizer ranks where possible """ if stat is None: stat = -1.0 stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.all_reduce( stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group() ) if stat.item() == -1.0: return None else: return stat.item() def logical_and_across_model_parallel_group(input: bool) -> bool: """ This function gathers a bool value across the model parallel group """ if input is True: input = 1 else: input = 0 input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device()) torch.distributed.all_reduce( input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group() ) return bool(input.item()) def report_memory(name): """Simple GPU memory report.""" mega_bytes = 1024.0 * 1024.0 string = name + ' memory (MB)' string += ' | allocated: {}'.format( torch.cuda.memory_allocated() / mega_bytes) string += ' | max allocated: {}'.format( torch.cuda.max_memory_allocated() / mega_bytes) string += ' | reserved: {}'.format( torch.cuda.memory_reserved() / mega_bytes) string += ' | max reserved: {}'.format( torch.cuda.max_memory_reserved() / mega_bytes) if mpu.get_data_parallel_rank() == 0: print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True) def print_params_min_max_norm(optimizer, iteration): """Print min, max, and norm of all parameters.""" index = 0 rank = torch.distributed.get_rank() string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n' optimizer_ = optimizer.optimizer for param_group in optimizer_.param_groups: for param in param_group['params']: index += 1 min_ = param.data.min() max_ = param.data.max() norm = torch.linalg.norm(param.data) string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( iteration, rank, index, int(param.tensor_model_parallel)) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) print(string, flush=True) def check_adlr_autoresume_termination(iteration, model, optimizer, opt_param_scheduler): """Check for autoresume signal and exit if it is received.""" from megatron.training.checkpointing import save_checkpoint args = get_args() autoresume = get_adlr_autoresume() # Add barrier to ensure consistnecy. torch.distributed.barrier() if autoresume.termination_requested(): if args.save: save_checkpoint(iteration, model, optimizer, opt_param_scheduler) print_rank_0(">>> autoresume termination request found!") if torch.distributed.get_rank() == 0: autoresume.request_resume() print_rank_0(">>> training terminated. Returning") sys.exit(0) def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. micro_batch_size, seq_length = data.size() # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = micro_batch_size else: att_mask_batch = 1 attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) if eod_mask_loss: loss_mask[data == eod_token] = 0.0 # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. if reset_position_ids: position_ids = position_ids.clone() if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(micro_batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. if reset_attention_mask: attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 # Reset positions. if reset_position_ids: position_ids[b, (i + 1):] -= (i + 1 - prev_index) prev_index = i + 1 # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) return attention_mask, loss_mask, position_ids def print_rank_0(message): """If distributed is initialized, print only on rank 0.""" if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(message, flush=True) else: print(message, flush=True) def is_rank0(): """Returns true if called in the rank0, false otherwise""" return torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 def is_last_rank(): return torch.distributed.get_rank() == ( torch.distributed.get_world_size() - 1) def print_rank_last(message): """If distributed is initialized, print only on last rank.""" if torch.distributed.is_initialized(): if is_last_rank(): print(message, flush=True) else: print(message, flush=True) def get_device_arch_version(): """Returns GPU arch version (8: Ampere, 9: Hopper, 10: Blackwell, ...)""" return torch.cuda.get_device_properties(torch.device("cuda:0")).major def append_to_progress_log(string, barrier=True): """Append given string to progress log.""" args = get_args() if args.save is None: return progress_log_filename = os.path.join(args.save, "progress.txt") if barrier: torch.distributed.barrier() if torch.distributed.get_rank() == 0: with open(progress_log_filename, 'a') as f: job_id = os.getenv('SLURM_JOB_ID', '') num_gpus = args.world_size f.write(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tJob ID: {job_id}\t" f"# GPUs: {num_gpus}\t{string}\n") def get_blend_and_blend_per_split(args): """Get blend and blend_per_split from passed-in arguments.""" use_data_path = args.data_path is not None or \ args.data_args_path is not None use_per_split_data_path = any( elt is not None for elt in [args.train_data_path, args.valid_data_path, args.test_data_path]) or \ args.per_split_data_args_path is not None blend = None blend_per_split = None if use_data_path: if args.data_args_path is not None: assert args.data_path is None with open(args.data_args_path, 'r') as f: blend = get_blend_from_list(f.read().split()) else: assert args.data_path is not None blend = get_blend_from_list(args.data_path) elif use_per_split_data_path: if args.per_split_data_args_path is not None: with open(args.per_split_data_args_path, 'r') as f: per_split_data_args = json.load(f) # Each element in blend_per_split should be a list of files (and optional # weights), so split string if needed. for split in ["train", "valid", "test"]: if isinstance(per_split_data_args[split], str): per_split_data_args[split] = per_split_data_args[split].split() blend_per_split = [ get_blend_from_list(per_split_data_args["train"]), get_blend_from_list(per_split_data_args["valid"]), get_blend_from_list(per_split_data_args["test"]) ] else: blend_per_split = [ get_blend_from_list(args.train_data_path), get_blend_from_list(args.valid_data_path), get_blend_from_list(args.test_data_path) ] else: blend, blend_per_split = None, None return blend, blend_per_split def get_batch_on_this_tp_rank(data_iterator): args = get_args() def _broadcast(item): if item is not None: torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) if mpu.get_tensor_model_parallel_rank() == 0: if data_iterator is not None: data = next(data_iterator) else: data = None batch = { 'tokens': data["tokens"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), 'position_ids': data["position_ids"].cuda(non_blocking = True) } if args.pipeline_model_parallel_size == 1: _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) elif mpu.is_pipeline_last_stage(): _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) else: tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device()) if args.create_attention_mask_in_dataloader: attention_mask=torch.empty( (args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device() ) else: attention_mask=None position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1: _broadcast(tokens) _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) elif mpu.is_pipeline_first_stage(): labels=None loss_mask=None _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) elif mpu.is_pipeline_last_stage(): tokens=None position_ids=None _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) batch = { 'tokens': tokens, 'labels': labels, 'loss_mask': loss_mask, 'attention_mask': attention_mask, 'position_ids': position_ids } return batch def update_use_dist_ckpt(args): args.use_dist_ckpt = args.ckpt_format != "torch"