Unverified Commit 5e069633 authored by Yacine Jernite's avatar Yacine Jernite Committed by GitHub
Browse files

Some changes to simplify the generation function (#5031)

* moving logits post-processing out of beam search

* moving logits post-processing out of beam search

* first step cache

* fix_Encoder_Decoder

* patrick_version_postprocess

* add_keyword_arg
parent 204ebc25
...@@ -983,10 +983,6 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -983,10 +983,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs" assert past is not None, "past has to be defined for encoder_outputs"
# first step, decoder_cached_states are empty
if not past[1]:
encoder_outputs, decoder_cached_states = past, None
else:
encoder_outputs, decoder_cached_states = past encoder_outputs, decoder_cached_states = past
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -298,7 +298,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -298,7 +298,7 @@ class EncoderDecoderModel(PreTrainedModel):
# first step # first step
if type(past) is tuple: if type(past) is tuple:
encoder_outputs = past encoder_outputs, _ = past
else: else:
encoder_outputs = (past,) encoder_outputs = (past,)
......
...@@ -1139,11 +1139,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1139,11 +1139,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs): def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs" assert past is not None, "past has to be defined for encoder_outputs"
# first step encoder_outputs, decoder_past_key_value_states = past
if len(past) < 2:
encoder_outputs, decoder_past_key_value_states = past, None
else:
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
...@@ -1156,7 +1152,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1156,7 +1152,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output # if decoder past is not included in output
# speedy decoding is disabled and no need to reorder # speedy decoding is disabled and no need to reorder
if len(past) < 2: if past[1] is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past return past
......
...@@ -813,6 +813,49 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -813,6 +813,49 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else: else:
lprobs[i, previous_token] /= repetition_penalty lprobs[i, previous_token] /= repetition_penalty
def postprocess_next_token_scores(
self,
scores,
input_ids,
no_repeat_ngram_size,
bad_words_ids,
cur_len,
min_length,
max_length,
eos_token_id,
repetition_penalty,
batch_size,
num_beams,
):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(
scores, batch_size, num_beams, input_ids, repetition_penalty,
)
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_batch_tokens = calc_banned_ngram_tokens(
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
for i, banned_tokens in enumerate(banned_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -1222,7 +1265,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1222,7 +1265,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
unfinished_sents = input_ids.new(batch_size).fill_(1) unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length) sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models past = (encoder_outputs, None) if encoder_outputs is not None else None
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation( model_inputs = self.prepare_inputs_for_generation(
...@@ -1232,40 +1275,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1232,40 +1275,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
scores = self.postprocess_next_token_scores(
scores=next_token_logits,
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
cur_len=cur_len,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=1,
)
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache): if self._use_cache(outputs, use_cache):
past = outputs[1] past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
for batch_idx in range(batch_size):
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
for batch_idx in range(batch_size):
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
next_token_logits[:, eos_token_id] = -float("inf")
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0: if temperature != 1.0:
next_token_logits = next_token_logits / temperature scores = scores / temperature
# Top-p/top-k filtering # Top-p/top-k filtering
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
# Sample # Sample
probs = F.softmax(next_token_logits, dim=-1) probs = F.softmax(next_token_logscores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else: else:
# Greedy decoding # Greedy decoding
...@@ -1300,18 +1335,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1300,18 +1335,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
) )
# if there are different sentences lengths in the batch, some batches have to be padded return input_ids
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
# finished sents are filled with pad_token
decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
else:
decoded = input_ids
for hypo_idx, hypo in enumerate(input_ids):
decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]
return decoded
def _generate_beam_search( def _generate_beam_search(
self, self,
...@@ -1357,7 +1381,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1357,7 +1381,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states # cache compute states
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models past = (encoder_outputs, None) if encoder_outputs is not None else None
# done sentences # done sentences
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
...@@ -1373,43 +1397,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1373,43 +1397,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if self._use_cache(outputs, use_cache): if self._use_cache(outputs, use_cache):
past = outputs[1] past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.prepare_logits_for_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached scores = self.postprocess_next_token_scores(
if eos_token_id is not None and cur_len < min_length: scores=scores,
scores[:, eos_token_id] = -float("inf") input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
if no_repeat_ngram_size > 0: bad_words_ids=bad_words_ids,
# calculate a list of banned tokens to prevent repetitively generating the same ngrams cur_len=cur_len,
num_batch_hypotheses = batch_size * num_beams min_length=min_length,
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 max_length=max_length,
banned_batch_tokens = calc_banned_ngram_tokens( eos_token_id=eos_token_id,
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=num_beams,
) )
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
if bad_words_ids is not None: if self.config.is_encoder_decoder and do_sample is False:
# calculate a list of banned tokens according to bad words # TODO (PVP) still a bit hacky here - there might be a better solution
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) scores = self.prepare_logits_for_generation(scores, cur_len=cur_len, max_length=max_length)
for i, banned_tokens in enumerate(banned_tokens):
scores[i, banned_tokens] = -float("inf")
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size) scores.shape, (batch_size * num_beams, vocab_size)
...@@ -1417,6 +1423,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1417,6 +1423,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if do_sample: if do_sample:
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Temperature
if temperature != 1.0:
_scores = _scores / temperature
# Top-p/top-k filtering # Top-p/top-k filtering
_scores = top_k_top_p_filtering( _scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment