Unverified Commit 2fdc7f6c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct greedy generation when doing beam search (#3078)

* correct greedy generation when doing beam search

* improve comment
parent 13afb712
...@@ -754,6 +754,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -754,6 +754,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else: else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
if pad_token_id is None and eos_token_ids is not None: if pad_token_id is None and eos_token_ids is not None:
logger.warning( logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
...@@ -764,7 +777,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -764,7 +777,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
cur_len = input_ids.shape[1] cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
if num_return_sequences != 1: if num_return_sequences != 1 and do_sample:
# Expand input to num return sequences # Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len) input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view( input_ids = input_ids.contiguous().view(
...@@ -787,6 +800,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -787,6 +800,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id, pad_token_id,
eos_token_ids, eos_token_ids,
effective_batch_size, effective_batch_size,
num_return_sequences,
length_penalty, length_penalty,
num_beams, num_beams,
vocab_size, vocab_size,
...@@ -826,6 +840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -826,6 +840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
All returned sequence are generated independantly. All returned sequence are generated independantly.
""" """
# current position / max lengths / length of generated sentences / unfinished sentences # current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1) unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length) sent_lengths = input_ids.new(batch_size).fill_(max_length)
...@@ -906,12 +921,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -906,12 +921,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id, pad_token_id,
eos_token_ids, eos_token_ids,
batch_size, batch_size,
num_return_sequences,
length_penalty, length_penalty,
num_beams, num_beams,
vocab_size, vocab_size,
): ):
""" Generate sequences for each example with beam search. """ Generate sequences for each example with beam search.
""" """
# Expand input to num beams # Expand input to num beams
# assert input_ids.shape == (batch_size * num_beams, cur_len) # assert input_ids.shape == (batch_size * num_beams, cur_len)
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
...@@ -1057,20 +1074,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1057,20 +1074,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item() input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
) )
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses # select the best hypotheses
sent_lengths = input_ids.new(batch_size) sent_lengths = input_ids.new(output_batch_size)
best = [] best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps): for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1] sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
sent_lengths[i] = len(best_hyp) for j in range(output_num_return_sequences_per_batch):
best.append(best_hyp) effective_batch_idx = output_num_return_sequences_per_batch * i + j
best_hyp = sorted_hyps.pop()[1]
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# shorter batches are filled with pad_token # shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined" assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length) sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id) decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary # fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
......
...@@ -621,10 +621,19 @@ class ModelTesterMixin: ...@@ -621,10 +621,19 @@ class ModelTesterMixin:
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(max_length=5, num_beams=3)) self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation
# is not allowed as it would always generate the same sequences
model.generate(input_ids, do_sample=False, num_return_sequences=2)
with self.assertRaises(AssertionError):
# generating more sequences than having beams leads is not possible
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample # batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
# batch_size > 1, greedy # batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3)) self._check_generated_tokens(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample # batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
# batch_size > 1, num_beams > 1, greedy # batch_size > 1, num_beams > 1, greedy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment