Commit 8e5587fb authored by thomwolf's avatar thomwolf
Browse files

few fixes on sampling

parent 641a8dec
......@@ -23,14 +23,12 @@ import json
import logging
import os
from io import open
import warnings
import six
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from tqdm import trange
from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
......@@ -82,7 +80,6 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
# Save config in model
self.config = config
......@@ -220,9 +217,6 @@ class PreTrainedModel(nn.Module):
# Tie weights if needed
self.tie_weights()
# Initialize decoding head if we have output embeddings
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
......@@ -569,30 +563,36 @@ class PreTrainedModel(nn.Module):
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_return_sequences, cur_len) # (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences
else:
effective_batch_size = batch_size
if num_beams > 1:
return self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
num_return_sequences,
length_penalty, num_beams, vocab_size)
return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size,
length_penalty, num_beams, vocab_size)
else:
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
num_return_sequences)
pad_token_id, eos_token_ids, effective_batch_size)
if num_return_sequences != 1:
output = output.view(batch_size, num_return_sequences, -1)
return output
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
num_return_sequences):
""" Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1).
pad_token_id, eos_token_ids, batch_size):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(batch_size*num_return_sequences, cur_len) # (batch_size*num_return_sequences, cur_len)
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size*num_return_sequences).fill_(1)
unfinished_sents = input_ids.new(batch_size).fill_(1)
# cache compute states
pasts = None
......@@ -604,7 +604,7 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size*num_return_sequences):
for i in range(batch_size):
for previous_tokens in set(input_ids[i].tolist()):
next_token_logits[i, previous_tokens] /= repetition_penalty
......@@ -635,22 +635,14 @@ class PreTrainedModel(nn.Module):
if cur_len == max_length:
input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0])
if num_return_sequences != 1:
input_ids = input_ids.view(batch_size, num_return_sequences, -1)
return input_ids
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
num_return_sequences,
length_penalty, num_beams, vocab_size):
""" Generate `num_return_sequences` sequences per batch example with beam search.
We return the top-`num_return_sequences` beams.
`num_return_sequences` should be bigger than `num_beams` (we default to the min of both)
""" Generate sequences for each example with beam search.
"""
num_return_sequences = min(num_return_sequences, num_beams)
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
......@@ -685,7 +677,7 @@ class PreTrainedModel(nn.Module):
if temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) # (batch_size * num_beams, vocab_size)
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
# Compute next scores
......@@ -778,41 +770,35 @@ class PreTrainedModel(nn.Module):
# print("")
# select the best hypotheses
tgt_len = input_ids.new(batch_size, num_return_sequences)
bests = []
tgt_len = input_ids.new(batch_size)
best = []
for i, hypotheses in enumerate(generated_hyps):
best_hyps = [hyp[1] for hyp in sorted(hypotheses.hyp, key=lambda hyp: hyp[0])[-num_return_sequences:]]
for j, hyp in enumerate(best_hyps):
tgt_len[i, j] = len(hyp) + 1 # +1 for the <EOS> symbol
bests.append(best_hyps)
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
best.append(best_hyp)
# generate target batch
decoded = input_ids.new(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id)
for i, hyps in enumerate(bests):
for j, hypo in enumerate(hyps):
decoded[i, j, :tgt_len[i, j] - 1] = hypo
decoded[i, j, tgt_len[i, j] - 1] = eos_token_ids[0]
if num_return_sequences == 1:
decoded = decoded.squeeze(1)
# # sanity check
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
decoded[i, :tgt_len[i] - 1] = hypo
decoded[i, tgt_len[i] - 1] = eos_token_ids[0]
return decoded
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf'), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
......@@ -821,8 +807,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment