Unverified Commit 2e12b907 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

TF generate refactor - Greedy Search (#15562)



* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent a3dbbc34
......@@ -38,8 +38,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice,
......
......@@ -47,8 +47,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_funnel import FunnelConfig
......
......@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_gpt2 import GPT2Config
......
......@@ -28,13 +28,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import (
TFPreTrainedModel,
booleans_processing,
get_initializer,
keras_serializable,
shape_list,
)
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_hubert import HubertConfig
......
......@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_layoutlm import LayoutLMConfig
......
......@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_led import LEDConfig
......
......@@ -38,8 +38,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_longformer import LongformerConfig
......
......@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_marian import MarianConfig
......
......@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_mbart import MBartConfig
......
......@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_mobilebert import MobileBertConfig
......
......@@ -47,8 +47,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_mpnet import MPNetConfig
......
......@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_openai import OpenAIGPTConfig
......
......@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_pegasus import PegasusConfig
......
......@@ -1269,6 +1269,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
)
if return_dict_in_generate:
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
......@@ -1350,28 +1352,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
**model_kwargs, # encoder_outputs is here as in Pytorch's version
)
else:
return self._generate_no_beam_search(
decoder_input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
)
# TODO(Patrick) clean-up once generate is fully cleaned up
model_kwargs["attention_mask"] = context_attention_mask
# TODO(Patrick) remove once generate is fully cleaned up
model_kwargs.pop("output_hidden_states", None)
model_kwargs.pop("output_attentions", None)
model_kwargs.pop("output_scores", None)
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
model_kwargs["past"] = encoder_outputs
return self.greedy_search(
input_ids=decoder_input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=batch_size,
vocab_size=vocab_size,
attention_mask=context_attention_mask,
use_cache=use_cache,
forced_bos_token_id=None,
forced_eos_token_id=None,
logits_processor=pre_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs, # encoder_outputs is here as in Pytorch's version
**model_kwargs,
)
def get_input_embeddings(self):
......
......@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_rembert import RemBertConfig
......
......@@ -52,8 +52,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_roberta import RobertaConfig
......
......@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_roformer import RoFormerConfig
......
......@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_speech_to_text import Speech2TextConfig
......
......@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_t5 import T5Config
......
......@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_tapas import TapasConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment