Unverified Commit 26a2e365 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add output in a dictionary for TF `generate` method (#12139)

* Add output args to greedy search

* Fix critical typo + make style quality

* Handle generate_beam_search

* Add dict_specific tests and fix the placement of encoder outputs

* Add  specific outputs

* Update doc

* Fix typo

* Adjust handling encoder_outputs + Fix generating for T5

* Fix generate for RAG

* Fix handling ouptut_attentions when target_mapping is not None

Take care of situations when target_mapping is provided
as there are 2-tuple of attentions

Change from:
if inputs["output_attentions"]:
    attentions = tuple(tf.transpose(t, perm(2, 3, 0, 1)) for t in attentions)

to:
if inputs["output_attentions"]:
    if inputs["target_mapping"] is not None:
        # when target_mapping is provided, there are 2-tuple of attentions
         attentions = tuple(
             tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
        )
    else:
        attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)

* Rename kwargs to model_kwargs

* make style quality

* Move imports in test_modeling_tf_common.py

Move ModelOutput-related imports in test_modeling_tf_common.py
into the `is_tf_available():` statement.

* Rewrite nested if-statements

* Fix added tests
parent d4be4984
This diff is collapsed.
...@@ -1063,7 +1063,11 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1063,7 +1063,11 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
num_return_sequences=None, num_return_sequences=None,
decoder_start_token_id=None, decoder_start_token_id=None,
n_docs=None, n_docs=None,
**kwargs output_scores=None,
output_attentions=None,
output_hidden_states=None,
return_dict_in_generate=None,
**model_kwargs
): ):
""" """
Implements TFRAG token decoding. Implements TFRAG token decoding.
...@@ -1137,6 +1141,18 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1137,6 +1141,18 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer. Number of documents to retrieve and/or number of documents for which to generate an answer.
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more details.
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more details.
output_scores (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return: Return:
:obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated :obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
...@@ -1167,6 +1183,21 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1167,6 +1183,21 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
else self.config.generator.decoder_start_token_id else self.config.generator.decoder_start_token_id
) )
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
model_kwargs["output_scores"] = output_scores
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["encoder_attentions"] = None
model_kwargs["encoder_hidden_states"] = None
# retrieve docs # retrieve docs
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
...@@ -1200,7 +1231,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1200,7 +1231,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
batch_size = context_input_ids.shape[0] // n_docs batch_size = context_input_ids.shape[0] // n_docs
encoder = self.rag.generator.get_encoder() encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) encoder_outputs = encoder(
input_ids=context_input_ids,
attention_mask=context_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
if return_dict_in_generate:
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
decoder_input_ids = tf.fill( decoder_input_ids = tf.fill(
(batch_size * num_beams, 1), (batch_size * num_beams, 1),
...@@ -1238,9 +1281,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1238,9 +1281,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
# define start_len & additional parameters # define start_len & additional parameters
cur_len = 1 cur_len = 1
vocab_size = self.config.generator.vocab_size vocab_size = self.config.generator.vocab_size
kwargs["doc_scores"] = doc_scores model_kwargs["doc_scores"] = doc_scores
kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
kwargs["n_docs"] = n_docs model_kwargs["n_docs"] = n_docs
# not needed. TODO(PVP): change after generate refactor # not needed. TODO(PVP): change after generate refactor
do_sample = False do_sample = False
...@@ -1274,7 +1317,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1274,7 +1317,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
use_cache=use_cache, use_cache=use_cache,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
**kwargs, # encoder_outputs is here as in Pytorch's version return_dict_in_generate=return_dict_in_generate,
**model_kwargs, # encoder_outputs is here as in Pytorch's version
) )
else: else:
return self._generate_no_beam_search( return self._generate_no_beam_search(
...@@ -1297,7 +1341,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1297,7 +1341,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
use_cache=use_cache, use_cache=use_cache,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
**kwargs, # encoder_outputs is here as in Pytorch's version return_dict_in_generate=return_dict_in_generate,
**model_kwargs, # encoder_outputs is here as in Pytorch's version
) )
def get_input_embeddings(self): def get_input_embeddings(self):
......
...@@ -1481,6 +1481,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1481,6 +1481,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs, past_key_values = past, None encoder_outputs, past_key_values = past, None
else: else:
encoder_outputs, past_key_values = past[0], past[1] encoder_outputs, past_key_values = past[0], past[1]
if "encoder_hidden_states" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_hidden_states"])
if "encoder_attentions" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_attentions"])
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
......
...@@ -796,7 +796,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -796,7 +796,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else: else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
if inputs["output_attentions"]: if inputs["output_attentions"]:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) if inputs["target_mapping"] is not None:
# when target_mapping is provided, there are 2-tuple of attentions
attentions = tuple(
tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
)
else:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None) return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
......
...@@ -61,6 +61,16 @@ if is_tf_available(): ...@@ -61,6 +61,16 @@ if is_tf_available():
TFSharedEmbeddings, TFSharedEmbeddings,
tf_top_k_top_p_filtering, tf_top_k_top_p_filtering,
) )
from transformers.generation_tf_utils import (
TFBeamSampleDecoderOnlyOutput,
TFBeamSampleEncoderDecoderOutput,
TFBeamSearchDecoderOnlyOutput,
TFBeamSearchEncoderDecoderOutput,
TFGreedySearchDecoderOnlyOutput,
TFGreedySearchEncoderDecoderOutput,
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
if _tf_gpu_memory_limit is not None: if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
...@@ -1100,6 +1110,37 @@ class TFModelTesterMixin: ...@@ -1100,6 +1110,37 @@ class TFModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :] generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_lm_head_model_no_beam_search_generate_dict_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config)
output_greedy = model.generate(
input_ids,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
output_sample = model.generate(
input_ids,
do_sample=True,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, TFGreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_sample, TFSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_greedy, TFGreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_sample, TFSampleDecoderOnlyOutput)
def test_lm_head_model_random_beam_search_generate(self): def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get("input_ids", None) input_ids = inputs_dict.get("input_ids", None)
...@@ -1140,6 +1181,39 @@ class TFModelTesterMixin: ...@@ -1140,6 +1181,39 @@ class TFModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :] generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_lm_head_model_beam_search_generate_dict_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config)
output_beam_search = model.generate(
input_ids,
num_beams=2,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
output_beam_sample = model.generate(
input_ids,
num_beams=2,
do_sample=True,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, TFBeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_beam_sample, TFBeamSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, TFBeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_beam_sample, TFBeamSampleDecoderOnlyOutput)
def test_loss_computation(self): def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment