Commit 641a8dec authored by thomwolf's avatar thomwolf
Browse files

clean up code and add arbitrary number of return sequences

parent 77d39720
...@@ -62,18 +62,19 @@ class PretrainedConfig(object): ...@@ -62,18 +62,19 @@ class PretrainedConfig(object):
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop('is_decoder', False)
# Parameters for sequence generation # Parameters for sequence generation
self.generate_max_length = kwargs.pop('generate_max_length', 20) self.max_length = kwargs.pop('max_length', 20)
self.generate_do_sample = kwargs.pop('generate_do_sample', False) self.do_sample = kwargs.pop('do_sample', False)
self.generate_num_beams = kwargs.pop('generate_num_beams', 1) self.num_beams = kwargs.pop('num_beams', 1)
self.generate_temperature = kwargs.pop('generate_temperature', 1.0) self.temperature = kwargs.pop('temperature', 1.0)
self.generate_top_k = kwargs.pop('generate_top_k', 50) self.top_k = kwargs.pop('top_k', 50)
self.generate_top_p = kwargs.pop('generate_top_p', 1.0) self.top_p = kwargs.pop('top_p', 1.0)
self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0) self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0)
self.generate_bos_token_id = kwargs.pop('generate_bos_token_id', 0) self.bos_token_id = kwargs.pop('bos_token_id', 0)
self.generate_pad_token_id = kwargs.pop('generate_pad_token_id', 0) self.pad_token_id = kwargs.pop('pad_token_id', 0)
self.generate_eos_token_ids = kwargs.pop('generate_eos_token_ids', 0) self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
self.generate_batch_size = kwargs.pop('generate_batch_size', 1) self.batch_size = kwargs.pop('batch_size', 1)
self.generate_length_penalty = kwargs.pop('generate_length_penalty', 1.) self.length_penalty = kwargs.pop('length_penalty', 1.)
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
......
...@@ -25,7 +25,6 @@ from torch import nn ...@@ -25,7 +25,6 @@ from torch import nn
from tqdm import trange from tqdm import trange
from .modeling_auto import AutoModel, AutoModelWithLMHead from .modeling_auto import AutoModel, AutoModelWithLMHead
from .modeling_utils import Sampler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -203,100 +202,6 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -203,100 +202,6 @@ class PreTrainedEncoderDecoder(nn.Module):
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
def decode(
self,
encoder_input_ids,
decoder_prompt_ids=None,
device=torch.device("cpu"),
length=10,
do_sample=False,
temperature=1.0,
k=9,
p=0.,
repetition_penalty=1.,
**kwargs
):
""" Generic sequence generator for encoder-decoder models.
For encoder-decoders the generation consists in:
- Performing a forward pass through the encoder once;
- Pass the encoder's hidden states to a decoding mechanism that
repeatedly calls the decoder to generate sequences.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if decoder_prompt_ids is None:
decoder_prompt_ids = torch.tensor([[]], dtype=torch.long, device=device)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if self.decoder.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a decoder that does not have a LM Head.")
# The followings checks that the decoder is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
decoder_device = next(self.decoder.parameters()).device
if decoder_device != decoder_prompt_ids.device:
warnings.warn(
"The decoder is not on the same device as the prompt. Expected {}, got {}.".format(
decoder_prompt_ids.device, decoder_device
)
)
kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
with torch.no_grad():
encoder_outputs = self.encoder(encoder_input_ids, **kwargs)
encoder_hidden_states = encoder_outputs[0]
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
sampler_config = {
"k": k,
"p": p,
"do_sample": do_sample,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
return self._greedy_decode_or_sample(
decoder_prompt_ids, length, sampler_config, **kwargs_decoder
)
def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **kwargs_decoder):
sampler = Sampler(**sampler_config)
with torch.no_grad():
generated_sequence = prompt_ids
for _ in trange(length):
arguments = self.decoder._prepare_inputs_for_decoding(generated_sequence, **kwargs_decoder)
outputs = self.decoder(**arguments)
next_tokens_logits = outputs[0][:, -1, :]
next_tokens = sampler.get_one_token(next_tokens_logits, generated_sequence)
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
return generated_sequence.squeeze(0)
@staticmethod @staticmethod
def prepare_model_kwargs(**kwargs): def prepare_model_kwargs(**kwargs):
""" Prepare the encoder and decoder's keyword arguments. """ Prepare the encoder and decoder's keyword arguments.
......
...@@ -494,7 +494,7 @@ class PreTrainedModel(nn.Module): ...@@ -494,7 +494,7 @@ class PreTrainedModel(nn.Module):
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None, def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None,
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None, bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
length_penalty=None, **kwargs): length_penalty=None, num_return_sequences=None, **kwargs):
""" Sequence generator for models with a LM head. """ Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
...@@ -526,18 +526,19 @@ class PreTrainedModel(nn.Module): ...@@ -526,18 +526,19 @@ class PreTrainedModel(nn.Module):
if self.get_output_embeddings() is None: if self.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
max_length = max_length if max_length is not None else self.config.generate_max_length max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.generate_do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.generate_num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.generate_temperature temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.generate_top_k top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.generate_top_p top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.generate_repetition_penalty repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generate_bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generate_pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.generate_eos_token_ids eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
batch_size = batch_size if batch_size is not None else self.config.generate_batch_size batch_size = batch_size if batch_size is not None else self.config.batch_size
length_penalty = length_penalty if length_penalty is not None else self.config.generate_length_penalty length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
if input_ids is not None: if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size batch_size = input_ids.shape[0] # overriden by the input batch_size
...@@ -547,8 +548,8 @@ class PreTrainedModel(nn.Module): ...@@ -547,8 +548,8 @@ class PreTrainedModel(nn.Module):
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be positive." assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k > 0, "`top_k` should be a strictely positive integer." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer." assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer."
...@@ -557,30 +558,41 @@ class PreTrainedModel(nn.Module): ...@@ -557,30 +558,41 @@ class PreTrainedModel(nn.Module):
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers." "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer." assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer."
assert length_penalty > 0, "`length_penalty` should be strictely positive." assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert isinstance(num_return_sequences, int) and num_return_sequences > 0, "`num_return_sequences` should be a strictely positive integer."
if input_ids is None: if input_ids is None:
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device) input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
else: else:
assert input_ids.dims() == 2, "Input prompt should be of shape (batch_size, sequence length)." assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# current position and vocab size # current position and vocab size
cur_len = input_ids.shape[1] cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
if num_beams > 1: if num_beams > 1:
return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, length_penalty, return self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size) temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
num_return_sequences,
length_penalty, num_beams, vocab_size)
return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty, temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size) pad_token_id, eos_token_ids, batch_size,
num_return_sequences)
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty, temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size): pad_token_id, eos_token_ids, batch_size,
""" Generate a sentence without beam search (num_beams == 1). """ num_return_sequences):
""" Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(batch_size*num_return_sequences, cur_len) # (batch_size*num_return_sequences, cur_len)
# current position / max lengths / length of generated sentences / unfinished sentences # current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1) unfinished_sents = input_ids.new(batch_size*num_return_sequences).fill_(1)
# cache compute states # cache compute states
pasts = None pasts = None
...@@ -592,9 +604,9 @@ class PreTrainedModel(nn.Module): ...@@ -592,9 +604,9 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size): for i in range(batch_size*num_return_sequences):
for _ in set(input_ids[i].tolist()): for previous_tokens in set(input_ids[i].tolist()):
next_token_logits[i, _] /= repetition_penalty next_token_logits[i, previous_tokens] /= repetition_penalty
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
...@@ -603,16 +615,16 @@ class PreTrainedModel(nn.Module): ...@@ -603,16 +615,16 @@ class PreTrainedModel(nn.Module):
# Top-p/top-k filtering # Top-p/top-k filtering
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# Sample # Sample
next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1)
else: else:
# Greedy decoding # Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) next_token = torch.argmax(next_token_logits, dim=-1)
# update generations and finished sentences # update generations and finished sentences
tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents) tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents)
input_ids = torch.cat([input_ids, tokens_to_add], dim=-1) input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
for eos_token_id in eos_token_ids: for eos_token_id in eos_token_ids:
unfinished_sents.mul_(tokens_to_add.squeeze(-1).ne(eos_token_id).long()) unfinished_sents.mul_(tokens_to_add.ne(eos_token_id).long())
cur_len = cur_len + 1 cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length # stop when there is a </s> in each sentence, or if we exceed the maximul length
...@@ -621,13 +633,24 @@ class PreTrainedModel(nn.Module): ...@@ -621,13 +633,24 @@ class PreTrainedModel(nn.Module):
# add eos_token_ids to unfinished sentences # add eos_token_ids to unfinished sentences
if cur_len == max_length: if cur_len == max_length:
input_ids[:, -1].masked_fill_(unfinished_sents.byte(), eos_token_ids[0]) input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0])
if num_return_sequences != 1:
input_ids = input_ids.view(batch_size, num_return_sequences, -1)
return input_ids return input_ids
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, length_penalty, def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size): temperature, top_k, top_p, repetition_penalty,
""" Generate a sentence with beam search. """ pad_token_id, eos_token_ids, batch_size,
num_return_sequences,
length_penalty, num_beams, vocab_size):
""" Generate `num_return_sequences` sequences per batch example with beam search.
We return the top-`num_return_sequences` beams.
`num_return_sequences` should be bigger than `num_beams` (we default to the min of both)
"""
num_return_sequences = min(num_return_sequences, num_beams)
# Expand input to num beams # Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
...@@ -638,7 +661,7 @@ class PreTrainedModel(nn.Module): ...@@ -638,7 +661,7 @@ class PreTrainedModel(nn.Module):
# scores for each sentence in the beam # scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states # cache compute states
pasts = None # self.prepare_pasts() pasts = None # self.prepare_pasts()
...@@ -650,16 +673,38 @@ class PreTrainedModel(nn.Module): ...@@ -650,16 +673,38 @@ class PreTrainedModel(nn.Module):
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_tokens in set(input_ids[i].tolist()):
scores[i, previous_tokens] /= repetition_penalty
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) # (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
# Compute next scores
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
else:
# do greedy beam search
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * num_beams, vocab_size) assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams) # re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2*num_beams, dim=1, largest=True, sorted=True)
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams) assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
# next batch beam content # next batch beam content
...@@ -733,32 +778,36 @@ class PreTrainedModel(nn.Module): ...@@ -733,32 +778,36 @@ class PreTrainedModel(nn.Module):
# print("") # print("")
# select the best hypotheses # select the best hypotheses
tgt_len = input_ids.new(batch_size) tgt_len = input_ids.new(batch_size, num_return_sequences)
best = [] bests = []
for i, hypotheses in enumerate(generated_hyps): for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] best_hyps = [hyp[1] for hyp in sorted(hypotheses.hyp, key=lambda hyp: hyp[0])[-num_return_sequences:]]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol for j, hyp in enumerate(best_hyps):
best.append(best_hyp) tgt_len[i, j] = len(hyp) + 1 # +1 for the <EOS> symbol
bests.append(best_hyps)
# generate target batch # generate target batch
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) decoded = input_ids.new(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best): for i, hyps in enumerate(bests):
decoded[i, :tgt_len[i] - 1] = hypo for j, hypo in enumerate(hyps):
decoded[i, tgt_len[i] - 1] = eos_token_ids[0] decoded[i, j, :tgt_len[i, j] - 1] = hypo
decoded[i, j, tgt_len[i, j] - 1] = eos_token_ids[0]
if num_return_sequences == 1:
decoded = decoded.squeeze(1)
# # sanity check # # sanity check
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size # assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
return decoded return decoded
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size x vocabulary size) logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
...@@ -768,7 +817,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ...@@ -768,7 +817,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p > 0.0: if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
......
# coding=utf-8
import sys
import unittest
import numpy as np
import pytest
from transformers import is_torch_available
if is_torch_available():
import torch
from transformers import (
BertConfig,
BertModel,
GPT2Config,
GPT2LMHeadModel,
OpenAIGPTConfig,
OpenAIGPTLMHeadModel,
TransfoXLConfig,
TransfoXLLMHeadModel,
XLMConfig,
XLMWithLMHeadModel,
XLNetConfig,
XLNetLMHeadModel,
Model2Model,
)
from transformers.modeling_utils import Sampler
else:
pytestmark = pytest.mark.skip("Require Torch")
class SamplerTest(unittest.TestCase):
def test_nucleus_sampling(self):
inf = -float("Inf")
test_cases = (
{
"p": 0,
"logits": torch.tensor([0.3, 0.1, 0.2]),
"expected": torch.tensor([0.3, 0.1, 0.2]),
},
{
"p": 0.01,
"logits": torch.tensor([0.3, 0.1, 0.2]),
"expected": torch.tensor([0.3, inf, inf]),
},
{
"p": 1,
"logits": torch.tensor([0.3, 0.1, 0.2]),
"expected": torch.tensor([0.3, 0.1, 0.2]),
},
{
"p": 0.2,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, inf]),
},
{
"p": 0.71,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, 0.2]),
},
{
"p": 0.71,
"logits": torch.tensor([0.1, 0.7, 0.2]),
"expected": torch.tensor([inf, 0.7, 0.2]),
},
{
"p": 0.71,
"logits": torch.tensor([0.7, 0.2, 0.1]),
"expected": torch.tensor([0.7, 0.2, inf]),
},
{
"p": 0.91,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, 0.1, 0.2]),
},
)
for case in test_cases:
config = {
"do_sample": True,
"temperature": 1.0,
"k": 0,
"p": case["p"],
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
filtered_logits = sampler.apply_nucleus_filter(case["logits"])
np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy())
def test_top_k_filter(self):
inf = -float("Inf")
test_cases = (
{
"k": 0,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, 0.1, 0.2]),
},
{
"k": 1,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, inf]),
},
{
"k": 2,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, 0.2]),
},
{
"k": 3,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, 0.1, 0.2]),
},
)
for case in test_cases:
config = {
"do_sample": True,
"temperature": 1.0,
"k": case["k"],
"p": 0,
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
filtered_logits = sampler.apply_top_k_filter(case["logits"])
np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy())
@pytest.mark.skipif(sys.version_info < (3, 2), reason="assertWarns() requires Python >= 3.2")
def test_wrong_k_value(self):
case = {"k": 10, "vocab_size": 5}
config = {
"do_sample": True,
"temperature": 1.0,
"k": case["k"],
"p": 0,
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
next_token_logits = torch.rand(case["vocab_size"]).unsqueeze(0)
past_sequence = torch.tensor([])
with self.assertWarns(UserWarning):
_ = sampler.get_one_token(next_token_logits, past_sequence)
def test_zero_temperature(self):
temperature = 0
config = {
"do_sample": True,
"temperature": temperature,
"k": 0,
"p": 0,
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
next_token_logits = torch.rand(10).unsqueeze(0)
past_sequence = torch.tensor([])
with self.assertRaises(ZeroDivisionError):
_ = sampler.get_one_token(next_token_logits, past_sequence)
class SamplerSingleStackTest(unittest.TestCase):
def test_raises_exception_when_no_LM_head(self):
models = [BertModel(BertConfig())]
for model in models:
with self.assertRaises(AttributeError):
model.decode()
@pytest.mark.slow
def test_forward_pass_and_output_length(self):
models = {
"XLNet": XLNetLMHeadModel(XLNetConfig()),
"XLM": XLMWithLMHeadModel(XLMConfig()),
"TransfoXL": TransfoXLLMHeadModel(TransfoXLConfig()),
"GPT2": GPT2LMHeadModel(GPT2Config()),
"GPT": OpenAIGPTLMHeadModel(OpenAIGPTConfig()),
}
kwargs = {
"XLNet": {},
"XLM": {"mask_token": 0},
"TransfoXL": {},
"GPT2": {},
"GPT": {},
}
prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
generated_length = 5
expected_length = 8
for name, model in models.items():
kwargs_model = kwargs[name]
output = model.decode(prompt_ids=prompt, length=generated_length, **kwargs_model)
self.assertEqual(len(output), expected_length)
class SamplerEncoderDecoderTest(unittest.TestCase):
@pytest.mark.slow
def test_forward_pass_and_output_length(self):
model = Model2Model.from_pretrained("bert-base-uncased")
encoder_input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
generated_length = 5
expected_length = 8
output = model.decode(
encoder_input_ids,
decoder_prompt_ids=prompt,
k=2,
p=0.5,
repetition_penalty=2,
length=generated_length,
)
self.assertEqual(len(output), expected_length)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment