Unverified Commit af37d183 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: documented function to compute the transition scores (#21191)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 91c2278b
...@@ -37,6 +37,7 @@ and how to create and save a customized generation configuration, refer to the ...@@ -37,6 +37,7 @@ and how to create and save a customized generation configuration, refer to the
[[autodoc]] generation.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- compute_transition_scores
- greedy_search - greedy_search
- sample - sample
- beam_search - beam_search
......
...@@ -924,42 +924,121 @@ class GenerationMixin: ...@@ -924,42 +924,121 @@ class GenerationMixin:
default_list.extend(custom_list) default_list.extend(custom_list)
return default_list return default_list
def compute_transition_beam_scores( def compute_transition_scores(
self, self,
sequences: torch.Tensor, sequences: torch.Tensor,
scores: Tuple[torch.Tensor], scores: Tuple[torch.Tensor],
beam_indices: torch.Tensor, beam_indices: Optional[torch.Tensor] = None,
eos_token_id: Union[int, List[int]] = None, normalize_logits: bool = False,
): ) -> torch.Tensor:
"""compute the transition probabilities of sequences given generation """
scores and beam indices""" Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was
used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time.
Parameters:
sequences (`torch.LongTensor`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
shorter if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)`):
Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens Tuple of
`torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with
each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, input_ids.shape[-1])`. Only required if a `num_beams>1` at
generate-time.
normalize_logits (`bool`, *optional*, defaults to `False`):
Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
# 1. reshape scores as [vocab_size * batch_size, # generation steps] Return:
# with batch_size being 2 * vocab_size and # generation steps being `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing
the transition scores (logits)
Examples:
```python
>>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
>>> import numpy as np
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer.pad_token_id = tokenizer.eos_token_id
>>> inputs = tokenizer(["Today is"], return_tensors="pt")
>>> # Example 1: Print the scores for each token generated with Greedy Search
>>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
>>> transition_scores = model.compute_transition_scores(
... outputs.sequences, outputs.scores, normalize_logits=True
... )
>>> input_length = inputs.input_ids.shape[1]
>>> generated_tokens = outputs.sequences[:, input_length:]
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
... # | token | token string | logits | probability
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.4f} | {np.exp(score.numpy()):.2%}")
| 262 | the | -1.4136 | 24.33%
| 1110 | day | -2.6089 | 7.36%
| 618 | when | -2.0096 | 13.40%
| 356 | we | -1.8593 | 15.58%
| 460 | can | -2.5083 | 8.14%
>>> # Example 2: Reconstruct the sequence scores from Beam Search
>>> outputs = model.generate(
... **inputs,
... max_new_tokens=5,
... num_beams=4,
... num_return_sequences=4,
... return_dict_in_generate=True,
... output_scores=True,
... )
>>> transition_scores = model.compute_transition_scores(
... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
... )
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # Tip: set `normalize_logits=True` to recompute the scores from the normalized logits.
>>> output_length = inputs.input_ids.shape[1] + np.sum(transition_scores.numpy() < 0, axis=1)
>>> length_penalty = model.generation_config.length_penalty
>>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
True
```"""
# 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent
# to a beam search approach were the first (and only) beam is always selected
if beam_indices is None:
beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device)
beam_indices = beam_indices.expand(-1, len(scores))
# 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being
# seq_len - input_length # seq_len - input_length
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
# 2. cut beam_indices to longest beam length # 3. Optionally normalize the logits (across the vocab dimension)
if normalize_logits:
scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1])
scores = torch.nn.functional.log_softmax(scores, dim=1)
scores = scores.reshape(-1, scores.shape[-1])
# 4. cut beam_indices to longest beam length
beam_indices_mask = beam_indices < 0 beam_indices_mask = beam_indices < 0
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
beam_indices = beam_indices[:, :max_beam_length] beam_indices = beam_indices[:, :max_beam_length]
beam_indices_mask = beam_indices_mask[:, :max_beam_length] beam_indices_mask = beam_indices_mask[:, :max_beam_length]
# 3. Set indices of beams that finished early to 0 # 5. Set indices of beams that finished early to 0
# such indices will be masked correctly afterwards # such indices will be masked correctly afterwards
beam_indices[beam_indices_mask] = 0 beam_indices[beam_indices_mask] = 0
# 4. multiply beam_indices with vocab size to gather correctly from scores # 6. multiply beam_indices with vocab size to gather correctly from scores
beam_sequence_indices = beam_indices * self.config.vocab_size beam_sequence_indices = beam_indices * self.config.vocab_size
# 5. Define which indices contributed to scores # 7. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length cut_idx = sequences.shape[-1] - max_beam_length
indices = sequences[:, cut_idx:] + beam_sequence_indices indices = sequences[:, cut_idx:] + beam_sequence_indices
# 6. Compute scores # 8. Compute scores
transition_scores = scores.gather(0, indices) transition_scores = scores.gather(0, indices)
# 7. Mask out transition_scores of beams that stopped early # 9. Mask out transition_scores of beams that stopped early
transition_scores[beam_indices_mask] = 0 transition_scores[beam_indices_mask] = 0
return transition_scores return transition_scores
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import inspect import inspect
import unittest import unittest
import numpy as np
from transformers import is_torch_available, pipeline from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -2485,6 +2487,58 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2485,6 +2487,58 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist()) self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
def test_transition_scores_greedy_search(self):
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores)
expected_scores = np.array(
[
[0.3596273, 0.39646253, 0.46157718, 0.4594633, 0.44866616],
[0.34934354, 0.4935004, 0.6373219, 0.5173545, 0.57517034],
]
)
self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores))
def test_transition_scores_greedy_search_normalized(self):
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
expected_scores = np.array(
[
[-6.5532393, -6.5158753, -6.451863, -6.4527144, -6.459402],
[-6.5685124, -6.4277077, -6.282607, -6.399295, -6.340927],
]
)
self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores))
def test_transition_scores_beam_search_encoder_decoder(self): def test_transition_scores_beam_search_encoder_decoder(self):
articles = [ articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.", "Justin Timberlake and Jessica Biel, welcome to parenthood.",
...@@ -2506,9 +2560,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2506,9 +2560,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids) outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores( transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1) transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
...@@ -2533,9 +2585,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2533,9 +2585,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids) outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores( transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1) transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
...@@ -2564,9 +2614,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2564,9 +2614,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids) outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores( transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1) transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
...@@ -2593,9 +2641,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2593,9 +2641,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids) outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores( transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1) transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
...@@ -2622,9 +2668,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2622,9 +2668,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids) outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores( transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1) transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
...@@ -2653,7 +2697,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2653,7 +2697,7 @@ class GenerationIntegrationTests(unittest.TestCase):
length_penalty=0.0, length_penalty=0.0,
) )
transition_scores = model.compute_transition_beam_scores( transition_scores = model.compute_transition_scores(
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment