Unverified Commit 2e12b907 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

TF generate refactor - Greedy Search (#15562)



* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent a3dbbc34
......@@ -148,6 +148,24 @@ generation.
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__
[[autodoc]] TFLogitsProcessor
- __call__
[[autodoc]] TFLogitsProcessorList
- __call__
[[autodoc]] TFMinLengthLogitsProcessor
- __call__
[[autodoc]] TFNoBadWordsLogitsProcessor
- __call__
[[autodoc]] TFNoRepeatNGramLogitsProcessor
- __call__
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] FlaxLogitsProcessor
- __call__
......
......@@ -1592,6 +1592,14 @@ if is_tf_available():
_import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = []
......@@ -2046,6 +2054,7 @@ if is_tf_available():
]
)
_import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"]
_import_structure["tf_utils"] = []
_import_structure["trainer_tf"] = ["TFTrainer"]
else:
......@@ -3572,6 +3581,14 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import (
TFLogitsProcessor,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
)
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import (
......
......@@ -14,7 +14,6 @@
# limitations under the License.
import inspect
from abc import ABC
import jax
import jax.lax as lax
......@@ -48,7 +47,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
"""
class FlaxLogitsProcessor(ABC):
class FlaxLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
......@@ -59,7 +58,7 @@ class FlaxLogitsProcessor(ABC):
)
class FlaxLogitsWarper(ABC):
class FlaxLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
......
......@@ -15,7 +15,6 @@
import inspect
import math
from abc import ABC
from typing import Callable, Iterable, List, Optional
import numpy as np
......@@ -49,7 +48,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
"""
class LogitsProcessor(ABC):
class LogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
......@@ -60,7 +59,7 @@ class LogitsProcessor(ABC):
)
class LogitsWarper(ABC):
class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List
import numpy as np
import tensorflow as tf
from .file_utils import add_start_docstrings
from .tf_utils import set_tensor_by_indices_to_value
from .utils.logging import get_logger
logger = get_logger(__name__)
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
kwargs:
Additional logits processor specific kwargs.
Return:
`tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class TFLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
"""TF method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class TFLogitsProcessorList(list):
"""
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the
inputs.
"""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
class TFMinLengthLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
# create boolean flag to decide if min length penalty should be applied
cur_len = input_ids.shape[-1]
apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1)
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
# generate is not XLA - compileable anyways
if apply_penalty:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))
return scores
class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def _create_score_penalties(self, input_ids, logits):
# create logit penalties for already seen input_ids
token_penalties = np.ones(logits.shape)
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
logit_penalties = np.zeros(logit_penalized.shape)
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalties[logit_penalized < 0] = self.penalty
logit_penalties[logit_penalized > 0] = 1 / self.penalty
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids, scores)
scores = tf.math.multiply(scores, score_penalties)
return scores
class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
"""
[`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
that should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True).input_ids`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
self.bad_words_ids = bad_words_ids
def calc_banned_bad_words_ids(self, prev_input_ids):
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_tokens):
# if bad word tokens are longer than prev tokens they can't be equal
return False
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids:
assert (
len(banned_token_seq) > 0
), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list"
if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
vocab_size = scores.shape[-1]
# calculate a list of banned tokens according to bad words
banned_tokens = self.calc_banned_bad_words_ids(input_ids)
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
return scores
class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search
if cur_len + 1 < self.ngram_size:
# return no banned tokens if we haven't generated ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].numpy().tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - self.ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
batch_size, vocab_size = scores.shape
cur_len = input_ids.shape[-1]
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
return scores
......@@ -16,12 +16,20 @@
import inspect
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from .file_utils import ModelOutput
from .generation_tf_logits_process import (
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
)
from .tf_utils import set_tensor_by_indices_to_value, shape_list
from .utils import logging
......@@ -476,18 +484,18 @@ class TFGenerationMixin:
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~file_utils.ModelOutput`] types are:
- [`~generation_utils.TFGreedySearchDecoderOnlyOutput`],
- [`~generation_utils.TFSampleDecoderOnlyOutput`],
- [`~generation_utils.TFBeamSearchDecoderOnlyOutput`],
- [`~generation_utils.TFBeamSampleDecoderOnlyOutput`]
- [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`],
- [`~generation_tf_utils.TFSampleDecoderOnlyOutput`],
- [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`],
- [`~generation_tf_utils.TFBeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~file_utils.ModelOutput`] types are:
- [`~generation_utils.TFGreedySearchEncoderDecoderOutput`],
- [`~generation_utils.TFSampleEncoderDecoderOutput`],
- [`~generation_utils.TFBeamSearchEncoderDecoderOutput`],
- [`~generation_utils.TFBeamSampleEncoderDecoderOutput`]
- [`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`],
- [`~generation_tf_utils.TFSampleEncoderDecoderOutput`],
- [`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`],
- [`~generation_tf_utils.TFBeamSampleEncoderDecoderOutput`]
Examples:
......@@ -547,6 +555,38 @@ class TFGenerationMixin:
input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids
) # generate sequences without allowing bad_words to be generated
```"""
num_beams = num_beams if num_beams is not None else self.config.num_beams
do_sample = do_sample if do_sample is not None else self.config.do_sample
is_greedy_gen_mode = num_beams == 1 and do_sample is False
if is_greedy_gen_mode:
return self._generate(
input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
)
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
......@@ -557,12 +597,11 @@ class TFGenerationMixin:
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
......@@ -632,7 +671,7 @@ class TFGenerationMixin:
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
# This block corresponds to the following line in `generation_utils`:
# This block corresponds to the following line in `generation_tf_utils`:
# "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))"
# with the following differences:
# 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF.
......@@ -751,14 +790,13 @@ class TFGenerationMixin:
cur_len < max_length
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
if num_beams > 1:
output = self._generate_beam_search(
if num_beams == 1:
return self._generate_no_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
......@@ -768,25 +806,21 @@ class TFGenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
else:
output = self._generate_no_beam_search(
return self._generate_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
......@@ -796,16 +830,19 @@ class TFGenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
return output
def _generate_no_beam_search(
self,
input_ids,
......@@ -1488,6 +1525,676 @@ class TFGenerationMixin:
else:
return logits
def _generate(
self,
input_ids=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
output_scores=None,
output_attentions=None,
output_hidden_states=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
**model_kwargs,
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Adapted in part from [Facebook's XLM beam search
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
Apart from `input_ids` and `attention_mask`, all the arguments below will default to the value of the attribute
of the same name inside the [`PretrainedConfig`] of the model. The default values indicated are the default
values of those config.
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
Parameters:
input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
The sequence used as a prompt for the generation. If `None` the method initializes it with
`bos_token_id` and a batch size of 1.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
repetition_penalty (`float`, *optional*, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty.
Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
order to encourage the model to produce longer sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
attention_mask (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
that are not masked, and 0 for masked tokens.
If not provided, will default to a tensor the same shape as `input_ids` that masks the pad token.
[What are attention masks?](../glossary#attention-mask)
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
forced_bos_token_id (`int`, *optional*):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
the target language token.
forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Return:
[`~file_utils.ModelOutput`] or `tf.Tensor`: A [`~file_utils.ModelOutput`] (if
`return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `tf.Tensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~file_utils.ModelOutput`] types are:
- [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`],
- [`~generation_tf_utils.TFSampleDecoderOnlyOutput`],
- [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`],
- [`~generation_tf_utils.TFBeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~file_utils.ModelOutput`] types are:
- [`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`],
- [`~generation_tf_utils.TFSampleEncoderDecoderOutput`],
- [`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`],
- [`~generation_tf_utils.TFBeamSampleEncoderDecoderOutput`]
Examples:
```python
tokenizer = AutoTokenizer.from_pretrained("distilgpt2") # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained("distilgpt2")
# Greedy decoding
outputs = model.generate(max_length=40)
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("openai-gpt")
model = TFAutoModelWithLMHead.from_pretrained("openai-gpt")
input_context = "The dog"
input_ids = tokenizer.encode(input_context, return_tensors="tf") # encode input context
# Generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)
# 3 output sequences were generated
for i in range(3):
print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = TFAutoModelWithLMHead.from_pretrained("distilgpt2")
input_context = "The dog"
input_ids = tokenizer.encode(input_context, return_tensors="tf")
# Generate 3 candidates using sampling
outputs = model.generate(
input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True
)
# 3 output sequences were generated
for i in range(3):
print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("ctrl")
model = TFAutoModelWithLMHead.from_pretrained("ctrl")
# "Legal" is one of the control codes for ctrl
input_context = "Legal My neighbor is"
input_ids = tokenizer.encode(input_context, return_tensors="tf")
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = TFAutoModelWithLMHead.from_pretrained("gpt2")
input_context = "My cute dog"
bad_words_ids = [
tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ["idiot", "stupid", "shut up"]
]
input_ids = tokenizer.encode(input_context, return_tensors="tf")
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
# 1. Set generation parameters if not already defined
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
num_beams = num_beams if num_beams is not None else self.config.num_beams
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
pad_token_id = eos_token_id
# 2. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
# inputs_ids now has to be defined and cannot be None anymore
batch_size = input_ids.shape[0]
# 3. Prepare other model kwargs
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
if self.config.is_encoder_decoder:
# if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
input_ids, return_dict_in_generate, model_kwargs
)
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger
# refactor of all generation models in TF. `past` should be
# optional everywhere and not be set equal to encoder_outputs
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
)
if input_ids.shape[-1] >= max_length:
raise ValueError(
f"The context has {input_ids.shape[-1]} number of tokens, "
f"but `max_length` is only {max_length}. "
"Please make sure that `max_length` is bigger than the number of tokens, "
"by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
)
# 5. determine generation mode
# TODO(Matt, Joao, Patrick) - add more use cases here
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
# 6. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
)
# 7. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# 8. run greedy search
return self.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
logits_processor=logits_processor,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
def _prepare_attention_mask_for_generation(
self,
input_ids: tf.Tensor,
pad_token_id: int,
) -> tf.Tensor:
# prepare `attention_mask` if not passed
if (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
else:
return tf.ones(input_ids.shape[:2], dtype=tf.int32)
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs
) -> Dict[str, Any]:
# TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs`
# is cleaned
# get encoder and store encoder outputs
encoder = self.get_encoder()
# prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
# vision models don't use `attention_mask`.
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")
encoder_outputs = encoder(input_ids, **encoder_kwargs)
model_kwargs["encoder_outputs"] = encoder_outputs
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and
# `encoder_hidden_states` have to be seperated from encoder_outputs and passed
# under other names because of `encoder_outputs`, `past` hack. Need to clean-up
# all encoder-decoder prepare_inputs_for_generation method to clean this
if return_dict_in_generate:
model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None)
model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
) -> tf.Tensor:
# prepare `input_ids` for decoder if model is encoder-decoder
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
else:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
# retrieve decoder_start_token_id for encoder-decoder models
# fall back to bos_token_id if necessary
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
# TODO(Patrick) - adapt this function when making `generate` more flexible
# for all kinds of input types
if inputs is None:
# if no `inputs` are passed create prompt of size (1,1) filled with BOS token
if not isinstance(bos_token_id, int) or bos_token_id < 0:
raise ValueError(
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
return tf.cast(tf.fill((1, 1), bos_token_id), dtype=tf.int32)
return inputs
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
if self._use_cache(outputs, model_kwargs["use_cache"]):
# TODO(Patrick): `past`/`encoder_outputs` hack. This should be
# removed when cleaning up the encoder-decoder models
# if model has past, then set the past variable to speed up decoding
# make this method static then as well
model_kwargs["past"] = outputs[1]
elif "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
elif "past" in model_kwargs:
# TODO(Patrick) `past`/`encoder_outputs` hack.
# removed when cleaning up the encoder-decoder models.
# The line should not be necessary.
pass
else:
model_kwargs["past"] = None
# update attention mask
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
return model_kwargs
def _get_logits_processor(
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
instances used to modify the scores of the language model head.
"""
processors = TFLogitsProcessorList()
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# instantiate processors list
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if bad_words_ids is not None:
processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id))
return processors
def greedy_search(
self,
input_ids: tf.Tensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
logits_processor: Optional[TFLogitsProcessorList] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
) -> Union[TFGreedySearchOutput, tf.Tensor]:
r"""
Generates sequences for models with a language modeling head using greedy decoding.
Parameters:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`TFLogitsProcessorList`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`],
[`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing the
generated tokens (default behaviour) or a [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`] if
`model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
[`~generation_tf_utils.TFGreedySearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import (
... TFAutoTokenizer,
... TFAutoModelForCausalLM,
... TFLogitsProcessorList,
... TFMinLengthLogitsProcessor,
... )
>>> tokenizer = TFAutoTokenizer.from_pretrained("gpt2")
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids
>>> # instantiate logits processors
>>> logits_processor = TFLogitsProcessorList(
... [
... TFMinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. Tis is a bad design and needs
# to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]
while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax
next_tokens = tf.cast(tf.argmax(next_tokens_scores, axis=-1), tf.int32)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
eos_in_sents = next_tokens == eos_token_id
# if sentence is unfinished and the token to add is eos
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
)
# unfinished_sequences is set to zero if eos in sentence
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
# stop when each sentence is finished, or if we exceed the maximum length
if tf.math.reduce_max(unfinished_sequences) == 0:
break
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return TFGreedySearchEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFGreedySearchDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids
......@@ -1628,12 +2335,6 @@ def scatter_values_on_batch_indices(values, batch_indices):
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
def set_tensor_by_indices_to_value(tensor, indices, value):
# create value_tensor since tensor value assignment is not possible in TF
value_tensor = tf.zeros_like(tensor) + value
return tf.where(indices, value_tensor, tensor)
def sample_without_replacement(logits, num_samples):
"""
categorical sampling without replacement is currently not implemented the gumbel-max trick will do for now see
......@@ -1644,13 +2345,6 @@ def sample_without_replacement(logits, num_samples):
return indices
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
......
......@@ -54,6 +54,7 @@ from .file_utils import (
)
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import logging
......@@ -2041,29 +2042,6 @@ class TFSequenceSummary(tf.keras.layers.Layer):
cls._auto_class = auto_class
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
Returns:
`List[int]`: The shape of the tensor as a list.
"""
if isinstance(tensor, np.ndarray):
return list(tensor.shape)
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
"""
Creates a `tf.initializers.TruncatedNormal` with the given range.
......
......@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_albert import AlbertConfig
......
......@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_bart import BartConfig
......
......@@ -57,8 +57,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_bert import BertConfig
......
......@@ -46,8 +46,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_blenderbot import BlenderbotConfig
......
......@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_blenderbot_small import BlenderbotSmallConfig
......
......@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
......
......@@ -43,8 +43,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_convbert import ConvBertConfig
......
......@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_ctrl import CTRLConfig
......
......@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss,
get_initializer,
input_processing,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_deberta import DebertaConfig
......
......@@ -38,8 +38,8 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss,
get_initializer,
input_processing,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_deberta_v2 import DebertaV2Config
......
......@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_distilbert import DistilBertConfig
......
......@@ -50,8 +50,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_electra import ElectraConfig
......
......@@ -30,13 +30,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
get_initializer,
input_processing,
shape_list,
)
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, input_processing
from ...tf_utils import shape_list
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment