# coding=utf-8 # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """General utilities.""" import sys import torch from megatron import get_args from megatron import get_adlr_autoresume from megatron import print_rank_0 from megatron.checkpointing import save_checkpoint from megatron.fp16 import FP16_Optimizer def reduce_losses(losses): """Reduce a tensor of losses across all GPUs.""" reduced_losses = torch.cat( [loss.clone().detach().view(1) for loss in losses]) torch.distributed.all_reduce(reduced_losses) reduced_losses = reduced_losses / torch.distributed.get_world_size() return reduced_losses def report_memory(name): """Simple GPU memory report.""" mega_bytes = 1024.0 * 1024.0 string = name + ' memory (MB)' string += ' | allocated: {}'.format( torch.cuda.memory_allocated() / mega_bytes) string += ' | max allocated: {}'.format( torch.cuda.max_memory_allocated() / mega_bytes) string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) string += ' | max cached: {}'.format( torch.cuda.max_memory_cached()/ mega_bytes) print_rank_0(string) def print_params_min_max_norm(optimizer, iteration): """Print min, max, and norm of all parameters.""" index = 0 rank = torch.distributed.get_rank() string = 'iteration, rank, index, model-parallel,min, max, norm\n' optimizer_ = optimizer if isinstance(optimizer, FP16_Optimizer): optimizer_ = optimizer.optimizer for param_group in optimizer_.param_groups: for param in param_group['params']: index += 1 min_ = param.data.min() max_ = param.data.max() norm = param.data.norm() string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( iteration, rank, index, int(param.model_parallel)) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) print(string, flush=True) def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler): """Check for autoresume signal and exit if it is received.""" args = get_args() autoresume = get_adlr_autoresume() # Add barrier to ensure consistnecy. torch.distributed.barrier() if autoresume.termination_requested(): if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) print_rank_0(">>> autoresume termination request found!") if torch.distributed.get_rank() == 0: autoresume.request_resume() print_rank_0(">>> training terminated. Returning") sys.exit(0) ################################################### from megatron import mpu def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = batch_size else: att_mask_batch = 1 attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) if eod_mask_loss: loss_mask[data == eod_token] = 0.0 # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. if reset_position_ids: position_ids = position_ids.clone() if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. if reset_attention_mask: attention_mask[b, 0, (i+1):, :(i+1)] = 0 # Reset positions. if reset_position_ids: position_ids[b, (i+1):] -= (i + 1 - prev_index) prev_index = i + 1 return attention_mask, loss_mask, position_ids def vocab_size_with_padding(num_tokens, args): after = num_tokens multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format( num_tokens, after - num_tokens, after)) return after