Unverified Commit 29d49924 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

New TF model inputs (#8602)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add input processing for TF Flaubert

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add the new inputs in new Longformer models

* Update the template with the new input processing

* Remove useless assert

* Apply style

* Trigger CI
parent 82d443a7
......@@ -34,7 +34,7 @@ class TFGenerationMixin:
Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in
the generate method.
"""
return {"inputs": inputs}
return {"input_ids": inputs}
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
......
......@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
import functools
import inspect
import os
import re
import warnings
......@@ -27,8 +29,17 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .file_utils import (
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
ModelOutput,
cached_path,
hf_bucket_url,
is_remote_url,
)
from .generation_tf_utils import TFGenerationMixin
from .tokenization_utils_base import BatchEncoding
from .utils import logging
......@@ -236,6 +247,110 @@ class TFNextSentencePredictionLoss:
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def input_processing(func, input_ids, **kwargs):
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
if "inputs" in kwargs["kwargs_call"]:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
if "decoder_cached_states" in kwargs["kwargs_call"]:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
if isinstance(input_ids, (tuple, list)):
for i, input in enumerate(input_ids):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern name:device_id then we check only the
# name and not the device id
tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names:
output[tensor_name] = input
else:
raise ValueError(
f"The tensor named {input.name} does not belong to the authorized list of names {parameter_names}."
)
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
)
elif isinstance(input_ids, (dict, BatchEncoding)):
if "inputs" in input_ids:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = input_ids.pop("inputs")
if "decoder_cached_states" in input_ids:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_ids.pop("decoder_cached_states")
for k, v in dict(input_ids).items():
if not isinstance(v, allowed_types):
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
else:
output[k] = v
else:
if isinstance(input_ids, tf.Tensor) or input_ids is None:
output[parameter_names[0]] = input_ids
else:
raise ValueError(
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
)
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_ids`
output["input_ids"] = output["args"]
del output["args"]
if "kwargs" in output:
del output["kwargs"]
return output
def load_tf_weights(model, resolved_archive_file):
"""
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
......@@ -385,6 +500,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
......@@ -1047,8 +1163,13 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
static = tensor.shape.as_list()
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic.as_list()
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
......
......@@ -14,16 +14,14 @@
# limitations under the License.
"""TF BlenderBot model, ported from the fairseq repo."""
from ...file_utils import add_start_docstrings, is_tf_available
import tensorflow as tf
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig
if is_tf_available():
import tensorflow as tf
_CONFIG_FOR_DOC = "BlenderbotConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace(
......
......@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 CTRL model."""
import numpy as np
import tensorflow as tf
......@@ -25,10 +24,10 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_ctrl import CTRLConfig
......@@ -252,7 +251,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
......@@ -264,79 +263,72 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
return_dict = inputs[10] if len(inputs) > 10 else return_dict
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None:
if inputs["past"] is not None:
if inputs["input_ids"] is not None:
inputs["input_ids"] = inputs["input_ids"][:, -1:]
if inputs["inputs_embeds"] is not None:
inputs["inputs_embeds"] = inputs["inputs_embeds"][:, -1:]
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = inputs["token_type_ids"][:, -1:]
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None:
if inputs["past"] is None:
past_length = 0
past = [None] * len(self.h)
inputs["past"] = [None] * len(self.h)
else:
past_length = shape_list(past[0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [input_shape[0], 1])
past_length = shape_list(inputs["past"][0][0])[-2]
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
tf.newaxis, :
]
inputs["position_ids"] = tf.tile(inputs["position_ids"], [input_shape[0], 1])
# Attention mask.
if attention_mask is not None:
if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -344,61 +336,63 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else:
attention_mask = None
inputs["attention_mask"] = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_layers
inputs["head_mask"] = [None] * self.num_layers
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.w(token_type_ids, mode="embedding")
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
else:
token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None:
inputs_embeds = self.w(input_ids, mode="embedding")
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.w(inputs["input_ids"], mode="embedding")
seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
pos_embeds = tf.gather(self.pos_encoding, position_ids)
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if use_cache else None
presents = () if inputs["use_cache"] else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)):
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h(
hidden_states,
mask,
layer_past,
attention_mask,
head_mask[i],
use_cache,
inputs["attention_mask"],
inputs["head_mask"][i],
inputs["use_cache"],
output_attentions,
training=training,
training=inputs["training"],
)
hidden_states, present = outputs[:2]
if use_cache:
if inputs["use_cache"]:
presents = presents + (present,)
if output_attentions:
......@@ -554,8 +548,52 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
output_type=TFBaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -600,7 +638,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -611,7 +649,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
)
def call(
self,
inputs,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
......@@ -624,22 +662,16 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[11] if len(inputs) > 11 else labels
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -650,7 +682,24 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
......@@ -658,10 +707,10 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
......
......@@ -22,8 +22,7 @@ from typing import Optional, Tuple
import tensorflow as tf
from transformers.activations_tf import get_tf_activation
from ...activations_tf import get_tf_activation
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -31,8 +30,14 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
)
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
from ...tokenization_utils import BatchEncoding
from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging
from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice,
......@@ -229,8 +234,56 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -351,7 +404,7 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
class TFFlaubertMainLayer(tf.keras.layers.Layer):
config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.n_heads = config.n_heads
......@@ -417,7 +470,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -430,64 +483,57 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
# removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
return_dict = inputs[11] if len(inputs) > 11 else return_dict
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
bs, slen = shape_list(input_ids)
elif inputs_embeds is not None:
bs, slen = shape_list(inputs_embeds)[:2]
elif inputs["input_ids"] is not None:
bs, slen = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None:
if input_ids is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
if inputs["lengths"] is None:
if inputs["input_ids"] is not None:
inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index
# check inputs
# assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
......@@ -496,26 +542,26 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids
if position_ids is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0)
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None:
if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
......@@ -523,34 +569,34 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layers
inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements
if cache is not None and input_ids is not None:
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - inputs["cache"]["slen"]
inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if inputs["langs"] is not None:
inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# embeddings
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids)
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
if inputs["langs"] is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(inputs["langs"])
if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training)
tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis]
# hidden_states and attentions cannot be None in graph mode.
......@@ -562,7 +608,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# LayerDrop
dropout_probability = tf.random.uniform([1], 0, 1)
if training and tf.less(dropout_probability, self.layerdrop):
if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
continue
if output_hidden_states:
......@@ -571,27 +617,39 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# self attention
if not self.pre_norm:
attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
else:
tensor_normalized = self.layer_norm1[i](tensor)
attn_outputs = self.attentions[i](
tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training
tensor_normalized,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn
# encoder attention (for decoder only)
......@@ -616,8 +674,8 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,)
# update cache length
if cache is not None:
cache["slen"] += tensor.size(1)
if inputs["cache"] is not None:
inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
......@@ -724,7 +782,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id
else:
langs = None
return {"inputs": inputs, "langs": langs}
return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -733,11 +791,56 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
output_type=TFFlaubertWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
transformer_outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
outputs = self.pred_layer(output)
......
......@@ -16,7 +16,6 @@
# limitations under the License.
""" TF 2.0 LXMERT model. """
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
......@@ -30,8 +29,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from ...tokenization_utils_base import BatchEncoding
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...utils import logging
from .configuration_lxmert import LxmertConfig
......@@ -716,7 +714,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
......@@ -727,60 +725,55 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
visual_feats = inputs[1] if len(inputs) > 1 else visual_feats
visual_pos = inputs[2] if len(inputs) > 2 else visual_pos
attention_mask = inputs[3] if len(inputs) > 3 else attention_mask
visual_attention_mask = inputs[4] if len(inputs) > 4 else visual_attention_mask
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
return_dict = inputs[9] if len(inputs) > 9 else return_dict
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
visual_feats = inputs.get("visual_feats", visual_feats)
visual_pos = inputs.get("visual_pos", visual_pos)
attention_mask = inputs.get("attention_mask", attention_mask)
visual_attention_mask = inputs.get("visual_attention_mask", visual_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if visual_pos is None or visual_feats is None:
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -791,8 +784,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if visual_attention_mask is not None:
extended_visual_attention_mask = visual_attention_mask[:, tf.newaxis, tf.newaxis, :]
if inputs["visual_attention_mask"] is not None:
extended_visual_attention_mask = inputs["visual_attention_mask"][:, tf.newaxis, tf.newaxis, :]
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
......@@ -800,17 +793,19 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_visual_attention_mask = None
# Positional Word Embeddings
embedding_output = self.embeddings([input_ids, token_type_ids, inputs_embeds], training=training)
embedding_output = self.embeddings(
[inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"]], training=inputs["training"]
)
# Run Lxmert encoder
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
visual_feats,
visual_pos,
inputs["visual_feats"],
inputs["visual_pos"],
extended_visual_attention_mask,
output_attentions=output_attentions,
training=training,
training=inputs["training"],
)
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
vision_hidden_states = visual_encoder_outputs[0]
......@@ -977,8 +972,50 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
output_type=TFLxmertModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, *args, **kwargs):
outputs = self.lxmert(inputs, *args, **kwargs)
def call(
self,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
visual_attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -1228,7 +1265,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs=None,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
......@@ -1242,6 +1279,8 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
masked_lm_labels (``tf.Tensor`` of shape ``(batch_size, sequence_length)``, `optional`):
......@@ -1263,31 +1302,38 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
Returns:
"""
if isinstance(inputs, (tuple, list)):
masked_lm_labels = inputs[7] if len(inputs) > 7 else masked_lm_labels
obj_labels = inputs[8] if len(inputs) > 8 else obj_labels
matched_label = inputs[9] if len(inputs) > 9 else matched_label
ans = inputs[10] if len(inputs) > 10 else ans
if len(inputs) > 10:
inputs = inputs[:10]
elif isinstance(inputs, (dict, BatchEncoding)):
masked_lm_labels = inputs.pop("masked_lm_labels", masked_lm_labels)
obj_labels = inputs.pop("obj_labels", obj_labels)
matched_label = inputs.pop("matched_label", matched_label)
ans = inputs.pop("ans", ans)
return_dict = return_dict if return_dict is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
masked_lm_labels=masked_lm_labels,
obj_labels=obj_labels,
matched_label=matched_label,
ans=ans,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
lang_output, visual_output, pooled_output = (
......@@ -1303,29 +1349,34 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_loss = (
None
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
if (
inputs["masked_lm_labels"] is None
and inputs["matched_label"] is None
and inputs["obj_labels"] is None
and inputs["ans"] is None
)
else tf.constant(0.0)
)
losses = ()
if masked_lm_labels is not None and self.task_mask_lm:
if inputs["masked_lm_labels"] is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"](
tf.reshape(masked_lm_labels, [-1]),
tf.reshape(inputs["masked_lm_labels"], [-1]),
tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]),
)
total_loss += masked_lm_loss
losses += (masked_lm_loss,)
if matched_label is not None and self.task_matched:
if inputs["matched_label"] is not None and self.task_matched:
matched_loss = self.loss_fcts["ce"](
tf.reshape(matched_label, [-1]),
tf.reshape(inputs["matched_label"], [-1]),
tf.reshape(cross_relationship_score, [-1, 2]),
)
total_loss += matched_loss
losses += (matched_loss,)
if obj_labels is not None and self.task_obj_predict:
if inputs["obj_labels"] is not None and self.task_obj_predict:
total_visn_loss = 0.0
visn_prediction_scores_dict = self.obj_predict_head(visual_output)
for key, key_info in self.visual_losses.items():
label, mask_conf = obj_labels[key]
label, mask_conf = inputs["obj_labels"][key]
output_dim = key_info["num"]
loss_fct_name = key_info["loss"]
label_shape = key_info["shape"]
......@@ -1343,7 +1394,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_visn_loss += visn_loss
losses += (visn_loss,)
total_loss += total_visn_loss
if ans is not None and self.task_qa:
if inputs["ans"] is not None and self.task_qa:
answer_loss = self.loss_fcts["ce"](
tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels])
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment