Unverified Commit af37d183 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: documented function to compute the transition scores (#21191)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 91c2278b
......@@ -37,6 +37,7 @@ and how to create and save a customized generation configuration, refer to the
[[autodoc]] generation.GenerationMixin
- generate
- compute_transition_scores
- greedy_search
- sample
- beam_search
......
......@@ -924,42 +924,121 @@ class GenerationMixin:
default_list.extend(custom_list)
return default_list
def compute_transition_beam_scores(
def compute_transition_scores(
self,
sequences: torch.Tensor,
scores: Tuple[torch.Tensor],
beam_indices: torch.Tensor,
eos_token_id: Union[int, List[int]] = None,
):
"""compute the transition probabilities of sequences given generation
scores and beam indices"""
beam_indices: Optional[torch.Tensor] = None,
normalize_logits: bool = False,
) -> torch.Tensor:
"""
Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was
used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time.
Parameters:
sequences (`torch.LongTensor`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
shorter if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)`):
Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens Tuple of
`torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with
each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, input_ids.shape[-1])`. Only required if a `num_beams>1` at
generate-time.
normalize_logits (`bool`, *optional*, defaults to `False`):
Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
Return:
`torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing
the transition scores (logits)
Examples:
# 1. reshape scores as [vocab_size * batch_size, # generation steps]
# with batch_size being 2 * vocab_size and # generation steps being
```python
>>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
>>> import numpy as np
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer.pad_token_id = tokenizer.eos_token_id
>>> inputs = tokenizer(["Today is"], return_tensors="pt")
>>> # Example 1: Print the scores for each token generated with Greedy Search
>>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
>>> transition_scores = model.compute_transition_scores(
... outputs.sequences, outputs.scores, normalize_logits=True
... )
>>> input_length = inputs.input_ids.shape[1]
>>> generated_tokens = outputs.sequences[:, input_length:]
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
... # | token | token string | logits | probability
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.4f} | {np.exp(score.numpy()):.2%}")
| 262 | the | -1.4136 | 24.33%
| 1110 | day | -2.6089 | 7.36%
| 618 | when | -2.0096 | 13.40%
| 356 | we | -1.8593 | 15.58%
| 460 | can | -2.5083 | 8.14%
>>> # Example 2: Reconstruct the sequence scores from Beam Search
>>> outputs = model.generate(
... **inputs,
... max_new_tokens=5,
... num_beams=4,
... num_return_sequences=4,
... return_dict_in_generate=True,
... output_scores=True,
... )
>>> transition_scores = model.compute_transition_scores(
... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
... )
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # Tip: set `normalize_logits=True` to recompute the scores from the normalized logits.
>>> output_length = inputs.input_ids.shape[1] + np.sum(transition_scores.numpy() < 0, axis=1)
>>> length_penalty = model.generation_config.length_penalty
>>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
True
```"""
# 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent
# to a beam search approach were the first (and only) beam is always selected
if beam_indices is None:
beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device)
beam_indices = beam_indices.expand(-1, len(scores))
# 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being
# seq_len - input_length
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
# 2. cut beam_indices to longest beam length
# 3. Optionally normalize the logits (across the vocab dimension)
if normalize_logits:
scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1])
scores = torch.nn.functional.log_softmax(scores, dim=1)
scores = scores.reshape(-1, scores.shape[-1])
# 4. cut beam_indices to longest beam length
beam_indices_mask = beam_indices < 0
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
beam_indices = beam_indices[:, :max_beam_length]
beam_indices_mask = beam_indices_mask[:, :max_beam_length]
# 3. Set indices of beams that finished early to 0
# 5. Set indices of beams that finished early to 0
# such indices will be masked correctly afterwards
beam_indices[beam_indices_mask] = 0
# 4. multiply beam_indices with vocab size to gather correctly from scores
# 6. multiply beam_indices with vocab size to gather correctly from scores
beam_sequence_indices = beam_indices * self.config.vocab_size
# 5. Define which indices contributed to scores
# 7. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length
indices = sequences[:, cut_idx:] + beam_sequence_indices
# 6. Compute scores
# 8. Compute scores
transition_scores = scores.gather(0, indices)
# 7. Mask out transition_scores of beams that stopped early
# 9. Mask out transition_scores of beams that stopped early
transition_scores[beam_indices_mask] = 0
return transition_scores
......
......@@ -17,6 +17,8 @@
import inspect
import unittest
import numpy as np
from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
......@@ -2485,6 +2487,58 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
def test_transition_scores_greedy_search(self):
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores)
expected_scores = np.array(
[
[0.3596273, 0.39646253, 0.46157718, 0.4594633, 0.44866616],
[0.34934354, 0.4935004, 0.6373219, 0.5173545, 0.57517034],
]
)
self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores))
def test_transition_scores_greedy_search_normalized(self):
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
expected_scores = np.array(
[
[-6.5532393, -6.5158753, -6.451863, -6.4527144, -6.459402],
[-6.5685124, -6.4277077, -6.282607, -6.399295, -6.340927],
]
)
self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores))
def test_transition_scores_beam_search_encoder_decoder(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
......@@ -2506,9 +2560,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
......@@ -2533,9 +2585,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
......@@ -2564,9 +2614,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
......@@ -2593,9 +2641,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
......@@ -2622,9 +2668,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
......@@ -2653,7 +2697,7 @@ class GenerationIntegrationTests(unittest.TestCase):
length_penalty=0.0,
)
transition_scores = model.compute_transition_beam_scores(
transition_scores = model.compute_transition_scores(
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment