# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Generation utilities.""" import torch import torch.nn.functional as F from megatron.training import get_args, get_tokenizer from megatron.training import get_retro_args from megatron.core import mpu from megatron.training.utils import get_ltor_masks_and_position_ids, unwrap_model from megatron.inference.text_generation.communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage, broadcast_int_list, broadcast_tensor) from megatron.inference.text_generation.generation import _build_attention_mask_and_position_ids from megatron.inference.text_generation.sampling import sample def retro_generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, neighbours_array=None, return_output_log_probs=False, top_k=0, top_p=0.0, temperature=1.0, use_eod_token_for_early_termination=True, stop_on_double_eol=False, stop_on_eol=False, logits_mask=None): """Main token generation function. Args: model: no interleaving is supported. tokens: prompt tokens extended to be of size [b, max-sequence-length] lengths: original prompt length, size: [b] neighbours_array: neighbours array of size [b, l, k, r] return_output_log_probs: flag to calculate the log probability of the generated tokens. Note that the log probability is the one from the original logit. top_k, top_p: top-k and top-p sampling parameters. Note that top-k = 1 is gready. Also, these paramters are exclusive meaning that: if top-k > 0 then we expect top-p=0. if top-p > 0 then we check for top-k=0. temperature: sampling temperature. use_eod_token_for_early_termination: if True, do early termination if all the sequences have reached this token. Note: Outside of model, other parameters only need to be available on rank 0. Returns: Note that is size is adjusted to a lower value than max-sequence-length if generation is terminated early. tokens: prompt and generated tokens. size: [b, :] generated_sequence_lengths: total length (including prompt) of the generated sequence. size: [b] output_log_probs: log probability of the selected tokens. size: [b, s] """ args = get_args() retro_args = get_retro_args() tokenizer = get_tokenizer() batch_size = tokens.size(0) min_prompt_length = lengths.min().item() max_sequence_length = tokens.size(1) print("max_sequence_length", max_sequence_length) print("min_prompt_length", min_prompt_length) max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # If the context is too big, this happens if min_prompt_length >= max_sequence_length: raise ValueError("context length + tokens_to_generate too large") # forward step. unwrapped_model = unwrap_model( model) unwrapped_model.language_model.seq_length = max_sequence_length # Added termination_id to support the case that we want to terminate the # generation once that id is generated. if hasattr(args, 'eos_id'): termination_id = args.eos_id else: termination_id = tokenizer.eod # =================== # Pre-allocate memory # =================== # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_sequence_length - 1) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = None if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) generated_sequence_lengths = torch.ones( batch_size, dtype=torch.int64, device=torch.cuda.current_device()) * max_sequence_length # Whether we have reached a termination id. is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) # ============= # Run infernece # ============= with torch.no_grad(): attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) for context_length in range(min_prompt_length, max_sequence_length): prev_context_length = 0 sizes_list = None neighbor_tokens_cuda_long_tensor = None # get the chunks for retrieval if torch.distributed.get_rank() == 0: neighbor_tokens = neighbours_array neighbor_tokens_cuda_long_tensor = torch.cuda.LongTensor( neighbor_tokens.reshape((-1, retro_args.retro_gpt_retrieved_length))) sizes_list = [neighbor_tokens_cuda_long_tensor.size(0), # Batch size neighbor_tokens_cuda_long_tensor.size(1)] # Sequence lenght sizes_tensor = broadcast_int_list(2, int_list=sizes_list) sizes = sizes_tensor.tolist() neighbor_tokens_cuda_long_tensor = broadcast_tensor( sizes, torch.int64, tensor=neighbor_tokens_cuda_long_tensor) _, _, neighbor_position_ids = get_ltor_masks_and_position_ids( neighbor_tokens_cuda_long_tensor, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) neighbor_attention_mask = None # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:4096] positions2use = position_ids[:, prev_context_length:4096] attention_mask2use = attention_mask[ ..., prev_context_length:4096, :4096] logits = model(tokens2use, positions2use, attention_mask2use, retriever_input_ids=neighbor_tokens_cuda_long_tensor, retriever_position_ids=neighbor_position_ids, retriever_attn_mask=neighbor_attention_mask, ) if mpu.is_pipeline_last_stage(): # Always the last stage should have an output. assert logits is not None # Sample. last_token_logits = logits[:, context_length - 1, :] # last_token_logits = logits[:, -1, :] # word banning if logits_mask is not None: last_token_logits[:, logits_mask] = float('-Inf') new_sample = sample(last_token_logits, top_k=top_k, top_p=top_p, temperature=temperature, vocab_size=tokenizer.vocab_size) # If a prompt length is smaller or equal th current context # length, it means we have started generating tokens started = lengths <= context_length # Update the tokens. tokens[started, context_length] = new_sample[started] # Calculate the log probabilities. if return_output_log_probs: log_probs = F.log_softmax(logits, dim=2) if return_output_log_probs: # Pick the tokens that we need to get the log # probabilities for. Note that next input token is # the token which we selected in the current logits, # so shift by 1. indices = torch.unsqueeze( tokens[ :, (prev_context_length + 1):(context_length + 1)], 2) output_log_probs[:, prev_context_length:context_length] = \ torch.gather(log_probs, 2, indices).squeeze(2) # Update the tokens on the first stage so the next input to # the network is correct. copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, tokens[:, context_length]) # Update the context length for the next token generation. prev_context_length = context_length # Check if all the sequences have hit the termination_id. done = None if mpu.is_pipeline_last_stage(): # TODO(rprenger) These stopping methods are tokenizer dependent # instead tokenization should be in the inference loop so stop sequences can be used if stop_on_double_eol: hit_double_eol = (new_sample == 628).byte() & started.byte() hit_two_eols = (new_sample == 198).byte() & ( tokens[:, context_length - 1] == 198).byte() & started.byte() done_token = hit_double_eol | hit_two_eols elif stop_on_eol: hit_double_eol = (new_sample == 628).byte() & started.byte() hit_eol = (new_sample == 198).byte() & started.byte() done_token = hit_double_eol | hit_eol elif context_length > min_prompt_length + 64: # previous retrov1 limitations done_token = 1 else: done_token = (new_sample == termination_id).byte() & \ started.byte() just_finished = (done_token & ~is_generation_done).bool() generated_sequence_lengths[just_finished.view(-1)] = \ context_length + 1 is_generation_done = is_generation_done | done_token done = torch.all(is_generation_done) done = broadcast_from_last_pipeline_stage(1, torch.uint8, tensor=done) if use_eod_token_for_early_termination and done: break # =================================================== # Update the length of based on max generated length. # =================================================== tokens = tokens[:, :(context_length + 1)] if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = output_log_probs[:, :context_length] # ====================================== # Broadcast to the first pipeline stage. # ====================================== generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( batch_size, torch.int64, generated_sequence_lengths) if return_output_log_probs: output_log_probs_size = (batch_size, context_length) output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) return tokens, generated_sequence_lengths, output_log_probs