Commit 8e5587fb authored by thomwolf's avatar thomwolf
Browse files

few fixes on sampling

parent 641a8dec
...@@ -23,14 +23,12 @@ import json ...@@ -23,14 +23,12 @@ import json
import logging import logging
import os import os
from io import open from io import open
import warnings
import six import six
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn import functional as F from torch.nn import functional as F
from tqdm import trange
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
...@@ -82,7 +80,6 @@ class PreTrainedModel(nn.Module): ...@@ -82,7 +80,6 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__ self.__class__.__name__, self.__class__.__name__
)) ))
# Save config in model # Save config in model
self.config = config self.config = config
...@@ -220,9 +217,6 @@ class PreTrainedModel(nn.Module): ...@@ -220,9 +217,6 @@ class PreTrainedModel(nn.Module):
# Tie weights if needed # Tie weights if needed
self.tie_weights() self.tie_weights()
# Initialize decoding head if we have output embeddings
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model. """ Prunes heads of the base model.
...@@ -569,30 +563,36 @@ class PreTrainedModel(nn.Module): ...@@ -569,30 +563,36 @@ class PreTrainedModel(nn.Module):
cur_len = input_ids.shape[1] cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_return_sequences, cur_len) # (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences
else:
effective_batch_size = batch_size
if num_beams > 1: if num_beams > 1:
return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty, temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size, pad_token_id, eos_token_ids, effective_batch_size,
num_return_sequences, length_penalty, num_beams, vocab_size)
length_penalty, num_beams, vocab_size) else:
return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty, temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size, pad_token_id, eos_token_ids, effective_batch_size)
num_return_sequences)
if num_return_sequences != 1:
output = output.view(batch_size, num_return_sequences, -1)
return output
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty, temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size, pad_token_id, eos_token_ids, batch_size):
num_return_sequences): """ Generate sequences for each example without beam search (num_beams == 1).
""" Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
""" """
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(batch_size*num_return_sequences, cur_len) # (batch_size*num_return_sequences, cur_len)
# current position / max lengths / length of generated sentences / unfinished sentences # current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size*num_return_sequences).fill_(1) unfinished_sents = input_ids.new(batch_size).fill_(1)
# cache compute states # cache compute states
pasts = None pasts = None
...@@ -604,7 +604,7 @@ class PreTrainedModel(nn.Module): ...@@ -604,7 +604,7 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size*num_return_sequences): for i in range(batch_size):
for previous_tokens in set(input_ids[i].tolist()): for previous_tokens in set(input_ids[i].tolist()):
next_token_logits[i, previous_tokens] /= repetition_penalty next_token_logits[i, previous_tokens] /= repetition_penalty
...@@ -635,22 +635,14 @@ class PreTrainedModel(nn.Module): ...@@ -635,22 +635,14 @@ class PreTrainedModel(nn.Module):
if cur_len == max_length: if cur_len == max_length:
input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0]) input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0])
if num_return_sequences != 1:
input_ids = input_ids.view(batch_size, num_return_sequences, -1)
return input_ids return input_ids
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty, temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size, pad_token_id, eos_token_ids, batch_size,
num_return_sequences,
length_penalty, num_beams, vocab_size): length_penalty, num_beams, vocab_size):
""" Generate `num_return_sequences` sequences per batch example with beam search. """ Generate sequences for each example with beam search.
We return the top-`num_return_sequences` beams.
`num_return_sequences` should be bigger than `num_beams` (we default to the min of both)
""" """
num_return_sequences = min(num_return_sequences, num_beams)
# Expand input to num beams # Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
...@@ -685,7 +677,7 @@ class PreTrainedModel(nn.Module): ...@@ -685,7 +677,7 @@ class PreTrainedModel(nn.Module):
if temperature != 1.0: if temperature != 1.0:
scores = scores / temperature scores = scores / temperature
# Top-p/top-k filtering # Top-p/top-k filtering
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) # (batch_size * num_beams, vocab_size) scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search) # Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2) next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
# Compute next scores # Compute next scores
...@@ -778,41 +770,35 @@ class PreTrainedModel(nn.Module): ...@@ -778,41 +770,35 @@ class PreTrainedModel(nn.Module):
# print("") # print("")
# select the best hypotheses # select the best hypotheses
tgt_len = input_ids.new(batch_size, num_return_sequences) tgt_len = input_ids.new(batch_size)
bests = [] best = []
for i, hypotheses in enumerate(generated_hyps): for i, hypotheses in enumerate(generated_hyps):
best_hyps = [hyp[1] for hyp in sorted(hypotheses.hyp, key=lambda hyp: hyp[0])[-num_return_sequences:]] best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
for j, hyp in enumerate(best_hyps): tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
tgt_len[i, j] = len(hyp) + 1 # +1 for the <EOS> symbol best.append(best_hyp)
bests.append(best_hyps)
# generate target batch # generate target batch
decoded = input_ids.new(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id) decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hyps in enumerate(bests): for i, hypo in enumerate(best):
for j, hypo in enumerate(hyps): decoded[i, :tgt_len[i] - 1] = hypo
decoded[i, j, :tgt_len[i, j] - 1] = hypo decoded[i, tgt_len[i] - 1] = eos_token_ids[0]
decoded[i, j, tgt_len[i, j] - 1] = eos_token_ids[0]
if num_return_sequences == 1:
decoded = decoded.squeeze(1)
# # sanity check
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
return decoded return decoded
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')): def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf'), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size x vocabulary size) logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0: if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
...@@ -821,8 +807,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf') ...@@ -821,8 +807,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment