Unverified Commit 3dc82427 authored by Shamima's avatar Shamima Committed by GitHub
Browse files

TF: removed inputs_processing and replaced with decorator in lxmert (#16414)

parent b320d87e
......@@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list, unpack_inputs
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -671,6 +671,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -686,51 +687,33 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
if visual_pos is None or visual_feats is None:
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
# Positional Word Embeddings
embedding_output = self.embeddings(
inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"]
)
embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds, training)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -743,13 +726,9 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
if inputs["visual_attention_mask"] is not None:
extended_visual_attention_mask = tf.reshape(
inputs["visual_attention_mask"], (input_shape[0], 1, 1, input_shape[1])
)
extended_visual_attention_mask = tf.expand_dims(
tf.expand_dims(inputs["visual_attention_mask"], axis=1), axis=1
)
if visual_attention_mask is not None:
extended_visual_attention_mask = tf.reshape(visual_attention_mask, (input_shape[0], 1, 1, input_shape[1]))
extended_visual_attention_mask = tf.expand_dims(tf.expand_dims(visual_attention_mask, axis=1), axis=1)
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, dtype=embedding_output.dtype)
extended_visual_attention_mask = tf.multiply(
......@@ -762,18 +741,18 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
inputs["visual_feats"],
inputs["visual_pos"],
visual_feats,
visual_pos,
extended_visual_attention_mask,
output_attentions=inputs["output_attentions"],
training=inputs["training"],
output_attentions,
training,
)
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
vision_hidden_states = visual_encoder_outputs[0]
language_hidden_states = lang_encoder_outputs[0]
all_attentions = ()
if inputs["output_attentions"]:
if output_attentions:
language_attentions = lang_encoder_outputs[1]
vision_attentions = visual_encoder_outputs[1]
cross_encoder_attentions = encoder_outputs[2]
......@@ -783,24 +762,24 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
cross_encoder_attentions,
)
hidden_states = (language_hidden_states, vision_hidden_states) if inputs["output_hidden_states"] else ()
hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
visual_output = vision_hidden_states[-1]
lang_output = language_hidden_states[-1]
pooled_output = self.pooler(lang_output)
if not inputs["return_dict"]:
if not return_dict:
return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
return TFLxmertModelOutput(
pooled_output=pooled_output,
language_output=lang_output,
vision_output=visual_output,
language_hidden_states=language_hidden_states if inputs["output_hidden_states"] else None,
vision_hidden_states=vision_hidden_states if inputs["output_hidden_states"] else None,
language_attentions=language_attentions if inputs["output_attentions"] else None,
vision_attentions=vision_attentions if inputs["output_attentions"] else None,
cross_encoder_attentions=cross_encoder_attentions if inputs["output_attentions"] else None,
language_hidden_states=language_hidden_states if output_hidden_states else None,
vision_hidden_states=vision_hidden_states if output_hidden_states else None,
language_attentions=language_attentions if output_attentions else None,
vision_attentions=vision_attentions if output_attentions else None,
cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
)
......@@ -946,6 +925,7 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.lxmert = TFLxmertMainLayer(config, name="lxmert")
@unpack_inputs
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -968,34 +948,18 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
input_ids,
visual_feats,
visual_pos,
attention_mask,
visual_attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict,
training,
)
return outputs
......@@ -1298,6 +1262,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name
@unpack_inputs
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -1339,38 +1304,19 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
Returns:
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
masked_lm_labels=masked_lm_labels,
obj_labels=obj_labels,
matched_label=matched_label,
ans=ans,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
lxmert_output = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
input_ids,
visual_feats,
visual_pos,
attention_mask,
visual_attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict,
training,
)
lang_output, visual_output, pooled_output = (
......@@ -1386,34 +1332,29 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_loss = (
None
if (
inputs["masked_lm_labels"] is None
and inputs["matched_label"] is None
and inputs["obj_labels"] is None
and inputs["ans"] is None
)
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
else tf.constant(0.0)
)
losses = ()
if inputs["masked_lm_labels"] is not None and self.task_mask_lm:
if masked_lm_labels is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"](
tf.reshape(inputs["masked_lm_labels"], [-1]),
tf.reshape(masked_lm_labels, [-1]),
tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]),
)
total_loss += masked_lm_loss
losses += (masked_lm_loss,)
if inputs["matched_label"] is not None and self.task_matched:
if matched_label is not None and self.task_matched:
matched_loss = self.loss_fcts["ce"](
tf.reshape(inputs["matched_label"], [-1]),
tf.reshape(matched_label, [-1]),
tf.reshape(cross_relationship_score, [-1, 2]),
)
total_loss += matched_loss
losses += (matched_loss,)
if inputs["obj_labels"] is not None and self.task_obj_predict:
if obj_labels is not None and self.task_obj_predict:
total_visn_loss = 0.0
visn_prediction_scores_dict = self.obj_predict_head(visual_output)
for key, key_info in self.visual_losses.items():
label, mask_conf = inputs["obj_labels"][key]
label, mask_conf = obj_labels[key]
output_dim = key_info["num"]
loss_fct_name = key_info["loss"]
label_shape = key_info["shape"]
......@@ -1431,7 +1372,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_visn_loss += visn_loss
losses += (visn_loss,)
total_loss += total_visn_loss
if inputs["ans"] is not None and self.task_qa:
if ans is not None and self.task_qa:
answer_loss = self.loss_fcts["ce"](
tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels])
)
......@@ -1444,7 +1385,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
losses += (answer_loss,)
# return total_loss, tf.stack(losses)[tf.new_axis, ...], answer_score.detach()
if not inputs["return_dict"]:
if not return_dict:
output = (
lang_prediction_scores,
cross_relationship_score,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment