"vscode:/vscode.git/clone" did not exist on "aa4b2cc36fedcaa4bfacdaad4c6a878fc82b9860"
Commit c47394b0 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactoring and bug fixing beam search generate

parent ff9e79ba
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
# limitations under the License. # limitations under the License.
"""PyTorch BERT model.""" """PyTorch BERT model."""
import logging import logging
import os import os
import typing import typing
import ipdb
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -758,6 +758,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -758,6 +758,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else: else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False: if do_sample is False:
if num_beams == 1: if num_beams == 1:
# no_beam_search greedy generation conditions # no_beam_search greedy generation conditions
...@@ -781,15 +782,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -781,15 +782,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
cur_len = input_ids.shape[1] cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
if num_return_sequences != 1 and do_sample: # set effective batch size and effective batch multiplier according to do_sample
# Expand input to num return sequences if do_sample:
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # shape: (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else: else:
effective_batch_size = batch_size effective_batch_size = batch_size
effective_batch_mult = 1
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if num_beams > 1: if num_beams > 1:
output = self._generate_beam_search( output = self._generate_beam_search(
...@@ -892,12 +899,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -892,12 +899,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# unfinished_sents is set to zero if eos in sentence # unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long()) unfinished_sents.mul_((~eos_in_sents).long())
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length # stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0: if unfinished_sents.max() == 0:
break break
cur_len = cur_len + 1
# if there are different sentences lengths in the batch, some batches have to be padded # if there are different sentences lengths in the batch, some batches have to be padded
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths" assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
...@@ -932,10 +939,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -932,10 +939,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example with beam search. """ Generate sequences for each example with beam search.
""" """
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
...@@ -945,7 +948,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -945,7 +948,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times # Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False: # if do_sample is False:
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
...@@ -996,6 +999,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -996,6 +999,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Compute next scores # Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else: else:
# do greedy beam search # do greedy beam search
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
...@@ -1006,6 +1012,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1006,6 +1012,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_scores = next_scores.view( next_scores = next_scores.view(
batch_size, num_beams * vocab_size batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size) ) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
...@@ -1041,14 +1048,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1041,14 +1048,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_id = idx // vocab_size beam_id = idx // vocab_size
token_id = idx % vocab_size token_id = idx % vocab_size
# add to generated hypotheses if end of sentence or last iteration effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if eos_token_ids is not None and token_id.item() in eos_token_ids: if eos_token_ids is not None and token_id.item() in eos_token_ids:
generated_hyps[batch_idx].add( generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(), input_ids[effective_beam_id].clone(), score.item(),
) )
else: else:
# add next predicted word if it is not eos_token # add next predicted word if it is not eos_token
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id)) next_sent_beam.append((score, token_id, effective_beam_id))
# the beam for next step is full # the beam for next step is full
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
...@@ -1073,25 +1081,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1073,25 +1081,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if past: if past:
past = self._reorder_cache(past, beam_idx) past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence # stop when we are done with each sentence
if all(done): if all(done):
break break
# update current length
cur_len = cur_len + 1
# finalize all open beam hypotheses and end to generated hypotheses
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps if done[batch_idx]:
if not done[batch_idx]: continue
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and word IDs # test that beam scores match previously calculated scores if not eos and batch_idx not done
beam_id = idx // vocab_size if eos_token_ids is not None and all(
token_id = idx % vocab_size (token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx]
generated_hyps[batch_idx].add( ):
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item() assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx]
) )
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(num_beams):
effective_beam_id = batch_idx * num_beams + beam_id
final_score = beam_scores[effective_beam_id].item()
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment