# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Generation utilities.""" import torch import torch.nn.functional as F from megatron import get_args, get_tokenizer, mpu from megatron.utils import get_ltor_masks_and_position_ids from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import ForwardStep from .sampling import sample def score_and_return_on_first_stage(model, tokens, lengths): """Function for just scoring. Arguments: model: no interleaving is supported. tokens: prompt tokens extended to be of size [b, max_prompt_length] lengths: original prompt length, size: [b] Note: Outside of model, other parameters only need to be available on rank 0. Outputs: output_log_probs: log probability of the selected tokens. size: [b, s] """ args = get_args() batch_size = tokens.size(0) max_prompt_length = lengths.max().item() assert max_prompt_length == tokens.size(1) max_sequence_length = min(max_prompt_length, args.max_position_embeddings) # forward step. forward_step = ForwardStep(model, batch_size, max_sequence_length) # =================== # Pre-allocate memory # =================== # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_sequence_length - 1) if mpu.is_pipeline_last_stage(): output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) # ============= # Run infernece # ============= with torch.no_grad(): attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens) # logits will be meanigful only in the last pipeline stage. logits = forward_step(tokens, position_ids, attention_mask) if mpu.is_pipeline_last_stage(): # Always the last stage should have an output. assert logits is not None log_probs = F.log_softmax(logits, dim=2) # Pick the tokens that we need to get the log # probabilities for. Note that next input token is # the token which we selected in the current logits, # so shift by 1. indices = torch.unsqueeze(tokens[:, 1:], 2) output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2) # ====================================== # Broadcast to the first pipeline stage. # ====================================== output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) return tokens, lengths, output_log_probs def generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, return_output_log_probs=False, top_k=0, top_p=0.0, temperature=1.0, use_eod_token_for_early_termination=True, stop_on_double_eol=False, stop_on_eol=False ): """Main token generation function. Arguments: model: no interleaving is supported. tokens: prompt tokens extended to be of size [b, max-sequence-length] lengths: original prompt length, size: [b] return_output_log_probs: flag to calculate the log probability of the generated tokens. Note that the log probability is the one from the original logit. top_k, top_p: top-k and top-p sampling parameters. Note that top-k = 1 is gready. Also, these paramters are exclusive meaning that: if top-k > 0 then we expect top-p=0. if top-p > 0 then we check for top-k=0. temperature: sampling temperature. use_eod_token_for_early_termination: if True, do early termination if all the sequences have reached this token. Note: Outside of model, other parameters only need to be available on rank 0. Outputs: Note that is size is adjusted to a lower value than max-sequence-length if generation is terminated early. tokens: prompt and generated tokens. size: [b, :] generated_sequence_lengths: total length (including prompt) of the generated sequence. size: [b] output_log_probs: log probability of the selected tokens. size: [b, s] """ args = get_args() tokenizer = get_tokenizer() batch_size = tokens.size(0) min_prompt_length = lengths.min().item() max_sequence_length = tokens.size(1) max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # If the context is too big, this happens if min_prompt_length >= max_sequence_length: raise ValueError("context length + tokens_to_generate too large") # forward step. forward_step = ForwardStep(model, batch_size, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. if hasattr(args, 'eos_id'): termination_id = args.eos_id else: termination_id = tokenizer.eod # =================== # Pre-allocate memory # =================== # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_sequence_length - 1) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = None if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) generated_sequence_lengths = torch.ones( batch_size, dtype=torch.int64, device=torch.cuda.current_device()) * max_sequence_length # Whether we have reached a termination id. is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) # ============= # Run infernece # ============= with torch.no_grad(): attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) prev_context_length = 0 for context_length in range(min_prompt_length, max_sequence_length): # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] attention_mask2use = attention_mask[ ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = forward_step(tokens2use, positions2use, attention_mask2use) if mpu.is_pipeline_last_stage(): # Always the last stage should have an output. assert logits is not None # Sample. last_token_logits = logits[:, -1, :] new_sample = sample(last_token_logits, top_k=top_k, top_p=top_p, temperature=temperature, vocab_size=tokenizer.vocab_size) # If a prompt length is smaller or equal th current context # length, it means we have started generating tokens started = lengths <= context_length # Update the tokens. tokens[started, context_length] = new_sample[started] # Calculate the log probabilities. if return_output_log_probs: log_probs = F.log_softmax(logits, dim=2) if return_output_log_probs: # Pick the tokens that we need to get the log # probabilities for. Note that next input token is # the token which we selected in the current logits, # so shift by 1. indices = torch.unsqueeze( tokens[ :, (prev_context_length + 1):(context_length + 1)], 2) output_log_probs[:, prev_context_length:context_length] = \ torch.gather(log_probs, 2, indices).squeeze(2) # Update the tokens on the first stage so the next input to # the network is correct. copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, tokens[:, context_length]) # Update the context length for the next token generation. prev_context_length = context_length # Check if all the sequences have hit the termination_id. done = None if mpu.is_pipeline_last_stage(): # TODO(rprenger) These stopping methods are tokenizer dependent # instead tokenization should be in the inference loop so stop sequences can be used if stop_on_double_eol: hit_double_eol = (new_sample == 628).byte() & started.byte() hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() done_token = hit_double_eol | hit_two_eols elif stop_on_eol: hit_double_eol = (new_sample == 628).byte() & started.byte() hit_eol = (new_sample == 198).byte() & started.byte() done_token = hit_double_eol | hit_eol else: done_token = (new_sample == termination_id).byte() & \ started.byte() just_finished = (done_token & ~is_generation_done).bool() generated_sequence_lengths[just_finished.view(-1)] = \ context_length + 1 is_generation_done = is_generation_done | done_token done = torch.all(is_generation_done) done = broadcast_from_last_pipeline_stage(1, torch.uint8, tensor=done) if use_eod_token_for_early_termination and done: break # =================================================== # Update the length of based on max generated length. # =================================================== tokens = tokens[:, :(context_length + 1)] if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = output_log_probs[:, :context_length] # ====================================== # Broadcast to the first pipeline stage. # ====================================== generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( batch_size, torch.int64, generated_sequence_lengths) if return_output_log_probs: output_log_probs_size = (batch_size, context_length) output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) return tokens, generated_sequence_lengths, output_log_probs ## from huggingface beam search class BeamHypotheses(object): def __init__(self, num_beams, length_penalty=1.0, early_stopping=False): """ Initialize n-best list of hypotheses. """ self.length_penalty = length_penalty self.early_stopping = early_stopping self.num_beams = num_beams self.beams = [] self.worst_score = 1e9 def __len__(self): """ Number of hypotheses in the list. """ return len(self.beams) def add(self, hyp, sum_logprobs, length): """ Add a new hypothesis to the list. """ score = sum_logprobs / length ** self.length_penalty if len(self) < self.num_beams or score > self.worst_score: self.beams.append((score, hyp)) if len(self) > self.num_beams: sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) del self.beams[sorted_scores[0][1]] self.worst_score = sorted_scores[1][0] else: self.worst_score = min(score, self.worst_score) def is_done(self, best_sum_logprobs, cur_len): """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. """ if len(self) < self.num_beams: return False elif self.early_stopping: return True else: cur_score = best_sum_logprobs / cur_len ** self.length_penalty ret = self.worst_score >= cur_score return ret def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty): args = get_args() tokenizer = get_tokenizer() batch_size = tokens.size(0) assert(batch_size == 1) prompt_length = lengths.item() final_sequence_length = tokens.size(1) final_sequence_length = min(final_sequence_length, args.max_position_embeddings) # If the context is too big, this happens if prompt_length >= final_sequence_length: raise ValueError("context length + tokens_to_generate too large") # forward step. forward_step = ForwardStep(model, beam_size, final_sequence_length) beam_hyp = BeamHypotheses(beam_size, length_penalty) done = False scores = torch.zeros(beam_size, dtype=torch.float32, device=torch.cuda.current_device()).unsqueeze(1) # ============= # Run infernece # ============= with torch.no_grad(): tokens = tokens.repeat(beam_size, 1) attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens) prev_context_length = 0 for context_length in range(prompt_length, final_sequence_length): # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] attention_mask2use = attention_mask[ ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = forward_step(tokens2use, positions2use, attention_mask2use) if mpu.is_pipeline_last_stage(): vocab_size = logits.size(2) log_probs = F.log_softmax(logits, dim=2) new_scores = log_probs[:, -1, :] + scores if context_length == prompt_length: # if this is the first one sorted_scores, indices = torch.sort(new_scores[0,:], descending=True) else: sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True) best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long() best_words = indices[:2 * beam_size] % vocab_size best_scores = sorted_scores[: 2 * beam_size] next_beams = [] for beam_token_rank, (token_id, beam_score, beam_id) in enumerate( zip(best_words, best_scores, best_beam_ids) ): if token_id.item() == stop_token: # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size if is_beam_token_worse_than_top_num_beams: continue beam_hyp.add( tokens[beam_id].clone(), beam_score, context_length + 1 - prompt_length ) else: # add next predicted token since it is not eos_token next_beams.append((token_id, beam_score, beam_id)) if len(next_beams) == beam_size: break if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): done = True break best_batches = tokens.new([item[2] for item in next_beams]) tokens = tokens[best_batches,:] tokens[:, context_length] = tokens.new([item[0] for item in next_beams]) scores = scores.new([item[1] for item in next_beams]).unsqueeze(1) # set inference key values to make it consistent with best beam index forward_step.inference_params.swap_key_value_dict(best_batches) # Update the tokens on the first stage so the next input to # the network is correct. copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, tokens[:, context_length]) # Update the context length for the next token generation. prev_context_length = context_length copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32, scores[:,0]) # if cannot find stop token, add open beams to hyps if not done: for beam_id in range(beam_size): beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length) # rank based on scores sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) num_return_gen = min(num_return_gen, len(sorted_hyps)) scores = [sorted_hyps[i][0] for i in range(num_return_gen)] tokens = [sorted_hyps[i][1] for i in range(num_return_gen)] scores = torch.stack(scores, dim=0) tokens = torch.stack(tokens, dim=0) return tokens, scores def _build_attention_mask_and_position_ids(tokens): """Build the attention mask and postition ids for the input tokens.""" # Since we are not interested in loss-mask and reset attention/position # is also False, eod_token is not used so it is safe to set it to None. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( data=tokens, eod_token=None, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False) return attention_mask, position_ids