Unverified Commit a1bbcf3f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Refactoring the generate() function (#6949)

* first draft

* show design proposition for new generate method

* up

* make better readable

* make first version

* gpt2 tests pass

* make beam search for gpt2 work

* add first encoder-decoder code

* delete typo

* make t5 work

* save indermediate

* make bart work with beam search

* finish beam search bart / t5

* add default kwargs

* make more tests pass

* fix no bad words sampler

* some fixes and tests for all distribution processors

* fix test

* fix rag slow tests

* merge to master

* add nograd to generate

* make all slow tests pass

* speed up generate

* fix edge case bug

* small fix

* correct typo

* add type hints and docstrings

* fix typos in tests

* add beam search tests

* add tests for beam scorer

* fix test rag

* finish beam search tests

* move generation tests in seperate file

* fix generation tests

* more tests

* add aggressive generation tests

* fix tests

* add gpt2 sample test

* add more docstring

* add more docs

* finish doc strings

* apply some more of sylvains and sams comments

* fix some typos

* make fix copies

* apply lysandres and sylvains comments

* final corrections on examples

* small fix for reformer
parent b63beb74
...@@ -272,3 +272,4 @@ conversion utilities for the following models: ...@@ -272,3 +272,4 @@ conversion utilities for the following models:
internal/pipelines_utils internal/pipelines_utils
internal/tokenization_utils internal/tokenization_utils
internal/trainer_utils internal/trainer_utils
internal/generation_utils
Utilities for Generation
-----------------------------------------------------------------------------------------------------------------------
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
:meth:`~transformers.PretrainedModel.beam_search`, and :meth:`~transformers.PretrainedModel.beam_sample`.
Most of those are only useful if you are studying the code of the generate methods in the library.
LogitsProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A :class:`~transformers.LogitsProcessor` can be used to modify the prediction scores of a language model head for
generation.
.. autoclass:: transformers.LogitsProcessor
:members: __call__
.. autoclass:: transformers.LogitsProcessorList
:members: __call__
.. autoclass:: transformers.MinLengthLogitsProcessor
:members: __call__
.. autoclass:: transformers.TemperatureLogitsWarper
:members: __call__
.. autoclass:: transformers.RepetitionPenaltyLogitsProcessor
:members: __call__
.. autoclass:: transformers.TopPLogitsWarper
:members: __call__
.. autoclass:: transformers.TopKLogitsWarper
:members: __call__
.. autoclass:: transformers.NoRepeatNGramLogitsProcessor
:members: __call__
.. autoclass:: transformers.NoBadWordsLogitsProcessor
:members: __call__
BeamSearch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BeamScorer
:members: process, finalize
.. autoclass:: transformers.BeamSearchScorer
:members: process, finalize
...@@ -45,7 +45,7 @@ TFModelUtilsMixin ...@@ -45,7 +45,7 @@ TFModelUtilsMixin
:members: :members:
Generative models Generation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.generation_utils.GenerationMixin .. autoclass:: transformers.generation_utils.GenerationMixin
......
...@@ -299,6 +299,19 @@ if is_torch_available(): ...@@ -299,6 +299,19 @@ if is_torch_available():
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from .generation_utils import top_k_top_p_filtering from .generation_utils import top_k_top_p_filtering
from .modeling_albert import ( from .modeling_albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple
import torch
from .file_utils import add_start_docstrings
PROCESS_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses.
next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
:obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses.
next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
Return:
:obj:`UserDict`: A dictionary composed of the fields as defined above:
- **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated
scores of all non-finished beams.
- **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens
to be added to the non-finished beam_hypotheses.
- **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices
indicating to which beam the next tokens shall be added.
"""
FINALIZE_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The final scores of all non-finished beams.
final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The last tokens to be added to the non-finished beam_hypotheses.
final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
"""
class BeamScorer(ABC):
"""
Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
:meth:`~transformers.PretrainedModel.beam_sample`.
"""
@abstractmethod
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> Tuple[torch.Tensor]:
raise NotImplementedError("This is an abstract method.")
@abstractmethod
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
def finalize(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.")
class BeamSearchScorer(BeamScorer):
r"""
:class:`transformers.BeamScorer` implementing standard beam search decoding.
Adapted in part from `Facebook's XLM beam search code
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
Args:
batch_size (:obj:`int`):
Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
max_length (:obj:`int`):
The maximum length of the sequence to be generated.
num_beams (:obj:`int`):
Number of beams for beam search.
device (:obj:`torch.device`):
Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
:obj:`BeamSearchScorer` will be allocated.
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
sequences.
do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
:meth:`~transformer.BeamSearchScorer.finalize`.
"""
def __init__(
self,
batch_size: int,
max_length: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
assert batch_size == (input_ids.shape[0] // self.num_beams)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
assert (
len(beam_hyp) >= self.num_beams
), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.num_beams + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.num_beams:
break
if beam_idx < self.num_beams:
raise ValueError(
f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.LongTensor:
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp = sorted_hyps.pop()[1]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
best.append(best_hyp)
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return decoded
class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Iterable, List
import numpy as np
import torch
from torch.nn import functional as F
from .file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
Return:
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class LogitsProcessor(ABC):
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsWarper(ABC):
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsProcessorList(list):
"""
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
:class:`~transformers.LogitsProcessor` to the inputs.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for processor in self:
scores = processor(input_ids, scores)
return scores
class MinLengthLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
Args:
min_length (:obj:`int`):
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
return scores
class TemperatureLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
Args:
temperature (:obj:`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores = scores / self.temperature
return scores
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (:obj:`float`):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for i in range(scores.shape[0]):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= self.penalty
else:
scores[i, previous_token] /= self.penalty
return scores
class TopPLogitsWarper(LogitsWarper):
"""
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
prob_cut_off.
Args:
top_p (:obj:`float`):
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
kept for generation.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores[indices_to_remove] = self.filter_value
return scores
class TopKLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (:obj:`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores[indices_to_remove] = self.filter_value
return scores
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
Args:
ngram_size (:obj:`int`):
All ngrams of size :obj:`ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
def _calc_banned_ngram_tokens(
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < self.ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - self.ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
class NoBadWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled.
Args:
bad_words_ids (:obj:`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
for banned_token_seq in self.bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
banned_tokens = self._calc_banned_bad_words_ids(input_ids)
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
"""
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if not banned_mask_list:
return scores
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores
This diff is collapsed.
...@@ -1084,7 +1084,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1084,7 +1084,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -514,12 +514,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -514,12 +514,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]} return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -431,7 +431,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -431,7 +431,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = { input_dict = {
......
...@@ -1107,7 +1107,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1107,7 +1107,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -1800,7 +1800,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1800,7 +1800,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
return loss return loss
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from .generation_beam_search import BeamSearchScorer
from .modeling_outputs import ModelOutput from .modeling_outputs import ModelOutput
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .retrieval_rag import RagRetriever from .retrieval_rag import RagRetriever
...@@ -825,7 +826,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -825,7 +826,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
num_return_sequences=None, # defaults to 1 num_return_sequences=None, # defaults to 1
num_beams=None, # defaults to 1 num_beams=None, # defaults to 1
n_docs=None, n_docs=None,
**kwargs **model_kwargs
): ):
""" """
Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate`` Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate``
...@@ -872,7 +873,6 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -872,7 +873,6 @@ class RagSequenceForGeneration(RagPreTrainedModel):
) )
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
# TODO(patrick) - clean up generate here
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
context_input_ids = self.retriever( context_input_ids = self.retriever(
...@@ -887,10 +887,9 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -887,10 +887,9 @@ class RagSequenceForGeneration(RagPreTrainedModel):
context_input_ids = context_input_ids.to(input_ids) context_input_ids = context_input_ids.to(input_ids)
hypos = [] hypos = []
kwargs["num_beams"] = num_beams model_kwargs["num_beams"] = num_beams
kwargs["num_return_sequences"] = num_beams model_kwargs["num_return_sequences"] = num_beams
kwargs["attention_mask"] = None model_kwargs["attention_mask"] = None
kwargs["n_docs"] = n_docs
for index in range(len(input_ids)): for index in range(len(input_ids)):
# first, generate beams from documents: # first, generate beams from documents:
...@@ -898,7 +897,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -898,7 +897,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
output_sequences = self.generator.generate( output_sequences = self.generator.generate(
generator_input_ids, generator_input_ids,
**kwargs, **model_kwargs,
) # n_docs * n_beam, tgt_len ) # n_docs * n_beam, tgt_len
if do_deduplication: if do_deduplication:
# do_deduplication, max_output_len # do_deduplication, max_output_len
...@@ -1018,7 +1017,15 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1018,7 +1017,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length) return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, n_docs=None, **kwargs self,
decoder_input_ids,
past=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
doc_scores=None,
n_docs=None,
**kwargs
): ):
return { return {
"input_ids": None, "input_ids": None,
...@@ -1222,11 +1229,12 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1222,11 +1229,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
repetition_penalty=None,
bad_words_ids=None, bad_words_ids=None,
num_return_sequences=None, num_return_sequences=None,
decoder_start_token_id=None, decoder_start_token_id=None,
n_docs=None, n_docs=None,
**kwargs **model_kwargs
): ):
""" """
Implements RAG token decoding. Implements RAG token decoding.
...@@ -1307,22 +1315,15 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1307,22 +1315,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
""" """
# set default parameters # set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
) )
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
use_cache = use_cache if use_cache is not None else self.config.use_cache
decoder_start_token_id = ( decoder_start_token_id = (
decoder_start_token_id decoder_start_token_id
if decoder_start_token_id is not None if decoder_start_token_id is not None
...@@ -1365,7 +1366,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1365,7 +1366,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder = self.rag.generator.get_encoder() encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
decoder_input_ids = torch.full( input_ids = torch.full(
(batch_size * num_beams, 1), (batch_size * num_beams, 1),
decoder_start_token_id, decoder_start_token_id,
dtype=torch.long, dtype=torch.long,
...@@ -1388,64 +1389,57 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1388,64 +1389,57 @@ class RagTokenForGeneration(RagPreTrainedModel):
doc_scores = doc_scores.repeat_interleave(num_beams, dim=0) doc_scores = doc_scores.repeat_interleave(num_beams, dim=0)
# define start_len & additional parameters # define start_len & additional parameters
cur_len = 1 model_kwargs["doc_scores"] = doc_scores
vocab_size = self.config.generator.vocab_size model_kwargs["encoder_outputs"] = encoder_outputs
kwargs["doc_scores"] = doc_scores model_kwargs["attention_mask"] = context_attention_mask
kwargs["encoder_outputs"] = encoder_outputs model_kwargs["n_docs"] = n_docs
kwargs["n_docs"] = n_docs
pre_processor = self._get_logits_processor(
# not needed. TODO(PVP): change after generate refactor repetition_penalty=repetition_penalty,
do_sample = False no_repeat_ngram_size=no_repeat_ngram_size,
temperature = self.config.temperature bad_words_ids=bad_words_ids,
top_k = self.config.top_k min_length=min_length,
top_p = self.config.top_p eos_token_id=eos_token_id,
repetition_penalty = self.config.repetition_penalty )
if num_beams > 1: if num_beams == 1:
return self._generate_beam_search( if num_return_sequences > 1:
decoder_input_ids, raise ValueError(
cur_len=cur_len, f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
return self.greedy_search(
input_ids,
pre_processor=pre_processor,
max_length=max_length, max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
**model_kwargs,
)
elif num_beams > 1:
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
num_return_sequences=num_return_sequences, max_length=max_length,
length_penalty=length_penalty,
num_beams=num_beams, num_beams=num_beams,
vocab_size=vocab_size, device=self.device,
attention_mask=context_attention_mask, length_penalty=length_penalty,
use_cache=use_cache, do_early_stopping=early_stopping,
model_kwargs=kwargs, num_beam_hyps_to_keep=num_return_sequences,
) )
else: return self.beam_search(
return self._generate_no_beam_search( input_ids,
decoder_input_ids, beam_scorer,
cur_len=cur_len, pre_processor=pre_processor,
max_length=max_length, max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
batch_size=batch_size, **model_kwargs,
attention_mask=context_attention_mask,
use_cache=use_cache,
model_kwargs=kwargs,
) )
else:
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings() return self.rag.generator.get_input_embeddings()
......
...@@ -638,7 +638,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -638,7 +638,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2)
# create a random self.attention_head_size x num_hashes x num_buckets/2 # create a random self.attention_head_size x num_hashes x num_buckets/2
random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype) random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype)
# Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2
rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations)
...@@ -1471,7 +1470,9 @@ class ReformerLayer(nn.Module): ...@@ -1471,7 +1470,9 @@ class ReformerLayer(nn.Module):
# every forward pass we sample a different seed # every forward pass we sample a different seed
# for dropout and save for forward fn in backward pass # for dropout and save for forward fn in backward pass
# to have correct dropout # to have correct dropout
self._init_attention_seed() if self.training:
self._init_attention_seed()
attn_outputs = self.attention( attn_outputs = self.attention(
hidden_states=hidden_states, hidden_states=hidden_states,
head_mask=head_mask, head_mask=head_mask,
...@@ -1494,7 +1495,8 @@ class ReformerLayer(nn.Module): ...@@ -1494,7 +1495,8 @@ class ReformerLayer(nn.Module):
# every forward pass we sample a different seed # every forward pass we sample a different seed
# for dropout and save seed for forward fn in backward # for dropout and save seed for forward fn in backward
# to have correct dropout # to have correct dropout
self._init_feed_forward_seed() if self.training:
self._init_feed_forward_seed()
# Y_2 = X_2 + g(Y_1) # Y_2 = X_2 + g(Y_1)
hidden_states = hidden_states + self.feed_forward(attn_output) hidden_states = hidden_states + self.feed_forward(attn_output)
...@@ -2263,7 +2265,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2263,7 +2265,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
attentions=reformer_outputs.attentions, attentions=reformer_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, num_hashes=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past is not None: if past is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
...@@ -2271,12 +2273,10 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2271,12 +2273,10 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
inputs_dict = { inputs_dict = {
"input_ids": input_ids, "input_ids": input_ids,
"past_buckets_states": past, "past_buckets_states": past,
"use_cache": kwargs["use_cache"], "use_cache": use_cache,
"num_hashes": num_hashes,
} }
if "num_hashes" in kwargs:
inputs_dict["num_hashes"] = kwargs["num_hashes"]
return inputs_dict return inputs_dict
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
......
...@@ -1232,7 +1232,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1232,7 +1232,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs): def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past is not None: if past is not None:
......
...@@ -1091,7 +1091,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1091,7 +1091,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else: else:
return self.crit.out_layers[-1] return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
inputs = {} inputs = {}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
......
...@@ -1300,7 +1300,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1300,7 +1300,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0] effective_batch_size = input_ids.shape[0]
...@@ -1333,7 +1333,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1333,7 +1333,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"input_ids": input_ids, "input_ids": input_ids,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_cache": kwargs["use_cache"], "use_cache": use_cache,
} }
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
......
...@@ -88,8 +88,8 @@ def is_pipeline_test(test_case): ...@@ -88,8 +88,8 @@ def is_pipeline_test(test_case):
""" """
Decorator marking a test as a pipeline test. Decorator marking a test as a pipeline test.
Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TEST environment variable to Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TESTS environment variable
a truthy value and selecting the is_pipeline_test pytest mark. to a truthy value and selecting the is_pipeline_test pytest mark.
""" """
if not _run_pipeline_tests: if not _run_pipeline_tests:
......
...@@ -104,6 +104,66 @@ class TextDatasetForNextSentencePrediction: ...@@ -104,6 +104,66 @@ class TextDatasetForNextSentencePrediction:
requires_pytorch(self) requires_pytorch(self)
class BeamScorer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class BeamSearchScorer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessorList:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class MinLengthLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class NoBadWordsLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class NoRepeatNGramLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class RepetitionPenaltyLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TemperatureLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TopKLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TopPLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
def top_k_top_p_filtering(*args, **kwargs): def top_k_top_p_filtering(*args, **kwargs):
requires_pytorch(top_k_top_p_filtering) requires_pytorch(top_k_top_p_filtering)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer
class BeamSearchTester:
def __init__(
self,
parent,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
beam_hyp = beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
beam_hyp = beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores):
# check too many eos tokens
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# check all batches are done
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
# check
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
decoded = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length])
# first batch has to finish with eos_token
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
beam_scorer.num_beam_hyps_to_keep = self.num_beams
decoded = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length])
@require_torch
class BeamSearchTest(unittest.TestCase):
def setUp(self):
self.beam_search_tester = BeamSearchTester(self)
def test_beam_hypotheses(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_hypotheses(*inputs)
def test_beam_scorer_update(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scorer_update(*inputs)
def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment