Unverified Commit 3dc82427 authored by Shamima's avatar Shamima Committed by GitHub
Browse files

TF: removed inputs_processing and replaced with decorator in lxmert (#16414)

parent b320d87e
...@@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple ...@@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list, unpack_inputs
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -671,6 +671,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -671,6 +671,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -686,51 +687,33 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -686,51 +687,33 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["visual_pos"] is None or inputs["visual_feats"] is None: if visual_pos is None or visual_feats is None:
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.") raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
# Positional Word Embeddings # Positional Word Embeddings
embedding_output = self.embeddings( embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds, training)
inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"]
)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -743,13 +726,9 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -743,13 +726,9 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
if inputs["visual_attention_mask"] is not None: if visual_attention_mask is not None:
extended_visual_attention_mask = tf.reshape( extended_visual_attention_mask = tf.reshape(visual_attention_mask, (input_shape[0], 1, 1, input_shape[1]))
inputs["visual_attention_mask"], (input_shape[0], 1, 1, input_shape[1]) extended_visual_attention_mask = tf.expand_dims(tf.expand_dims(visual_attention_mask, axis=1), axis=1)
)
extended_visual_attention_mask = tf.expand_dims(
tf.expand_dims(inputs["visual_attention_mask"], axis=1), axis=1
)
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, dtype=embedding_output.dtype) extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, dtype=embedding_output.dtype)
extended_visual_attention_mask = tf.multiply( extended_visual_attention_mask = tf.multiply(
...@@ -762,18 +741,18 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -762,18 +741,18 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
inputs["visual_feats"], visual_feats,
inputs["visual_pos"], visual_pos,
extended_visual_attention_mask, extended_visual_attention_mask,
output_attentions=inputs["output_attentions"], output_attentions,
training=inputs["training"], training,
) )
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2] visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
vision_hidden_states = visual_encoder_outputs[0] vision_hidden_states = visual_encoder_outputs[0]
language_hidden_states = lang_encoder_outputs[0] language_hidden_states = lang_encoder_outputs[0]
all_attentions = () all_attentions = ()
if inputs["output_attentions"]: if output_attentions:
language_attentions = lang_encoder_outputs[1] language_attentions = lang_encoder_outputs[1]
vision_attentions = visual_encoder_outputs[1] vision_attentions = visual_encoder_outputs[1]
cross_encoder_attentions = encoder_outputs[2] cross_encoder_attentions = encoder_outputs[2]
...@@ -783,24 +762,24 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -783,24 +762,24 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
cross_encoder_attentions, cross_encoder_attentions,
) )
hidden_states = (language_hidden_states, vision_hidden_states) if inputs["output_hidden_states"] else () hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
visual_output = vision_hidden_states[-1] visual_output = vision_hidden_states[-1]
lang_output = language_hidden_states[-1] lang_output = language_hidden_states[-1]
pooled_output = self.pooler(lang_output) pooled_output = self.pooler(lang_output)
if not inputs["return_dict"]: if not return_dict:
return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
return TFLxmertModelOutput( return TFLxmertModelOutput(
pooled_output=pooled_output, pooled_output=pooled_output,
language_output=lang_output, language_output=lang_output,
vision_output=visual_output, vision_output=visual_output,
language_hidden_states=language_hidden_states if inputs["output_hidden_states"] else None, language_hidden_states=language_hidden_states if output_hidden_states else None,
vision_hidden_states=vision_hidden_states if inputs["output_hidden_states"] else None, vision_hidden_states=vision_hidden_states if output_hidden_states else None,
language_attentions=language_attentions if inputs["output_attentions"] else None, language_attentions=language_attentions if output_attentions else None,
vision_attentions=vision_attentions if inputs["output_attentions"] else None, vision_attentions=vision_attentions if output_attentions else None,
cross_encoder_attentions=cross_encoder_attentions if inputs["output_attentions"] else None, cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
) )
...@@ -946,6 +925,7 @@ class TFLxmertModel(TFLxmertPreTrainedModel): ...@@ -946,6 +925,7 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.lxmert = TFLxmertMainLayer(config, name="lxmert") self.lxmert = TFLxmertMainLayer(config, name="lxmert")
@unpack_inputs
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -968,34 +948,18 @@ class TFLxmertModel(TFLxmertPreTrainedModel): ...@@ -968,34 +948,18 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.lxmert( outputs = self.lxmert(
input_ids=inputs["input_ids"], input_ids,
visual_feats=inputs["visual_feats"], visual_feats,
visual_pos=inputs["visual_pos"], visual_pos,
attention_mask=inputs["attention_mask"], attention_mask,
visual_attention_mask=inputs["visual_attention_mask"], visual_attention_mask,
token_type_ids=inputs["token_type_ids"], token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict,
training=inputs["training"], training,
) )
return outputs return outputs
...@@ -1298,6 +1262,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1298,6 +1262,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name
@unpack_inputs
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1339,38 +1304,19 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1339,38 +1304,19 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
Returns: Returns:
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
masked_lm_labels=masked_lm_labels,
obj_labels=obj_labels,
matched_label=matched_label,
ans=ans,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
lxmert_output = self.lxmert( lxmert_output = self.lxmert(
input_ids=inputs["input_ids"], input_ids,
visual_feats=inputs["visual_feats"], visual_feats,
visual_pos=inputs["visual_pos"], visual_pos,
attention_mask=inputs["attention_mask"], attention_mask,
visual_attention_mask=inputs["visual_attention_mask"], visual_attention_mask,
token_type_ids=inputs["token_type_ids"], token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict,
training=inputs["training"], training,
) )
lang_output, visual_output, pooled_output = ( lang_output, visual_output, pooled_output = (
...@@ -1386,34 +1332,29 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1386,34 +1332,29 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_loss = ( total_loss = (
None None
if ( if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
inputs["masked_lm_labels"] is None
and inputs["matched_label"] is None
and inputs["obj_labels"] is None
and inputs["ans"] is None
)
else tf.constant(0.0) else tf.constant(0.0)
) )
losses = () losses = ()
if inputs["masked_lm_labels"] is not None and self.task_mask_lm: if masked_lm_labels is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"]( masked_lm_loss = self.loss_fcts["ce"](
tf.reshape(inputs["masked_lm_labels"], [-1]), tf.reshape(masked_lm_labels, [-1]),
tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]), tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]),
) )
total_loss += masked_lm_loss total_loss += masked_lm_loss
losses += (masked_lm_loss,) losses += (masked_lm_loss,)
if inputs["matched_label"] is not None and self.task_matched: if matched_label is not None and self.task_matched:
matched_loss = self.loss_fcts["ce"]( matched_loss = self.loss_fcts["ce"](
tf.reshape(inputs["matched_label"], [-1]), tf.reshape(matched_label, [-1]),
tf.reshape(cross_relationship_score, [-1, 2]), tf.reshape(cross_relationship_score, [-1, 2]),
) )
total_loss += matched_loss total_loss += matched_loss
losses += (matched_loss,) losses += (matched_loss,)
if inputs["obj_labels"] is not None and self.task_obj_predict: if obj_labels is not None and self.task_obj_predict:
total_visn_loss = 0.0 total_visn_loss = 0.0
visn_prediction_scores_dict = self.obj_predict_head(visual_output) visn_prediction_scores_dict = self.obj_predict_head(visual_output)
for key, key_info in self.visual_losses.items(): for key, key_info in self.visual_losses.items():
label, mask_conf = inputs["obj_labels"][key] label, mask_conf = obj_labels[key]
output_dim = key_info["num"] output_dim = key_info["num"]
loss_fct_name = key_info["loss"] loss_fct_name = key_info["loss"]
label_shape = key_info["shape"] label_shape = key_info["shape"]
...@@ -1431,7 +1372,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1431,7 +1372,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_visn_loss += visn_loss total_visn_loss += visn_loss
losses += (visn_loss,) losses += (visn_loss,)
total_loss += total_visn_loss total_loss += total_visn_loss
if inputs["ans"] is not None and self.task_qa: if ans is not None and self.task_qa:
answer_loss = self.loss_fcts["ce"]( answer_loss = self.loss_fcts["ce"](
tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels]) tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels])
) )
...@@ -1444,7 +1385,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1444,7 +1385,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
losses += (answer_loss,) losses += (answer_loss,)
# return total_loss, tf.stack(losses)[tf.new_axis, ...], answer_score.detach() # return total_loss, tf.stack(losses)[tf.new_axis, ...], answer_score.detach()
if not inputs["return_dict"]: if not return_dict:
output = ( output = (
lang_prediction_scores, lang_prediction_scores,
cross_relationship_score, cross_relationship_score,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment