Unverified Commit ef9c3ca3 authored by Chan Woo Kim's avatar Chan Woo Kim Committed by GitHub
Browse files

[Bug Fix] Beam search example in docs fails & a fix (integrating `max_length`...

[Bug Fix] Beam search example in docs fails & a fix (integrating `max_length` in `BeamScorer.finalize()`) (#15555)

* added the test and fix

* had left out a comment
parent 9932ee4b
...@@ -332,7 +332,8 @@ class BeamSearchScorer(BeamScorer): ...@@ -332,7 +332,8 @@ class BeamSearchScorer(BeamScorer):
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos # prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, max_length) sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed # shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
...@@ -341,7 +342,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -341,7 +342,7 @@ class BeamSearchScorer(BeamScorer):
# fill with hypotheses and eos_token_id if the latter fits in # fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length: if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id decoded[i, sent_lengths[i]] = eos_token_id
return UserDict( return UserDict(
......
...@@ -2315,6 +2315,48 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2315,6 +2315,48 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously
# failed after directly copying from it. Refer to PR #15555
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
model_kwargs = {
"encoder_outputs": model.get_encoder()(
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
)
}
# instantiate beam scorer
beam_scorer = BeamSearchScorer(
batch_size=1,
num_beams=num_beams,
device=model.device,
)
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
]
)
outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alt bist du?"])
@slow @slow
def test_constrained_beam_search(self): def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment