Unverified Commit 2e12b907 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

TF generate refactor - Greedy Search (#15562)



* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent a3dbbc34
...@@ -38,8 +38,8 @@ from ...modeling_tf_utils import ( ...@@ -38,8 +38,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from ..xlm.modeling_tf_xlm import ( from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
......
...@@ -47,8 +47,8 @@ from ...modeling_tf_utils import ( ...@@ -47,8 +47,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_funnel import FunnelConfig from .configuration_funnel import FunnelConfig
......
...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( ...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
......
...@@ -28,13 +28,8 @@ from ...file_utils import ( ...@@ -28,13 +28,8 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import ( from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
TFPreTrainedModel, from ...tf_utils import shape_list
booleans_processing,
get_initializer,
keras_serializable,
shape_list,
)
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_hubert import HubertConfig from .configuration_hubert import HubertConfig
......
...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( ...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_layoutlm import LayoutLMConfig from .configuration_layoutlm import LayoutLMConfig
......
...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( ...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_led import LEDConfig from .configuration_led import LEDConfig
......
...@@ -38,8 +38,8 @@ from ...modeling_tf_utils import ( ...@@ -38,8 +38,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_longformer import LongformerConfig from .configuration_longformer import LongformerConfig
......
...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( ...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_marian import MarianConfig from .configuration_marian import MarianConfig
......
...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( ...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_mbart import MBartConfig from .configuration_mbart import MBartConfig
......
...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( ...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_mobilebert import MobileBertConfig from .configuration_mobilebert import MobileBertConfig
......
...@@ -47,8 +47,8 @@ from ...modeling_tf_utils import ( ...@@ -47,8 +47,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_mpnet import MPNetConfig from .configuration_mpnet import MPNetConfig
......
...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( ...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_openai import OpenAIGPTConfig from .configuration_openai import OpenAIGPTConfig
......
...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( ...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_pegasus import PegasusConfig from .configuration_pegasus import PegasusConfig
......
...@@ -1269,6 +1269,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1269,6 +1269,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
) )
if return_dict_in_generate: if return_dict_in_generate:
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
if output_attentions: if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states: if output_hidden_states:
...@@ -1350,28 +1352,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1350,28 +1352,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
**model_kwargs, # encoder_outputs is here as in Pytorch's version **model_kwargs, # encoder_outputs is here as in Pytorch's version
) )
else: else:
return self._generate_no_beam_search( pre_processor = self._get_logits_processor(
decoder_input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
)
# TODO(Patrick) clean-up once generate is fully cleaned up
model_kwargs["attention_mask"] = context_attention_mask
# TODO(Patrick) remove once generate is fully cleaned up
model_kwargs.pop("output_hidden_states", None)
model_kwargs.pop("output_attentions", None)
model_kwargs.pop("output_scores", None)
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
model_kwargs["past"] = encoder_outputs
return self.greedy_search(
input_ids=decoder_input_ids,
max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
batch_size=batch_size, logits_processor=pre_processor,
vocab_size=vocab_size, output_attentions=output_attentions,
attention_mask=context_attention_mask, output_hidden_states=output_hidden_states,
use_cache=use_cache, output_scores=output_scores,
forced_bos_token_id=None,
forced_eos_token_id=None,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**model_kwargs, # encoder_outputs is here as in Pytorch's version **model_kwargs,
) )
def get_input_embeddings(self): def get_input_embeddings(self):
......
...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( ...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_rembert import RemBertConfig from .configuration_rembert import RemBertConfig
......
...@@ -52,8 +52,8 @@ from ...modeling_tf_utils import ( ...@@ -52,8 +52,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
......
...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( ...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_roformer import RoFormerConfig from .configuration_roformer import RoFormerConfig
......
...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( ...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_speech_to_text import Speech2TextConfig from .configuration_speech_to_text import Speech2TextConfig
......
...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( ...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
......
...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( ...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_tapas import TapasConfig from .configuration_tapas import TapasConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment