Unverified Commit a1bbcf3f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Refactoring the generate() function (#6949)

* first draft

* show design proposition for new generate method

* up

* make better readable

* make first version

* gpt2 tests pass

* make beam search for gpt2 work

* add first encoder-decoder code

* delete typo

* make t5 work

* save indermediate

* make bart work with beam search

* finish beam search bart / t5

* add default kwargs

* make more tests pass

* fix no bad words sampler

* some fixes and tests for all distribution processors

* fix test

* fix rag slow tests

* merge to master

* add nograd to generate

* make all slow tests pass

* speed up generate

* fix edge case bug

* small fix

* correct typo

* add type hints and docstrings

* fix typos in tests

* add beam search tests

* add tests for beam scorer

* fix test rag

* finish beam search tests

* move generation tests in seperate file

* fix generation tests

* more tests

* add aggressive generation tests

* fix tests

* add gpt2 sample test

* add more docstring

* add more docs

* finish doc strings

* apply some more of sylvains and sams comments

* fix some typos

* make fix copies

* apply lysandres and sylvains comments

* final corrections on examples

* small fix for reformer
parent b63beb74
...@@ -272,3 +272,4 @@ conversion utilities for the following models: ...@@ -272,3 +272,4 @@ conversion utilities for the following models:
internal/pipelines_utils internal/pipelines_utils
internal/tokenization_utils internal/tokenization_utils
internal/trainer_utils internal/trainer_utils
internal/generation_utils
Utilities for Generation
-----------------------------------------------------------------------------------------------------------------------
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
:meth:`~transformers.PretrainedModel.beam_search`, and :meth:`~transformers.PretrainedModel.beam_sample`.
Most of those are only useful if you are studying the code of the generate methods in the library.
LogitsProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A :class:`~transformers.LogitsProcessor` can be used to modify the prediction scores of a language model head for
generation.
.. autoclass:: transformers.LogitsProcessor
:members: __call__
.. autoclass:: transformers.LogitsProcessorList
:members: __call__
.. autoclass:: transformers.MinLengthLogitsProcessor
:members: __call__
.. autoclass:: transformers.TemperatureLogitsWarper
:members: __call__
.. autoclass:: transformers.RepetitionPenaltyLogitsProcessor
:members: __call__
.. autoclass:: transformers.TopPLogitsWarper
:members: __call__
.. autoclass:: transformers.TopKLogitsWarper
:members: __call__
.. autoclass:: transformers.NoRepeatNGramLogitsProcessor
:members: __call__
.. autoclass:: transformers.NoBadWordsLogitsProcessor
:members: __call__
BeamSearch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BeamScorer
:members: process, finalize
.. autoclass:: transformers.BeamSearchScorer
:members: process, finalize
...@@ -45,7 +45,7 @@ TFModelUtilsMixin ...@@ -45,7 +45,7 @@ TFModelUtilsMixin
:members: :members:
Generative models Generation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.generation_utils.GenerationMixin .. autoclass:: transformers.generation_utils.GenerationMixin
......
...@@ -299,6 +299,19 @@ if is_torch_available(): ...@@ -299,6 +299,19 @@ if is_torch_available():
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from .generation_utils import top_k_top_p_filtering from .generation_utils import top_k_top_p_filtering
from .modeling_albert import ( from .modeling_albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple
import torch
from .file_utils import add_start_docstrings
PROCESS_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses.
next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
:obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses.
next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
Return:
:obj:`UserDict`: A dictionary composed of the fields as defined above:
- **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated
scores of all non-finished beams.
- **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens
to be added to the non-finished beam_hypotheses.
- **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices
indicating to which beam the next tokens shall be added.
"""
FINALIZE_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The final scores of all non-finished beams.
final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The last tokens to be added to the non-finished beam_hypotheses.
final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
"""
class BeamScorer(ABC):
"""
Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
:meth:`~transformers.PretrainedModel.beam_sample`.
"""
@abstractmethod
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> Tuple[torch.Tensor]:
raise NotImplementedError("This is an abstract method.")
@abstractmethod
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
def finalize(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.")
class BeamSearchScorer(BeamScorer):
r"""
:class:`transformers.BeamScorer` implementing standard beam search decoding.
Adapted in part from `Facebook's XLM beam search code
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
Args:
batch_size (:obj:`int`):
Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
max_length (:obj:`int`):
The maximum length of the sequence to be generated.
num_beams (:obj:`int`):
Number of beams for beam search.
device (:obj:`torch.device`):
Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
:obj:`BeamSearchScorer` will be allocated.
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
sequences.
do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
:meth:`~transformer.BeamSearchScorer.finalize`.
"""
def __init__(
self,
batch_size: int,
max_length: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
assert batch_size == (input_ids.shape[0] // self.num_beams)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
assert (
len(beam_hyp) >= self.num_beams
), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.num_beams + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.num_beams:
break
if beam_idx < self.num_beams:
raise ValueError(
f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.LongTensor:
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp = sorted_hyps.pop()[1]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
best.append(best_hyp)
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return decoded
class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Iterable, List
import numpy as np
import torch
from torch.nn import functional as F
from .file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
Return:
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class LogitsProcessor(ABC):
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsWarper(ABC):
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsProcessorList(list):
"""
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
:class:`~transformers.LogitsProcessor` to the inputs.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for processor in self:
scores = processor(input_ids, scores)
return scores
class MinLengthLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
Args:
min_length (:obj:`int`):
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
return scores
class TemperatureLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
Args:
temperature (:obj:`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores = scores / self.temperature
return scores
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (:obj:`float`):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for i in range(scores.shape[0]):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= self.penalty
else:
scores[i, previous_token] /= self.penalty
return scores
class TopPLogitsWarper(LogitsWarper):
"""
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
prob_cut_off.
Args:
top_p (:obj:`float`):
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
kept for generation.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores[indices_to_remove] = self.filter_value
return scores
class TopKLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (:obj:`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores[indices_to_remove] = self.filter_value
return scores
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
Args:
ngram_size (:obj:`int`):
All ngrams of size :obj:`ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
def _calc_banned_ngram_tokens(
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < self.ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - self.ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
class NoBadWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled.
Args:
bad_words_ids (:obj:`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
for banned_token_seq in self.bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
banned_tokens = self._calc_banned_bad_words_ids(input_ids)
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
"""
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if not banned_mask_list:
return scores
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores
# coding=utf-8 # coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. # Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,13 +14,23 @@ ...@@ -14,13 +14,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from .file_utils import ModelOutput from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from .utils import logging from .utils import logging
...@@ -33,85 +43,245 @@ class GenerationMixin: ...@@ -33,85 +43,245 @@ class GenerationMixin:
:class:`~transformers.PreTrainedModel`. :class:`~transformers.PreTrainedModel`.
""" """
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
""" """
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the
generate method. generate method.
""" """
return {"input_ids": input_ids} return {"input_ids": input_ids}
def adjust_logits_during_generation(self, logits, **kwargs): def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
""" """
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method. the generate method.
""" """
return logits return logits
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor:
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(
self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int
) -> torch.LongTensor:
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
return input_ids.ne(pad_token_id).long()
return input_ids.new_ones(input_ids.shape)
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs
) -> Dict[str, Any]:
# retrieve encoder hidden states
encoder = self.get_encoder()
encoder_kwargs = {
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
}
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs
) -> torch.LongTensor:
if "decoder_input_ids" in model_kwargs:
return model_kwargs["decoder_input_ids"]
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids = (
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
* decoder_start_token_id
)
return decoder_input_ids
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@staticmethod
def _expand_inputs_for_generation(
input_ids: torch.LongTensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: torch.LongTensor = None,
encoder_outputs: ModelOutput = None,
**model_kwargs
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = (
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
)
input_ids = input_ids.index_select(0, expanded_return_idx)
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
if is_encoder_decoder:
assert encoder_outputs is not None
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx
)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
@staticmethod
def _init_sequence_length_for_generation(
input_ids: torch.LongTensor, max_length: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length)
cur_len = input_ids.shape[-1]
return sequence_lengths, unfinished_sequences, cur_len
@staticmethod
def _update_seq_length_for_generation(
sequence_lengths: torch.LongTensor,
unfinished_sequences: torch.LongTensor,
cur_len: int,
is_eos_in_next_token: torch.BoolTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
# check if sentence is not finished yet
is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool()
# update sentence length
sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len)
unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long())
return sequence_lengths, unfinished_sequences
@staticmethod
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
else:
model_kwargs["past"] = None
# update attention mask
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return model_kwargs
@staticmethod
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
""" """
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__). This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
generation step.
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
subclasses of :class:`~transformers.PreTrainedModel`.
""" """
for i in range(batch_size * num_beams): return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
def postprocess_next_token_scores(
self,
scores,
input_ids,
no_repeat_ngram_size,
bad_words_ids,
cur_len,
min_length,
max_length,
eos_token_id,
repetition_penalty,
batch_size,
num_beams,
):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(
scores,
batch_size,
num_beams,
input_ids,
repetition_penalty,
)
# set eos token prob to zero if min_length is not reached def _get_logits_warper(
if eos_token_id is not None and cur_len < min_length: self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
scores[:, eos_token_id] = -float("inf") ) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsWarper` instances used for multinomial sampling.
"""
if no_repeat_ngram_size > 0: # init warp parameters
# calculate a list of banned tokens to prevent repetitively generating the same ngrams top_k = top_k if top_k is not None else self.config.top_k
num_batch_hypotheses = batch_size * num_beams top_p = top_p if top_p is not None else self.config.top_p
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 temperature = temperature if temperature is not None else self.config.temperature
banned_batch_tokens = calc_banned_ngram_tokens( # instantiate warpers list
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len warpers = LogitsProcessorList()
)
for i, banned_tokens in enumerate(banned_batch_tokens): # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
scores[i, banned_tokens] = -float("inf") # all samplers can be found in `generation_utils_samplers.py`
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
return warpers
def _get_logits_processor(
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
"""
# init warp parameters
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# instantiate processors list
processors = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if bad_words_ids is not None: if bad_words_ids is not None:
# Exclude EOS token (already processed) processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) if min_length is not None and eos_token_id is not None and min_length > -1:
# calculate a list of banned tokens according to bad words processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) return processors
# Modify the scores in place by setting the banned tokens logits to `-inf`
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
return scores
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
min_length: Optional[int] = None, min_length: Optional[int] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
...@@ -128,17 +298,13 @@ class GenerationMixin: ...@@ -128,17 +298,13 @@ class GenerationMixin:
length_penalty: Optional[float] = None, length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None, no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None, num_return_sequences: Optional[int] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
**model_kwargs **model_kwargs
) -> torch.LongTensor: ) -> torch.LongTensor:
r""" r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding, Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
Adapted in part from `Facebook's XLM beam search code
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
...@@ -152,9 +318,6 @@ class GenerationMixin: ...@@ -152,9 +318,6 @@ class GenerationMixin:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`. :obj:`torch.LongTensor` of shape :obj:`(1,)`.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
decoder_start_token_id is passed as the first token to the decoder.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
min_length (:obj:`int`, `optional`, defaults to 10): min_length (:obj:`int`, `optional`, defaults to 10):
...@@ -170,7 +333,7 @@ class GenerationMixin: ...@@ -170,7 +333,7 @@ class GenerationMixin:
top_k (:obj:`int`, `optional`, defaults to 50): top_k (:obj:`int`, `optional`, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering. The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (:obj:`float`, `optional`, defaults to 1.0): top_p (:obj:`float`, `optional`, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
higher are kept for generation. higher are kept for generation.
repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper The parameter for repetition penalty. 1.0 means no penalty. See `this paper
...@@ -182,792 +345,854 @@ class GenerationMixin: ...@@ -182,792 +345,854 @@ class GenerationMixin:
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token. The id of the `end-of-sequence` token.
length_penalty (:obj:`float`, `optional`, defaults to 1.0): length_penalty (:obj:`float`, `optional`, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty. Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in sequences.
order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once. If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(:obj:`List[int]`, `optional`): bad_words_ids(:obj:`List[List[int]]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
num_return_sequences(:obj:`int`, `optional`, defaults to 1): num_return_sequences(:obj:`int`, `optional`, defaults to 1):
The number of independently computed returned sequences for each element in the batch. The number of independently computed returned sequences for each element in the batch.
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
tokens that are not masked, and 0 for masked tokens. tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same
shape as :obj:`input_ids` that masks the pad token. `What are attention masks?
If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. <../glossary.html#attention-mask>`__
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_start_token_id (:obj:`int`, `optional`): decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding. speed up decoding.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
kwargs should be prefixed with `decoder_`.
Return: Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`. batches finished early due to the :obj:`eos_token_id`.
Examples:: Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40) # do greedy decoding >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> # do greedy decoding without providing a prompt
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer >>> outputs = model.generate(max_length=40)
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
for i in range(3): # 3 output sequences were generated >>> document = (
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) ... "at least two people were killed in a suspected bomb attack on a passenger bus "
... "in the strife-torn southern philippines on monday , the military said."
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer ... )
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. >>> # encode input contex
input_context = 'The dog' >>> input_ids = tokenizer(document, return_tensors="pt").input_ids
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context >>> # generate 3 independent sequences using beam search decoding (5 beams)
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling >>> # with T5 encoder-decoder model conditioned on short news article.
for i in range(3): # 3 output sequences were generated >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl >>> input_context = "The dog"
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context >>> # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) >>> # generate 3 candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl >>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] >>> model = AutoModelForCausalLM.from_pretrained("ctrl")
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context >>> # "Legal" is one of the control codes for ctrl
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated >>> input_context = "Legal My neighbor is"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "My cute dog"
>>> # get tokens of words that should not be generated
>>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]]
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate sequences without allowing bad_words to be generated
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
""" """
# We cannot generate if the model does not have a LM head # set init values
if self.get_output_embeddings() is None: num_beams = num_beams if num_beams is not None else self.config.num_beams
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
)
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
) )
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
if input_ids is not None: pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
batch_size = input_ids.shape[0] # overridden by the input batch_size bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
else: eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
batch_size = 1 use_cache = use_cache if use_cache is not None else self.config.use_cache
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert input_ids is not None or (
isinstance(bos_token_id, int) and bos_token_id >= 0
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
), "`no_repeat_ngram_size` should be a positive integer."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictly positive integer."
assert (
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if input_ids is None: if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( # init `input_ids` with bos_token_id
"you should either supply a context to complete as `input_ids` input " input_ids = self._prepare_input_ids_for_generation(bos_token_id)
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
) if model_kwargs.get("attention_mask", None) is None:
input_ids = torch.full( # init `attention_mask` depending on `pad_token_id`
(batch_size, 1), model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
bos_token_id, input_ids, pad_token_id, eos_token_id
dtype=torch.long,
device=next(self.parameters()).device,
) )
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." # special case if pad_token_id is not defined
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created
if pad_token_id is None and eos_token_id is not None: if pad_token_id is None and eos_token_id is not None:
logger.warning( logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_id pad_token_id = eos_token_id
# vocab size
if hasattr(self.config, "vocab_size"):
vocab_size = self.config.vocab_size
elif (
self.config.is_encoder_decoder
and hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "vocab_size")
):
vocab_size = self.config.decoder.vocab_size
else:
raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else:
effective_batch_size = batch_size
effective_batch_mult = 1
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
if decoder_start_token_id is None: # add encoder_outputs to model_kwargs
# see if BOS token can be used for decoder_start_token_id model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
if bos_token_id is not None:
decoder_start_token_id = bos_token_id # set input_ids as decoder_input_ids
elif ( input_ids = self._prepare_decoder_input_ids_for_generation(
hasattr(self.config, "decoder") input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
) )
input_ids = input_ids.contiguous().view( if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
effective_batch_size * num_beams, input_ids_len raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = attention_mask.contiguous().view( # determine generation mode
effective_batch_size * num_beams, input_ids_len is_greedy_gen_mode = (num_beams == 1) and do_sample is False
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True
# set model_kwargs
model_kwargs["use_cache"] = use_cache
# get distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
)
if self.config.is_encoder_decoder: if is_greedy_gen_mode:
device = next(self.parameters()).device if num_return_sequences > 1:
if decoder_input_ids is not None: raise ValueError(
# give initial decoder input ids f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
else:
# create empty decoder input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id,
dtype=torch.long,
device=device,
) )
cur_len = input_ids.shape[-1]
# greedy search
assert ( return self.greedy_search(
batch_size == encoder_outputs.last_hidden_state.shape[0] input_ids,
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " logits_processor=logits_processor,
max_length=max_length,
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) pad_token_id=pad_token_id,
expanded_batch_idxs = ( eos_token_id=eos_token_id,
torch.arange(batch_size) **model_kwargs,
.view(-1, 1)
.repeat(1, num_beams * effective_batch_mult)
.view(-1)
.to(input_ids.device)
) )
# expand encoder_outputs elif is_sample_gen_mode:
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( # get probability distribution warper
0, expanded_batch_idxs logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
) )
# save encoder_outputs in `model_kwargs` # expand input_ids with `num_return_sequences` additional sequences per batch
model_kwargs["encoder_outputs"] = encoder_outputs input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
else: # sample
cur_len = input_ids.shape[-1] return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs,
)
assert ( elif is_beam_gen_mode:
cur_len < max_length batch_size = input_ids.shape[0]
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if num_beams > 1: beam_scorer = BeamSearchScorer(
output = self._generate_beam_search( batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
# interleave with `num_beams`
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
return self.beam_search(
input_ids, input_ids,
cur_len=cur_len, beam_scorer,
logits_processor=logits_processor,
max_length=max_length, max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
batch_size=effective_batch_size, **model_kwargs,
num_return_sequences=num_return_sequences, )
length_penalty=length_penalty,
elif is_beam_sample_gen_mode:
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
)
batch_size = input_ids.shape[0] * num_return_sequences
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
vocab_size=vocab_size, device=self.device,
attention_mask=attention_mask, length_penalty=length_penalty,
use_cache=use_cache, do_early_stopping=early_stopping,
model_kwargs=model_kwargs,
) )
else:
output = self._generate_no_beam_search( # interleave with `num_beams * num_return_sequences`
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_beams * num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
return self.beam_sample(
input_ids, input_ids,
cur_len=cur_len, beam_scorer,
logits_processor=logits_processor,
logits_warper=logits_warper,
max_length=max_length, max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
batch_size=effective_batch_size, **model_kwargs,
attention_mask=attention_mask,
use_cache=use_cache,
model_kwargs=model_kwargs,
) )
return output def greedy_search(
def _generate_no_beam_search(
self, self,
input_ids, input_ids: torch.LongTensor,
cur_len, logits_processor: Optional[LogitsProcessorList] = None,
max_length, max_length: Optional[int] = None,
min_length, pad_token_id: Optional[int] = None,
do_sample, eos_token_id: Optional[int] = None,
temperature, **model_kwargs
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
attention_mask,
use_cache,
model_kwargs,
): ):
r"""
Generates sequences for models with a language modeling head using greedy decoding.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the
model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... ])
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
""" """
Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated
independently.
"""
# length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation( # prepare model inputs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
)
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True) outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
scores = self.postprocess_next_token_scores( # pre-process distribution
scores=next_token_logits, scores = logits_processor(input_ids, next_token_logits)
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size, # argmax
bad_words_ids=bad_words_ids, next_tokens = torch.argmax(scores, dim=-1)
cur_len=cur_len,
min_length=min_length, # add code that transfomers next_tokens to tokens_to_add
max_length=max_length, if eos_token_id is not None:
eos_token_id=eos_token_id, assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
repetition_penalty=repetition_penalty, next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
batch_size=batch_size,
num_beams=1, # add token and increase length by one
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
# if model has past, then set the past variable to speed up decoding # stop when there is a </s> in each sentence, or if we exceed the maximul length
if "past_key_values" in outputs: if unfinished_sequences.max() == 0:
past = outputs.past_key_values break
elif "mems" in outputs:
past = outputs.mems # increase cur_len
cur_len = cur_len + 1
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) return input_ids
if temperature != 1.0:
scores = scores / temperature def sample(
# Top-p/top-k filtering self,
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) input_ids: torch.LongTensor,
# Sample logits_processor: Optional[LogitsProcessorList] = None,
probs = F.softmax(next_token_logscores, dim=-1) logits_warper: Optional[LogitsProcessorList] = None,
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) max_length: Optional[int] = None,
else: pad_token_id: Optional[int] = None,
# Greedy decoding eos_token_id: Optional[int] = None,
next_token = torch.argmax(next_token_logits, dim=-1) **model_kwargs
):
# update generations and finished sentences r"""
Generates sequences for models with a language modeling head using multinomial sampling.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
logits_warper (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... ])
>>> # instantiate logits processors
>>> logits_warper = LogitsProcessorList([
... TopKLogitsWarper(50),
... TemperatureLogitsWarper(0.7),
... ])
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# auto-regressive generation
while cur_len < max_length:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
scores = logits_processor(input_ids, next_token_logits)
scores = logits_warper(input_ids, scores)
# sample
probs = F.softmax(scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# add code that transfomers next_tokens to tokens_to_add
if eos_token_id is not None: if eos_token_id is not None:
# pad finished sentences if eos_token_id exist assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
else:
tokens_to_add = next_token
# add token and increase length by one # add token and increase length by one
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
cur_len = cur_len + 1 cur_len = cur_len + 1
# update sequence length
if eos_token_id is not None: if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() )
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
# unfinished_sents is set to zero if eos in sentence # stop when there is a </s> in each sentence, or if we exceed the maximul length
unfinished_sents.mul_((~eos_in_sents).long()) if unfinished_sequences.max() == 0:
# stop when there is a </s> in each sentence, or if we exceed the maximum length
if unfinished_sents.max() == 0:
break break
# extend attention_mask for new generated input if only decoder # update model kwargs
if self.config.is_encoder_decoder is False: model_kwargs = self._update_model_kwargs_for_generation(
attention_mask = torch.cat( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 )
)
return input_ids return input_ids
def _generate_beam_search( def beam_search(
self, self,
input_ids, input_ids: torch.LongTensor,
cur_len, beam_scorer: BeamScorer,
max_length, logits_processor: Optional[LogitsProcessorList] = None,
min_length, max_length: Optional[int] = None,
do_sample, pad_token_id: Optional[int] = None,
early_stopping, eos_token_id: Optional[int] = None,
temperature, **model_kwargs
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
attention_mask,
use_cache,
model_kwargs,
): ):
"""Generate sequences for each example with beam search.""" r"""
Generates sequences for models with a language modeling head using beam search decoding.
# generated hypotheses Parameters:
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
for _ in range(batch_size)
]
# scores for each sentence in the beam input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
beam_scorer (:obj:`BeamScorer`):
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
constructed, stored and sorted during generation. For more information, the documentation of
:class:`~transformers.BeamScorer` should be read.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times Return:
if do_sample is False: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
beam_scores[:, 1:] = -1e9 sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ])
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# cache compute states batch_size = len(beam_scorer._beam_hyps)
past = None num_beams = beam_scorer.num_beams
# done sentences batch_beam_size, cur_len = input_ids.shape
done = [False for _ in range(batch_size)]
assert (
num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation( model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
# adjust tokens for Bart, *e.g.*
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
) )
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
scores = self.postprocess_next_token_scores( next_token_scores = logits_processor(input_ids, next_token_scores)
scores=scores, next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
input_ids=input_ids, # reshape for beam search
no_repeat_ngram_size=no_repeat_ngram_size, vocab_size = next_token_scores.shape[-1]
bad_words_ids=bad_words_ids, next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
cur_len=cur_len,
min_length=min_length, next_token_scores, next_tokens = torch.topk(
max_length=max_length, next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=num_beams,
) )
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( next_indices = next_tokens // vocab_size
scores.shape, (batch_size * num_beams, vocab_size) next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
) )
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
if do_sample: input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) cur_len = cur_len + 1
# Temperature
if temperature != 1.0:
_scores = _scores / temperature
# Top-p/top-k filtering
_scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
probs = F.softmax(_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else:
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis across beams)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
# next batch beam content
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence, add a pad token
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content, this will get added to next_batch_beam
next_sent_beam = []
# next tokens for this sentence
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and token IDs
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(),
beam_token_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams:
break
# Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
)
# update next beam content model_kwargs = self._update_model_kwargs_for_generation(
assert len(next_sent_beam) == num_beams, "Beam should always be full" outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
next_batch_beam.extend(next_sent_beam) )
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
# stop when we are done with each sentence if beam_scorer.is_done:
if all(done):
break break
# sanity check / prepare next batch decoded = beam_scorer.finalize(
assert len(next_batch_beam) == batch_size * num_beams input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) )
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and update current length return decoded
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
cur_len = cur_len + 1
# re-order internal states def beam_sample(
if past is not None: self,
past = self._reorder_cache(past, beam_idx) input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**model_kwargs
):
r"""
Generates sequences for models with a language modeling head using beam search with multinomial sampling.
# extend attention_mask for new generated input if only decoder Parameters:
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# finalize all open beam hypotheses and add to generated hypotheses input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
for batch_idx in range(batch_size): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
if done[batch_idx]: :obj:`torch.LongTensor` of shape :obj:`(1,)`.
continue beam_scorer (:obj:`BeamScorer`):
A derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
# test that beam scores match previously calculated scores if not eos and batch_idx not done constructed, stored and sorted during generation. For more information, the documentation of
if eos_token_id is not None and all( :class:`~transformers.BeamScorer` should be read.
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx] logits_processor (:obj:`LogitsProcessorList`, `optional`):
): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
assert torch.all( :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] head applied at each generation step.
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( logits_warper (:obj:`LogitsProcessorList`, `optional`):
next_scores[:, :num_beams][batch_idx], An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
beam_scores.view(batch_size, num_beams)[batch_idx], :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
) modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
# need to add best num_beams hypotheses to generated hyps Return:
for beam_id in range(num_beams): :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
effective_beam_id = batch_idx * num_beams + beam_id sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
final_score = beam_scores[effective_beam_id].item() batches finished early due to the :obj:`eos_token_id`.
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
sent_lengths = input_ids.new(output_batch_size)
best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
effective_batch_idx = output_num_return_sequences_per_batch * i + j
best_hyp = sorted_hyps.pop()[1]
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(output_batch_size, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return decoded Examples::
@staticmethod >>> from transformers import (
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]: ... AutoTokenizer,
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) ... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)
... ])
>>> # instantiate logits processors
>>> logits_warper = LogitsProcessorList([
... TopKLogitsWarper(50),
... TemperatureLogitsWarper(0.7),
... ])
>>> outputs = model.beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: batch_size = len(beam_scorer._beam_hyps)
"""Copied from fairseq for no_repeat_ngram in beam_search""" num_beams = beam_scorer.num_beams
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet batch_beam_size, cur_len = input_ids.shape
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)] beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
for idx in range(num_hypos): beam_scores = beam_scores.view((batch_size * num_beams,))
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx] while cur_len < max_length:
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_tokens):
# if bad word tokens are longer than prev tokens they can't be equal
return False
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
for prev_input_ids_slice in prev_input_ids: outputs = self(**model_inputs, return_dict=True)
banned_tokens_slice = [] next_token_logits = outputs.logits[:, -1, :]
for banned_token_seq in bad_words_ids: # adjust token scores (a no-op by default)
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( next_token_logits = self.adjust_logits_during_generation(
bad_words_ids next_token_logits, cur_len=cur_len, max_length=max_length
) )
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1]) next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = logits_warper(input_ids, next_token_scores)
banned_tokens.append(banned_tokens_slice) # reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
return banned_tokens probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: next_indices = next_tokens // vocab_size
""" next_tokens = next_tokens % vocab_size
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a list
of list of banned tokens to ban in the format [[batch index, vocabulary position],...
Args: # stateless
scores: logits distribution of shape (batch size, vocabulary size) beam_outputs = beam_scorer.process(
banned_tokens: list of list of tokens to ban of length (batch_size) input_ids,
""" next_token_scores,
banned_mask_list = [] next_tokens,
for idx, batch_banned_tokens in enumerate(banned_tokens): next_indices,
for token in batch_banned_tokens: pad_token_id=pad_token_id,
banned_mask_list.append([idx, token]) eos_token_id=eos_token_id,
if not banned_mask_list: )
return beam_scores = beam_outputs["next_beam_scores"]
banned_mask = torch.LongTensor(banned_mask_list) beam_next_tokens = beam_outputs["next_beam_tokens"]
indices = torch.ones(len(banned_mask)) beam_idx = beam_outputs["next_beam_indices"]
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
# [ 0 0 0 ] cur_len = cur_len + 1
# [ 1 0 0 ]
model_kwargs = self._update_model_kwargs_for_generation(
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
scores.masked_fill_(banned_mask, -float("inf")) )
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if beam_scorer.is_done:
break
decoded = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)
return decoded
def top_k_top_p_filtering( def top_k_top_p_filtering(
logits: Tensor, logits: torch.FloatTensor,
top_k: int = 0, top_k: int = 0,
top_p: float = 1.0, top_p: float = 1.0,
filter_value: float = -float("Inf"), filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
) -> Tensor: ) -> torch.FloatTensor:
""" """
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
...@@ -980,73 +1205,11 @@ def top_k_top_p_filtering( ...@@ -980,73 +1205,11 @@ def top_k_top_p_filtering(
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
if top_k > 0: if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
# Remove all tokens with a probability less than the last token of the top-k None, logits
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] )
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs): if 0 <= top_p <= 1.0:
""" logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams: return logits
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
...@@ -1084,7 +1084,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1084,7 +1084,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -514,12 +514,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -514,12 +514,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]} return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -431,7 +431,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -431,7 +431,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = { input_dict = {
......
...@@ -1107,7 +1107,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1107,7 +1107,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -1800,7 +1800,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1800,7 +1800,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
return loss return loss
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from .generation_beam_search import BeamSearchScorer
from .modeling_outputs import ModelOutput from .modeling_outputs import ModelOutput
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .retrieval_rag import RagRetriever from .retrieval_rag import RagRetriever
...@@ -825,7 +826,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -825,7 +826,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
num_return_sequences=None, # defaults to 1 num_return_sequences=None, # defaults to 1
num_beams=None, # defaults to 1 num_beams=None, # defaults to 1
n_docs=None, n_docs=None,
**kwargs **model_kwargs
): ):
""" """
Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate`` Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate``
...@@ -872,7 +873,6 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -872,7 +873,6 @@ class RagSequenceForGeneration(RagPreTrainedModel):
) )
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
# TODO(patrick) - clean up generate here
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
context_input_ids = self.retriever( context_input_ids = self.retriever(
...@@ -887,10 +887,9 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -887,10 +887,9 @@ class RagSequenceForGeneration(RagPreTrainedModel):
context_input_ids = context_input_ids.to(input_ids) context_input_ids = context_input_ids.to(input_ids)
hypos = [] hypos = []
kwargs["num_beams"] = num_beams model_kwargs["num_beams"] = num_beams
kwargs["num_return_sequences"] = num_beams model_kwargs["num_return_sequences"] = num_beams
kwargs["attention_mask"] = None model_kwargs["attention_mask"] = None
kwargs["n_docs"] = n_docs
for index in range(len(input_ids)): for index in range(len(input_ids)):
# first, generate beams from documents: # first, generate beams from documents:
...@@ -898,7 +897,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -898,7 +897,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
output_sequences = self.generator.generate( output_sequences = self.generator.generate(
generator_input_ids, generator_input_ids,
**kwargs, **model_kwargs,
) # n_docs * n_beam, tgt_len ) # n_docs * n_beam, tgt_len
if do_deduplication: if do_deduplication:
# do_deduplication, max_output_len # do_deduplication, max_output_len
...@@ -1018,7 +1017,15 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1018,7 +1017,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length) return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, n_docs=None, **kwargs self,
decoder_input_ids,
past=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
doc_scores=None,
n_docs=None,
**kwargs
): ):
return { return {
"input_ids": None, "input_ids": None,
...@@ -1222,11 +1229,12 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1222,11 +1229,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
repetition_penalty=None,
bad_words_ids=None, bad_words_ids=None,
num_return_sequences=None, num_return_sequences=None,
decoder_start_token_id=None, decoder_start_token_id=None,
n_docs=None, n_docs=None,
**kwargs **model_kwargs
): ):
""" """
Implements RAG token decoding. Implements RAG token decoding.
...@@ -1307,22 +1315,15 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1307,22 +1315,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
""" """
# set default parameters # set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
) )
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
use_cache = use_cache if use_cache is not None else self.config.use_cache
decoder_start_token_id = ( decoder_start_token_id = (
decoder_start_token_id decoder_start_token_id
if decoder_start_token_id is not None if decoder_start_token_id is not None
...@@ -1365,7 +1366,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1365,7 +1366,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder = self.rag.generator.get_encoder() encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
decoder_input_ids = torch.full( input_ids = torch.full(
(batch_size * num_beams, 1), (batch_size * num_beams, 1),
decoder_start_token_id, decoder_start_token_id,
dtype=torch.long, dtype=torch.long,
...@@ -1388,64 +1389,57 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1388,64 +1389,57 @@ class RagTokenForGeneration(RagPreTrainedModel):
doc_scores = doc_scores.repeat_interleave(num_beams, dim=0) doc_scores = doc_scores.repeat_interleave(num_beams, dim=0)
# define start_len & additional parameters # define start_len & additional parameters
cur_len = 1 model_kwargs["doc_scores"] = doc_scores
vocab_size = self.config.generator.vocab_size model_kwargs["encoder_outputs"] = encoder_outputs
kwargs["doc_scores"] = doc_scores model_kwargs["attention_mask"] = context_attention_mask
kwargs["encoder_outputs"] = encoder_outputs model_kwargs["n_docs"] = n_docs
kwargs["n_docs"] = n_docs
pre_processor = self._get_logits_processor(
# not needed. TODO(PVP): change after generate refactor repetition_penalty=repetition_penalty,
do_sample = False no_repeat_ngram_size=no_repeat_ngram_size,
temperature = self.config.temperature bad_words_ids=bad_words_ids,
top_k = self.config.top_k min_length=min_length,
top_p = self.config.top_p eos_token_id=eos_token_id,
repetition_penalty = self.config.repetition_penalty )
if num_beams > 1: if num_beams == 1:
return self._generate_beam_search( if num_return_sequences > 1:
decoder_input_ids, raise ValueError(
cur_len=cur_len, f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
return self.greedy_search(
input_ids,
pre_processor=pre_processor,
max_length=max_length, max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
**model_kwargs,
)
elif num_beams > 1:
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
num_return_sequences=num_return_sequences, max_length=max_length,
length_penalty=length_penalty,
num_beams=num_beams, num_beams=num_beams,
vocab_size=vocab_size, device=self.device,
attention_mask=context_attention_mask, length_penalty=length_penalty,
use_cache=use_cache, do_early_stopping=early_stopping,
model_kwargs=kwargs, num_beam_hyps_to_keep=num_return_sequences,
) )
else: return self.beam_search(
return self._generate_no_beam_search( input_ids,
decoder_input_ids, beam_scorer,
cur_len=cur_len, pre_processor=pre_processor,
max_length=max_length, max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
batch_size=batch_size, **model_kwargs,
attention_mask=context_attention_mask,
use_cache=use_cache,
model_kwargs=kwargs,
) )
else:
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings() return self.rag.generator.get_input_embeddings()
......
...@@ -638,7 +638,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -638,7 +638,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2)
# create a random self.attention_head_size x num_hashes x num_buckets/2 # create a random self.attention_head_size x num_hashes x num_buckets/2
random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype) random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype)
# Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2
rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations)
...@@ -1471,7 +1470,9 @@ class ReformerLayer(nn.Module): ...@@ -1471,7 +1470,9 @@ class ReformerLayer(nn.Module):
# every forward pass we sample a different seed # every forward pass we sample a different seed
# for dropout and save for forward fn in backward pass # for dropout and save for forward fn in backward pass
# to have correct dropout # to have correct dropout
self._init_attention_seed() if self.training:
self._init_attention_seed()
attn_outputs = self.attention( attn_outputs = self.attention(
hidden_states=hidden_states, hidden_states=hidden_states,
head_mask=head_mask, head_mask=head_mask,
...@@ -1494,7 +1495,8 @@ class ReformerLayer(nn.Module): ...@@ -1494,7 +1495,8 @@ class ReformerLayer(nn.Module):
# every forward pass we sample a different seed # every forward pass we sample a different seed
# for dropout and save seed for forward fn in backward # for dropout and save seed for forward fn in backward
# to have correct dropout # to have correct dropout
self._init_feed_forward_seed() if self.training:
self._init_feed_forward_seed()
# Y_2 = X_2 + g(Y_1) # Y_2 = X_2 + g(Y_1)
hidden_states = hidden_states + self.feed_forward(attn_output) hidden_states = hidden_states + self.feed_forward(attn_output)
...@@ -2263,7 +2265,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2263,7 +2265,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
attentions=reformer_outputs.attentions, attentions=reformer_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, num_hashes=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past is not None: if past is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
...@@ -2271,12 +2273,10 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2271,12 +2273,10 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
inputs_dict = { inputs_dict = {
"input_ids": input_ids, "input_ids": input_ids,
"past_buckets_states": past, "past_buckets_states": past,
"use_cache": kwargs["use_cache"], "use_cache": use_cache,
"num_hashes": num_hashes,
} }
if "num_hashes" in kwargs:
inputs_dict["num_hashes"] = kwargs["num_hashes"]
return inputs_dict return inputs_dict
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
......
...@@ -1232,7 +1232,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1232,7 +1232,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs): def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past is not None: if past is not None:
......
...@@ -1091,7 +1091,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1091,7 +1091,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else: else:
return self.crit.out_layers[-1] return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
inputs = {} inputs = {}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
......
...@@ -1300,7 +1300,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1300,7 +1300,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0] effective_batch_size = input_ids.shape[0]
...@@ -1333,7 +1333,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1333,7 +1333,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"input_ids": input_ids, "input_ids": input_ids,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_cache": kwargs["use_cache"], "use_cache": use_cache,
} }
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
......
...@@ -88,8 +88,8 @@ def is_pipeline_test(test_case): ...@@ -88,8 +88,8 @@ def is_pipeline_test(test_case):
""" """
Decorator marking a test as a pipeline test. Decorator marking a test as a pipeline test.
Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TEST environment variable to Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TESTS environment variable
a truthy value and selecting the is_pipeline_test pytest mark. to a truthy value and selecting the is_pipeline_test pytest mark.
""" """
if not _run_pipeline_tests: if not _run_pipeline_tests:
......
...@@ -104,6 +104,66 @@ class TextDatasetForNextSentencePrediction: ...@@ -104,6 +104,66 @@ class TextDatasetForNextSentencePrediction:
requires_pytorch(self) requires_pytorch(self)
class BeamScorer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class BeamSearchScorer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessorList:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class MinLengthLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class NoBadWordsLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class NoRepeatNGramLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class RepetitionPenaltyLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TemperatureLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TopKLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TopPLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
def top_k_top_p_filtering(*args, **kwargs): def top_k_top_p_filtering(*args, **kwargs):
requires_pytorch(top_k_top_p_filtering) requires_pytorch(top_k_top_p_filtering)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer
class BeamSearchTester:
def __init__(
self,
parent,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
beam_hyp = beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
beam_hyp = beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores):
# check too many eos tokens
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# check all batches are done
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
# check
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
decoded = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length])
# first batch has to finish with eos_token
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
beam_scorer.num_beam_hyps_to_keep = self.num_beams
decoded = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length])
@require_torch
class BeamSearchTest(unittest.TestCase):
def setUp(self):
self.beam_search_tester = BeamSearchTester(self)
def test_beam_hypotheses(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_hypotheses(*inputs)
def test_beam_scorer_update(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scorer_update(*inputs)
def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment