"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1f6f32c24338ad1ff17475b836c7b4505da77714"
Unverified Commit bbabbc16 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #3135 from patrickvonplaten/refactor_beam_search_generate

Refactoring and bug fixing beam search generate
parents 7ac47bfe e33ed12c
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch BERT model.""" """PyTorch BERT model."""
import logging import logging
import os import os
import typing import typing
...@@ -758,6 +757,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -758,6 +757,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else: else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False: if do_sample is False:
if num_beams == 1: if num_beams == 1:
# no_beam_search greedy generation conditions # no_beam_search greedy generation conditions
...@@ -781,15 +781,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -781,15 +781,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
cur_len = input_ids.shape[1] cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
if num_return_sequences != 1 and do_sample: # set effective batch size and effective batch multiplier according to do_sample
# Expand input to num return sequences if do_sample:
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # shape: (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else: else:
effective_batch_size = batch_size effective_batch_size = batch_size
effective_batch_mult = 1
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if num_beams > 1: if num_beams > 1:
output = self._generate_beam_search( output = self._generate_beam_search(
...@@ -892,12 +898,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -892,12 +898,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# unfinished_sents is set to zero if eos in sentence # unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long()) unfinished_sents.mul_((~eos_in_sents).long())
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length # stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0: if unfinished_sents.max() == 0:
break break
cur_len = cur_len + 1
# if there are different sentences lengths in the batch, some batches have to be padded # if there are different sentences lengths in the batch, some batches have to be padded
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths" assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
...@@ -932,10 +938,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -932,10 +938,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example with beam search. """ Generate sequences for each example with beam search.
""" """
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
...@@ -943,7 +945,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -943,7 +945,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# scores for each sentence in the beam # scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times # Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False: if do_sample is False:
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
...@@ -996,6 +997,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -996,6 +997,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Compute next scores # Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else: else:
# do greedy beam search # do greedy beam search
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
...@@ -1006,6 +1010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1006,6 +1010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_scores = next_scores.view( next_scores = next_scores.view(
batch_size, num_beams * vocab_size batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size) ) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
...@@ -1041,14 +1046,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1041,14 +1046,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_id = idx // vocab_size beam_id = idx // vocab_size
token_id = idx % vocab_size token_id = idx % vocab_size
# add to generated hypotheses if end of sentence or last iteration effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if eos_token_ids is not None and token_id.item() in eos_token_ids: if eos_token_ids is not None and token_id.item() in eos_token_ids:
generated_hyps[batch_idx].add( generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(), input_ids[effective_beam_id].clone(), score.item(),
) )
else: else:
# add next predicted word if it is not eos_token # add next predicted word if it is not eos_token
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id)) next_sent_beam.append((score, token_id, effective_beam_id))
# the beam for next step is full # the beam for next step is full
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
...@@ -1073,25 +1079,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1073,25 +1079,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if past: if past:
past = self._reorder_cache(past, beam_idx) past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence # stop when we are done with each sentence
if all(done): if all(done):
break break
# update current length
cur_len = cur_len + 1
# finalize all open beam hypotheses and end to generated hypotheses
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps if done[batch_idx]:
if not done[batch_idx]: continue
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and word IDs # test that beam scores match previously calculated scores if not eos and batch_idx not done
beam_id = idx // vocab_size if eos_token_ids is not None and all(
token_id = idx % vocab_size (token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx]
generated_hyps[batch_idx].add( ):
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item() assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx]
) )
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(num_beams):
effective_beam_id = batch_idx * num_beams + beam_id
final_score = beam_scores[effective_beam_id].item()
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment