Unverified Commit 29d49924 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

New TF model inputs (#8602)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add input processing for TF Flaubert

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add the new inputs in new Longformer models

* Update the template with the new input processing

* Remove useless assert

* Apply style

* Trigger CI
parent 82d443a7
...@@ -34,7 +34,7 @@ class TFGenerationMixin: ...@@ -34,7 +34,7 @@ class TFGenerationMixin:
Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in
the generate method. the generate method.
""" """
return {"inputs": inputs} return {"input_ids": inputs}
def _use_cache(self, outputs, use_cache): def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass.""" """During generation, decide whether to pass the `past` variable to the next forward pass."""
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""TF general model utils.""" """TF general model utils."""
import functools import functools
import inspect
import os import os
import re import re
import warnings import warnings
...@@ -27,8 +29,17 @@ from tensorflow.python.keras import backend as K ...@@ -27,8 +29,17 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import (
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
ModelOutput,
cached_path,
hf_bucket_url,
is_remote_url,
)
from .generation_tf_utils import TFGenerationMixin from .generation_tf_utils import TFGenerationMixin
from .tokenization_utils_base import BatchEncoding
from .utils import logging from .utils import logging
...@@ -236,6 +247,110 @@ class TFNextSentencePredictionLoss: ...@@ -236,6 +247,110 @@ class TFNextSentencePredictionLoss:
return loss_fn(next_sentence_label, next_sentence_reduced_logits) return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def input_processing(func, input_ids, **kwargs):
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
if "inputs" in kwargs["kwargs_call"]:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
if "decoder_cached_states" in kwargs["kwargs_call"]:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
if isinstance(input_ids, (tuple, list)):
for i, input in enumerate(input_ids):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern name:device_id then we check only the
# name and not the device id
tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names:
output[tensor_name] = input
else:
raise ValueError(
f"The tensor named {input.name} does not belong to the authorized list of names {parameter_names}."
)
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
)
elif isinstance(input_ids, (dict, BatchEncoding)):
if "inputs" in input_ids:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = input_ids.pop("inputs")
if "decoder_cached_states" in input_ids:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_ids.pop("decoder_cached_states")
for k, v in dict(input_ids).items():
if not isinstance(v, allowed_types):
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
else:
output[k] = v
else:
if isinstance(input_ids, tf.Tensor) or input_ids is None:
output[parameter_names[0]] = input_ids
else:
raise ValueError(
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
)
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_ids`
output["input_ids"] = output["args"]
del output["args"]
if "kwargs" in output:
del output["kwargs"]
return output
def load_tf_weights(model, resolved_archive_file): def load_tf_weights(model, resolved_archive_file):
""" """
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
...@@ -385,6 +500,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -385,6 +500,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states. :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
""" """
base_model = getattr(self, self.base_model_prefix, self) base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self: if base_model is not self:
return base_model.get_input_embeddings() return base_model.get_input_embeddings()
else: else:
...@@ -1047,8 +1163,13 @@ def shape_list(tensor: tf.Tensor) -> List[int]: ...@@ -1047,8 +1163,13 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
Returns: Returns:
:obj:`List[int]`: The shape of the tensor as a list. :obj:`List[int]`: The shape of the tensor as a list.
""" """
static = tensor.shape.as_list()
dynamic = tf.shape(tensor) dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic.as_list()
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)] return [dynamic[i] if s is None else s for i, s in enumerate(static)]
......
...@@ -14,16 +14,14 @@ ...@@ -14,16 +14,14 @@
# limitations under the License. # limitations under the License.
"""TF BlenderBot model, ported from the fairseq repo.""" """TF BlenderBot model, ported from the fairseq repo."""
from ...file_utils import add_start_docstrings, is_tf_available import tensorflow as tf
from ...file_utils import add_start_docstrings
from ...utils import logging from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig from .configuration_blenderbot import BlenderbotConfig
if is_tf_available():
import tensorflow as tf
_CONFIG_FOR_DOC = "BlenderbotConfig" _CONFIG_FOR_DOC = "BlenderbotConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace( START_DOCSTRING = BART_START_DOCSTRING.replace(
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 CTRL model.""" """ TF 2.0 CTRL model."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -25,10 +24,10 @@ from ...modeling_tf_utils import ( ...@@ -25,10 +24,10 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
...@@ -252,7 +251,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -252,7 +251,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
past=None, past=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
...@@ -264,79 +263,72 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -264,79 +263,72 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
input_ids = inputs[0] input_ids=input_ids,
past = inputs[1] if len(inputs) > 1 else past past=past,
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask attention_mask=attention_mask,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids=token_type_ids,
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids=position_ids,
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask=head_mask,
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds=inputs_embeds,
use_cache = inputs[7] if len(inputs) > 7 else use_cache use_cache=use_cache,
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions output_attentions=output_attentions,
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states output_hidden_states=output_hidden_states,
return_dict = inputs[10] if len(inputs) > 10 else return_dict return_dict=return_dict,
assert len(inputs) <= 11, "Too many inputs." training=training,
elif isinstance(inputs, (dict, BatchEncoding)): kwargs_call=kwargs,
input_ids = inputs.get("input_ids") )
past = inputs.get("past", past) output_attentions = (
attention_mask = inputs.get("attention_mask", attention_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
token_type_ids = inputs.get("token_type_ids", token_type_ids) )
position_ids = inputs.get("position_ids", position_ids) output_hidden_states = (
head_mask = inputs.get("head_mask", head_mask) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) )
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
output_attentions = inputs.get("output_attentions", output_attentions) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.return_dict
# If using past key value states, only the last tokens # If using past key value states, only the last tokens
# should be given as an input # should be given as an input
if past is not None: if inputs["past"] is not None:
if input_ids is not None: if inputs["input_ids"] is not None:
input_ids = input_ids[:, -1:] inputs["input_ids"] = inputs["input_ids"][:, -1:]
if inputs_embeds is not None: if inputs["inputs_embeds"] is not None:
inputs_embeds = inputs_embeds[:, -1:] inputs["inputs_embeds"] = inputs["inputs_embeds"][:, -1:]
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
token_type_ids = token_type_ids[:, -1:] inputs["token_type_ids"] = inputs["token_type_ids"][:, -1:]
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if inputs["past"] is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) inputs["past"] = [None] * len(self.h)
else: else:
past_length = shape_list(past[0][0])[-2] past_length = shape_list(inputs["past"][0][0])[-2]
if position_ids is None: if inputs["position_ids"] is None:
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
position_ids = tf.tile(position_ids, [input_shape[0], 1]) tf.newaxis, :
]
inputs["position_ids"] = tf.tile(inputs["position_ids"], [input_shape[0], 1])
# Attention mask. # Attention mask.
if attention_mask is not None: if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -344,61 +336,63 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -344,61 +336,63 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0 inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else: else:
attention_mask = None inputs["attention_mask"] = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_layers inputs["head_mask"] = [None] * self.num_layers
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) inputs["token_type_ids"] = tf.reshape(
token_type_embeds = self.w(token_type_ids, mode="embedding") inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
else: else:
token_type_embeds = 0 token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.w(input_ids, mode="embedding") inputs["inputs_embeds"] = self.w(inputs["input_ids"], mode="embedding")
seq_len = input_shape[-1] seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
pos_embeds = tf.gather(self.pos_encoding, position_ids) pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
hidden_states = inputs_embeds + pos_embeds + token_type_embeds hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if use_cache else None presents = () if inputs["use_cache"] else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)): for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h( outputs = h(
hidden_states, hidden_states,
mask, mask,
layer_past, layer_past,
attention_mask, inputs["attention_mask"],
head_mask[i], inputs["head_mask"][i],
use_cache, inputs["use_cache"],
output_attentions, output_attentions,
training=training, training=inputs["training"],
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache: if inputs["use_cache"]:
presents = presents + (present,) presents = presents + (present,)
if output_attentions: if output_attentions:
...@@ -554,8 +548,52 @@ class TFCTRLModel(TFCTRLPreTrainedModel): ...@@ -554,8 +548,52 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
output_type=TFBaseModelOutputWithPast, output_type=TFBaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
...@@ -600,7 +638,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -600,7 +638,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -611,7 +649,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -611,7 +649,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
past=None, past=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
...@@ -624,22 +662,16 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -624,22 +662,16 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[11] if len(inputs) > 11 else labels input_ids=input_ids,
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -650,7 +682,24 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -650,7 +682,24 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -658,10 +707,10 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -658,10 +707,10 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
loss = None loss = None
if labels is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
labels = labels[:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
......
...@@ -22,8 +22,7 @@ from typing import Optional, Tuple ...@@ -22,8 +22,7 @@ from typing import Optional, Tuple
import tensorflow as tf import tensorflow as tf
from transformers.activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -31,8 +30,14 @@ from ...file_utils import ( ...@@ -31,8 +30,14 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_outputs import TFBaseModelOutput from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list from ...modeling_tf_utils import (
from ...tokenization_utils import BatchEncoding TFPreTrainedModel,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging from ...utils import logging
from ..xlm.modeling_tf_xlm import ( from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
...@@ -229,8 +234,56 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel): ...@@ -229,8 +234,56 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
...@@ -351,7 +404,7 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer): ...@@ -351,7 +404,7 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
class TFFlaubertMainLayer(tf.keras.layers.Layer): class TFFlaubertMainLayer(tf.keras.layers.Layer):
config_class = FlaubertConfig config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.n_heads = config.n_heads self.n_heads = config.n_heads
...@@ -417,7 +470,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -417,7 +470,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -430,64 +483,57 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -430,64 +483,57 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
langs = inputs[2] if len(inputs) > 2 else langs attention_mask=attention_mask,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids langs=langs,
position_ids = inputs[4] if len(inputs) > 4 else position_ids token_type_ids=token_type_ids,
lengths = inputs[5] if len(inputs) > 5 else lengths position_ids=position_ids,
cache = inputs[6] if len(inputs) > 6 else cache lengths=lengths,
head_mask = inputs[7] if len(inputs) > 7 else head_mask cache=cache,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[11] if len(inputs) > 11 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 12, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
langs = inputs.get("langs", langs) output_attentions = (
token_type_ids = inputs.get("token_type_ids", token_type_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
position_ids = inputs.get("position_ids", position_ids) )
lengths = inputs.get("lengths", lengths) output_hidden_states = (
cache = inputs.get("cache", cache) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
bs, slen = shape_list(input_ids) bs, slen = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs_embeds)[:2] bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None: if inputs["lengths"] is None:
if input_ids is not None: if inputs["input_ids"] is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1) inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else: else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32) inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(lengths)[0], bs shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched" ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -496,26 +542,26 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -496,26 +542,26 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs # assert src_enc.size(0) == bs
# generate masks # generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids # position_ids
if position_ids is None: if inputs["position_ids"] is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0) inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else: else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen] shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched" ), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(langs), [bs, slen] shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched" ), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -523,34 +569,34 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -523,34 +569,34 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.n_layers inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if cache is not None and input_ids is not None: if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - cache["slen"] _slen = slen - inputs["cache"]["slen"]
input_ids = input_ids[:, -_slen:] inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
position_ids = position_ids[:, -_slen:] inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if langs is not None: if inputs["langs"] is not None:
langs = langs[:, -_slen:] inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:] mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.embeddings(input_ids) inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids) tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb: if inputs["langs"] is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(inputs["langs"])
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training) tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# hidden_states and attentions cannot be None in graph mode. # hidden_states and attentions cannot be None in graph mode.
...@@ -562,7 +608,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -562,7 +608,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# LayerDrop # LayerDrop
dropout_probability = tf.random.uniform([1], 0, 1) dropout_probability = tf.random.uniform([1], 0, 1)
if training and tf.less(dropout_probability, self.layerdrop): if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
continue continue
if output_hidden_states: if output_hidden_states:
...@@ -571,27 +617,39 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -571,27 +617,39 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# self attention # self attention
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
else: else:
tensor_normalized = self.layer_norm1[i](tensor) tensor_normalized = self.layer_norm1[i](tensor)
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor_normalized,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn tensor = tensor + attn
# encoder attention (for decoder only) # encoder attention (for decoder only)
...@@ -616,8 +674,8 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -616,8 +674,8 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
if cache is not None: if inputs["cache"] is not None:
cache["slen"] += tensor.size(1) inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
...@@ -724,7 +782,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -724,7 +782,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id langs = tf.ones_like(inputs) * lang_id
else: else:
langs = None langs = None
return {"inputs": inputs, "langs": langs} return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -733,11 +791,56 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -733,11 +791,56 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
output_type=TFFlaubertWithLMHeadModelOutput, output_type=TFFlaubertWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
return_dict = kwargs.get("return_dict") self,
return_dict = return_dict if return_dict is not None else self.transformer.return_dict input_ids=None,
transformer_outputs = self.transformer(inputs, **kwargs) attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 LXMERT model. """ """ TF 2.0 LXMERT model. """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
...@@ -30,8 +29,7 @@ from ...file_utils import ( ...@@ -30,8 +29,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_lxmert import LxmertConfig from .configuration_lxmert import LxmertConfig
...@@ -716,7 +714,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -716,7 +714,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
visual_feats=None, visual_feats=None,
visual_pos=None, visual_pos=None,
attention_mask=None, attention_mask=None,
...@@ -727,60 +725,55 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -727,60 +725,55 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
visual_feats = inputs[1] if len(inputs) > 1 else visual_feats input_ids=input_ids,
visual_pos = inputs[2] if len(inputs) > 2 else visual_pos visual_feats=visual_feats,
attention_mask = inputs[3] if len(inputs) > 3 else attention_mask visual_pos=visual_pos,
visual_attention_mask = inputs[4] if len(inputs) > 4 else visual_attention_mask attention_mask=attention_mask,
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids visual_attention_mask=visual_attention_mask,
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds token_type_ids=token_type_ids,
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[9] if len(inputs) > 9 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 10, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, dict): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
visual_feats = inputs.get("visual_feats", visual_feats) )
visual_pos = inputs.get("visual_pos", visual_pos) output_attentions = (
attention_mask = inputs.get("attention_mask", attention_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
visual_attention_mask = inputs.get("visual_attention_mask", visual_attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_hidden_states = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_attentions = inputs.get("output_attentions", output_attentions) )
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if visual_pos is None or visual_feats is None:
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.") raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -791,8 +784,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -791,8 +784,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if visual_attention_mask is not None: if inputs["visual_attention_mask"] is not None:
extended_visual_attention_mask = visual_attention_mask[:, tf.newaxis, tf.newaxis, :] extended_visual_attention_mask = inputs["visual_attention_mask"][:, tf.newaxis, tf.newaxis, :]
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32) extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0 extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
...@@ -800,17 +793,19 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -800,17 +793,19 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_visual_attention_mask = None extended_visual_attention_mask = None
# Positional Word Embeddings # Positional Word Embeddings
embedding_output = self.embeddings([input_ids, token_type_ids, inputs_embeds], training=training) embedding_output = self.embeddings(
[inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"]], training=inputs["training"]
)
# Run Lxmert encoder # Run Lxmert encoder
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
visual_feats, inputs["visual_feats"],
visual_pos, inputs["visual_pos"],
extended_visual_attention_mask, extended_visual_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=inputs["training"],
) )
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2] visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
vision_hidden_states = visual_encoder_outputs[0] vision_hidden_states = visual_encoder_outputs[0]
...@@ -977,8 +972,50 @@ class TFLxmertModel(TFLxmertPreTrainedModel): ...@@ -977,8 +972,50 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
output_type=TFLxmertModelOutput, output_type=TFLxmertModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, *args, **kwargs): def call(
outputs = self.lxmert(inputs, *args, **kwargs) self,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
visual_attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
...@@ -1228,7 +1265,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1228,7 +1265,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs=None, input_ids=None,
visual_feats=None, visual_feats=None,
visual_pos=None, visual_pos=None,
attention_mask=None, attention_mask=None,
...@@ -1242,6 +1279,8 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1242,6 +1279,8 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False,
**kwargs,
): ):
r""" r"""
masked_lm_labels (``tf.Tensor`` of shape ``(batch_size, sequence_length)``, `optional`): masked_lm_labels (``tf.Tensor`` of shape ``(batch_size, sequence_length)``, `optional`):
...@@ -1263,31 +1302,38 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1263,31 +1302,38 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
Returns: Returns:
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
masked_lm_labels = inputs[7] if len(inputs) > 7 else masked_lm_labels func=self.call,
obj_labels = inputs[8] if len(inputs) > 8 else obj_labels input_ids=input_ids,
matched_label = inputs[9] if len(inputs) > 9 else matched_label
ans = inputs[10] if len(inputs) > 10 else ans
if len(inputs) > 10:
inputs = inputs[:10]
elif isinstance(inputs, (dict, BatchEncoding)):
masked_lm_labels = inputs.pop("masked_lm_labels", masked_lm_labels)
obj_labels = inputs.pop("obj_labels", obj_labels)
matched_label = inputs.pop("matched_label", matched_label)
ans = inputs.pop("ans", ans)
return_dict = return_dict if return_dict is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
inputs,
visual_feats=visual_feats, visual_feats=visual_feats,
visual_pos=visual_pos, visual_pos=visual_pos,
attention_mask=attention_mask, attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask, visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states, masked_lm_labels=masked_lm_labels,
obj_labels=obj_labels,
matched_label=matched_label,
ans=ans,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
lang_output, visual_output, pooled_output = ( lang_output, visual_output, pooled_output = (
...@@ -1303,29 +1349,34 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1303,29 +1349,34 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_loss = ( total_loss = (
None None
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None) if (
inputs["masked_lm_labels"] is None
and inputs["matched_label"] is None
and inputs["obj_labels"] is None
and inputs["ans"] is None
)
else tf.constant(0.0) else tf.constant(0.0)
) )
losses = () losses = ()
if masked_lm_labels is not None and self.task_mask_lm: if inputs["masked_lm_labels"] is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"]( masked_lm_loss = self.loss_fcts["ce"](
tf.reshape(masked_lm_labels, [-1]), tf.reshape(inputs["masked_lm_labels"], [-1]),
tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]), tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]),
) )
total_loss += masked_lm_loss total_loss += masked_lm_loss
losses += (masked_lm_loss,) losses += (masked_lm_loss,)
if matched_label is not None and self.task_matched: if inputs["matched_label"] is not None and self.task_matched:
matched_loss = self.loss_fcts["ce"]( matched_loss = self.loss_fcts["ce"](
tf.reshape(matched_label, [-1]), tf.reshape(inputs["matched_label"], [-1]),
tf.reshape(cross_relationship_score, [-1, 2]), tf.reshape(cross_relationship_score, [-1, 2]),
) )
total_loss += matched_loss total_loss += matched_loss
losses += (matched_loss,) losses += (matched_loss,)
if obj_labels is not None and self.task_obj_predict: if inputs["obj_labels"] is not None and self.task_obj_predict:
total_visn_loss = 0.0 total_visn_loss = 0.0
visn_prediction_scores_dict = self.obj_predict_head(visual_output) visn_prediction_scores_dict = self.obj_predict_head(visual_output)
for key, key_info in self.visual_losses.items(): for key, key_info in self.visual_losses.items():
label, mask_conf = obj_labels[key] label, mask_conf = inputs["obj_labels"][key]
output_dim = key_info["num"] output_dim = key_info["num"]
loss_fct_name = key_info["loss"] loss_fct_name = key_info["loss"]
label_shape = key_info["shape"] label_shape = key_info["shape"]
...@@ -1343,7 +1394,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1343,7 +1394,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_visn_loss += visn_loss total_visn_loss += visn_loss
losses += (visn_loss,) losses += (visn_loss,)
total_loss += total_visn_loss total_loss += total_visn_loss
if ans is not None and self.task_qa: if inputs["ans"] is not None and self.task_qa:
answer_loss = self.loss_fcts["ce"]( answer_loss = self.loss_fcts["ce"](
tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels]) tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels])
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment