Unverified Commit 518bd02c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generation] Fix Transition probs (#17311)

* [Draft] fix transition probs

* up

* up

* up

* make it work

* fix

* finish

* update
parent e8714c03
......@@ -212,6 +212,7 @@ class BeamSearchScorer(BeamScorer):
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
......@@ -256,9 +257,16 @@ class BeamSearchScorer(BeamScorer):
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
else:
beam_index = None
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
)
else:
# add next predicted token since it is not eos_token
......@@ -299,6 +307,7 @@ class BeamSearchScorer(BeamScorer):
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
......@@ -313,11 +322,13 @@ class BeamSearchScorer(BeamScorer):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
best_indices = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses
......@@ -327,23 +338,42 @@ class BeamSearchScorer(BeamScorer):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append to lists
# append hyp to lists
best.append(best_hyp)
# append indices to list
best_indices.append(best_index)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
if len(best_indices) > 0 and best_indices[0] is not None:
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
else:
indices = None
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
if indices is not None:
indices.fill_(-1)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo
if indices is not None:
indices[i, : len(best_idx)] = torch.tensor(best_idx)
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
......@@ -351,6 +381,7 @@ class BeamSearchScorer(BeamScorer):
{
"sequences": decoded,
"sequence_scores": best_scores,
"beam_indices": indices,
}
)
......@@ -789,6 +820,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
......@@ -801,6 +833,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
"sequences": decoded,
......@@ -826,15 +859,15 @@ class BeamHypotheses:
"""
return len(self.beams)
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
......
......@@ -217,8 +217,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, input_ids.shape[-1])`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
......@@ -230,7 +230,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
......@@ -254,8 +254,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, max_length-1)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
......@@ -278,7 +278,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
......@@ -303,8 +303,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, input_ids.shape[-1])`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
......@@ -316,7 +316,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
......@@ -339,9 +339,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, max_length-1)`.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
......@@ -362,7 +362,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
......@@ -811,32 +811,33 @@ class GenerationMixin:
"""compute the transition probabilities of sequences given generation
scores and beam indices"""
# reshape scores as [vocab_size * batch_size, # generation steps]
# 1. reshape scores as [vocab_size * batch_size, # generation steps]
# with batch_size being 2 * vocab_size and # generation steps being
# seq_len - input_length
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
# start of generated tokens
cut_idx = sequences.shape[-1] - scores.shape[-1]
# adjust for beam indices
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
# compute real indices
# 2. cut beam_indices to longest beam length
beam_indices_mask = beam_indices < 0
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
beam_indices = beam_indices[:, :max_beam_length]
beam_indices_mask = beam_indices_mask[:, :max_beam_length]
# 3. Set indices of beams that finished early to 0
# such indices will be masked correctly afterwards
beam_indices[beam_indices_mask] = 0
# 4. multiply beam_indices with vocab size to gather correctly from scores
beam_sequence_indices = beam_indices * self.config.vocab_size
# 5. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length
indices = sequences[:, cut_idx:] + beam_sequence_indices
# gather scores and run
# 6. Compute scores
transition_scores = scores.gather(0, indices)
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
# get first occurence of EOS token
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is not None:
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
# make sure first eos token still contributes to transition probs
is_eos_token_id[:, -1] = False
is_eos_token_id = is_eos_token_id.roll(1, -1)
# all indices after eos shoud be masked
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
# zero out padded probs
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
# 7. Mask out transition_scores of beams that stopped early
transition_scores[beam_indices_mask] = 0
return transition_scores
......@@ -2256,6 +2257,7 @@ class GenerationMixin:
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
)
beam_scores = beam_outputs["next_beam_scores"]
......@@ -2290,25 +2292,19 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
else:
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=beam_indices,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
......@@ -2320,7 +2316,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=beam_indices,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
......@@ -2580,6 +2576,7 @@ class GenerationMixin:
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
......@@ -2613,25 +2610,19 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
else:
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder:
return BeamSampleEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=beam_indices,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
......@@ -2643,7 +2634,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=beam_indices,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
......@@ -2909,6 +2900,7 @@ class GenerationMixin:
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
......@@ -2916,6 +2908,7 @@ class GenerationMixin:
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
......@@ -2971,6 +2964,7 @@ class GenerationMixin:
else:
this_peer_finished = True
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
......@@ -2979,26 +2973,19 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
else:
beam_indices = sum(beam_indices, ())
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=beam_indices,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
......@@ -3010,6 +2997,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
......
......@@ -126,7 +126,11 @@ class BeamSearchTester:
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
......@@ -136,7 +140,7 @@ class BeamSearchTester:
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
......@@ -161,10 +165,15 @@ class BeamSearchTester:
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
expected_beam_indices = list(range(10))
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
self.parent.assertListEqual(
expected_beam_indices + [next_indices[batch_idx, 1].item()],
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
......@@ -188,6 +197,8 @@ class BeamSearchTester:
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
......@@ -196,6 +207,7 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
......@@ -225,6 +237,7 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
......@@ -394,7 +407,7 @@ class ConstrainedBeamSearchTester:
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
def check_constrained_beam_scorer_finalize(
......
......@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_transition_scores_early_stopping(self):
# This is an aggressive test that makes sure that `beam_search's`
# transition scores are computed correctly for varying `num_return_sequences`,
# `num_beams` and `batch_size > 1`
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
torch_device
)
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
result = model.generate(
input_ids,
max_length=10,
return_dict_in_generate=True,
output_scores=True,
forced_eos_token_id=model.config.eos_token_id,
num_beams=4,
do_sample=False,
num_return_sequences=3,
length_penalty=0.0,
)
transition_scores = model.compute_transition_beam_scores(
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
)
sum_transition_scores = torch.sum(transition_scores, dim=1)
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
def test_log_scores_sample_decoder_only(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=15,
return_dict_in_generate=True,
do_sample=False,
output_scores=True,
)
# decoder-only starts generating from `input_ids`
begin_generation = inputs.input_ids.shape[-1]
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
def test_log_scores_sample_encoder_decoder(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=3,
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
output_scores=True,
)
# encoder-decoder has one decoder_start_token_id by default
begin_generation = 1
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
@slow
def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously
......@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow
def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
......@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow
def test_constrained_beam_search_mixed(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
......@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow
def test_constrained_beam_search_mixed_mixin(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment