Unverified Commit b7018abf authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: Unpack model inputs through a decorator (#15907)

* MVP

* apply decorator to TFBertModel

* finish updating bert

* update rembert (copy-linked to bert)

* update roberta (copy-linked to bert); Fix args

* Now working for non-text modalities
parent 19597998
...@@ -344,6 +344,46 @@ def booleans_processing(config, **kwargs): ...@@ -344,6 +344,46 @@ def booleans_processing(config, **kwargs):
return final_booleans return final_booleans
def unpack_inputs(func):
"""
Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
(common case in Keras).
Args:
func (`callable`):
The callable function of the TensorFlow model.
Returns:
A callable that wraps the original `func` with the behavior described above.
"""
original_signature = inspect.signature(func)
@functools.wraps(func)
def run_call_with_unpacked_inputs(self, *args, **kwargs):
# isolates the actual `**kwargs` for the decorated function
kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
fn_args_and_kwargs.update({"kwargs_call": kwargs_call})
# move any arg into kwargs, if they exist
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
# process the inputs and call the wrapped function
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
main_input = fn_args_and_kwargs.pop(main_input_name)
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
return func(self, **unpacked_inputs)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
# function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
# Keras would attempt to check the first argument against the literal signature of the wrapper.
run_call_with_unpacked_inputs.__signature__ = original_signature
return run_call_with_unpacked_inputs
def input_processing(func, config, input_ids, **kwargs): def input_processing(func, config, input_ids, **kwargs):
""" """
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
......
...@@ -55,8 +55,8 @@ from ...modeling_tf_utils import ( ...@@ -55,8 +55,8 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -720,6 +720,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -720,6 +720,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids: Optional[TFModelInputType] = None, input_ids: Optional[TFModelInputType] = None,
...@@ -738,59 +739,40 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -738,59 +739,40 @@ class TFBertMainLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if not self.config.is_decoder: if not self.config.is_decoder:
inputs["use_cache"] = False use_cache = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
if inputs["past_key_values"] is None: if past_key_values is None:
past_key_values_length = 0 past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer) past_key_values = [None] * len(self.encoder.layer)
else: else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] past_key_values_length = shape_list(past_key_values[0][0])[-2]
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
embedding_output = self.embeddings( embedding_output = self.embeddings(
input_ids=inputs["input_ids"], input_ids=input_ids,
position_ids=inputs["position_ids"], position_ids=position_ids,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
training=inputs["training"], training=training,
) )
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -798,7 +780,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -798,7 +780,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_values_length mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py` # Copied from `modeling_tf_t5.py`
...@@ -811,18 +793,18 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -811,18 +793,18 @@ class TFBertMainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None], seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask) attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
) )
if inputs["past_key_values"][0] is not None: if past_key_values[0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else: else:
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
) )
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
...@@ -836,18 +818,16 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -836,18 +818,16 @@ class TFBertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000 # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None: if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast( encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3: if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
...@@ -863,29 +843,29 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -863,29 +843,29 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=head_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]: if not return_dict:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -1063,6 +1043,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -1063,6 +1043,7 @@ class TFBertModel(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1108,9 +1089,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -1108,9 +1089,7 @@ class TFBertModel(TFBertPreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). Set to `False` during training, `True` during generation `past_key_values`). Set to `False` during training, `True` during generation
""" """
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1125,25 +1104,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -1125,25 +1104,7 @@ class TFBertModel(TFBertPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
) )
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
def serving_output( def serving_output(
...@@ -1196,6 +1157,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -1196,6 +1157,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1244,9 +1206,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -1244,9 +1206,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
>>> outputs = model(input_ids) >>> outputs = model(input_ids)
>>> prediction_scores, seq_relationship_scores = outputs[:2] >>> prediction_scores, seq_relationship_scores = outputs[:2]
```""" ```"""
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1256,34 +1216,19 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -1256,34 +1216,19 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
next_sentence_label=next_sentence_label,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
seq_relationship_score = self.nsp(pooled_output=pooled_output) seq_relationship_score = self.nsp(pooled_output=pooled_output)
total_loss = None total_loss = None
if inputs["labels"] is not None and inputs["next_sentence_label"] is not None: if labels is not None and next_sentence_label is not None:
d_labels = {"labels": inputs["labels"]} d_labels = {"labels": labels}
d_labels["next_sentence_label"] = inputs["next_sentence_label"] d_labels["next_sentence_label"] = next_sentence_label
total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:] output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output return ((total_loss,) + output) if total_loss is not None else output
...@@ -1336,6 +1281,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1336,6 +1281,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1364,9 +1310,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1364,9 +1310,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1376,31 +1320,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1376,31 +1320,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
loss = ( loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
None
if inputs["labels"] is None
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
)
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1455,6 +1381,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1455,6 +1381,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
@unpack_inputs
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
...@@ -1503,9 +1430,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1503,9 +1430,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1519,37 +1444,19 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1519,37 +1444,19 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.mlm(sequence_output=sequence_output, training=inputs["training"]) logits = self.mlm(sequence_output=sequence_output, training=training)
loss = None loss = None
if inputs["labels"] is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = labels[:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1597,6 +1504,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi ...@@ -1597,6 +1504,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.nsp = TFBertNSPHead(config, name="nsp___cls") self.nsp = TFBertNSPHead(config, name="nsp___cls")
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1633,9 +1541,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi ...@@ -1633,9 +1541,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
>>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0] >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0]
>>> assert logits[0][0] < logits[0][1] # the next sentence was random >>> assert logits[0][0] < logits[0][1] # the next sentence was random
```""" ```"""
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1645,31 +1551,17 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi ...@@ -1645,31 +1551,17 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
next_sentence_label=next_sentence_label,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_scores = self.nsp(pooled_output=pooled_output) seq_relationship_scores = self.nsp(pooled_output=pooled_output)
next_sentence_loss = ( next_sentence_loss = (
None None
if inputs["next_sentence_label"] is None if next_sentence_label is None
else self.hf_compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores) else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
) )
if not inputs["return_dict"]: if not return_dict:
output = (seq_relationship_scores,) + outputs[2:] output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
...@@ -1715,6 +1607,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1715,6 +1607,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
name="classifier", name="classifier",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1743,9 +1636,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1743,9 +1636,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1755,28 +1646,14 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1755,28 +1646,14 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) pooled_output = self.dropout(inputs=pooled_output, training=training)
logits = self.classifier(inputs=pooled_output) logits = self.classifier(inputs=pooled_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1825,6 +1702,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1825,6 +1702,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1852,51 +1730,26 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1852,51 +1730,26 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing( if input_ids is not None:
func=self.call, num_choices = shape_list(input_ids)[1]
config=self.config, seq_length = shape_list(input_ids)[2]
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = ( flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
tf.reshape(tensor=inputs["input_ids"], shape=(-1, seq_length)) if inputs["input_ids"] is not None else None
)
flat_attention_mask = ( flat_attention_mask = (
tf.reshape(tensor=inputs["attention_mask"], shape=(-1, seq_length)) tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
if inputs["attention_mask"] is not None
else None
) )
flat_token_type_ids = ( flat_token_type_ids = (
tf.reshape(tensor=inputs["token_type_ids"], shape=(-1, seq_length)) tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
if inputs["token_type_ids"] is not None
else None
) )
flat_position_ids = ( flat_position_ids = (
tf.reshape(tensor=inputs["position_ids"], shape=(-1, seq_length)) tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
if inputs["position_ids"] is not None
else None
) )
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(tensor=inputs["inputs_embeds"], shape=(-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
outputs = self.bert( outputs = self.bert(
...@@ -1904,22 +1757,20 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1904,22 +1757,20 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
position_ids=flat_position_ids, position_ids=flat_position_ids,
head_mask=inputs["head_mask"], head_mask=head_mask,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) pooled_output = self.dropout(inputs=pooled_output, training=training)
logits = self.classifier(inputs=pooled_output) logits = self.classifier(inputs=pooled_output)
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
loss = ( loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1985,6 +1836,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1985,6 +1836,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
name="classifier", name="classifier",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2011,9 +1863,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -2011,9 +1863,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -2023,28 +1873,14 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -2023,28 +1873,14 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"]) sequence_output = self.dropout(inputs=sequence_output, training=training)
logits = self.classifier(inputs=sequence_output) logits = self.classifier(inputs=sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2091,6 +1927,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -2091,6 +1927,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
name="qa_outputs", name="qa_outputs",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2124,9 +1961,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -2124,9 +1961,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( outputs = self.bert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -2136,22 +1971,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -2136,22 +1971,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(inputs=sequence_output) logits = self.qa_outputs(inputs=sequence_output)
...@@ -2160,12 +1980,12 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -2160,12 +1980,12 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
end_logits = tf.squeeze(input=end_logits, axis=-1) end_logits = tf.squeeze(input=end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -49,8 +49,8 @@ from ...modeling_tf_utils import ( ...@@ -49,8 +49,8 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -642,6 +642,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -642,6 +642,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call( def call(
self, self,
...@@ -661,59 +662,40 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -661,59 +662,40 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if not self.config.is_decoder: if not self.config.is_decoder:
inputs["use_cache"] = False use_cache = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
if inputs["past_key_values"] is None: if past_key_values is None:
past_key_values_length = 0 past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer) past_key_values = [None] * len(self.encoder.layer)
else: else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] past_key_values_length = shape_list(past_key_values[0][0])[-2]
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
embedding_output = self.embeddings( embedding_output = self.embeddings(
input_ids=inputs["input_ids"], input_ids=input_ids,
position_ids=inputs["position_ids"], position_ids=position_ids,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
training=inputs["training"], training=training,
) )
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -721,7 +703,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -721,7 +703,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_values_length mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py` # Copied from `modeling_tf_t5.py`
...@@ -734,18 +716,18 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -734,18 +716,18 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None], seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask) attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
) )
if inputs["past_key_values"][0] is not None: if past_key_values[0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else: else:
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
) )
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
...@@ -759,18 +741,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -759,18 +741,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000 # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None: if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast( encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3: if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
...@@ -786,29 +766,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -786,29 +766,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=head_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]: if not return_dict:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -955,6 +935,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -955,6 +935,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
self.rembert = TFRemBertMainLayer(config, name="rembert") self.rembert = TFRemBertMainLayer(config, name="rembert")
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1000,9 +981,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -1000,9 +981,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). Set to `False` during training, `True` during generation `past_key_values`). Set to `False` during training, `True` during generation
""" """
inputs = input_processing( outputs = self.rembert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1017,23 +996,6 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -1017,23 +996,6 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.rembert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1077,6 +1039,7 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -1077,6 +1039,7 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions return self.mlm.predictions
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1105,9 +1068,7 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -1105,9 +1068,7 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing( outputs = self.rembert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1117,31 +1078,13 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -1117,31 +1078,13 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.rembert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
loss = ( loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
None
if inputs["labels"] is None
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
)
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1188,6 +1131,7 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1188,6 +1131,7 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
@unpack_inputs
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint="rembert", checkpoint="rembert",
...@@ -1236,9 +1180,7 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1236,9 +1180,7 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
inputs = input_processing( outputs = self.rembert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1252,37 +1194,19 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1252,37 +1194,19 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.rembert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.mlm(sequence_output=sequence_output, training=inputs["training"]) logits = self.mlm(sequence_output=sequence_output, training=training)
loss = None loss = None
if inputs["labels"] is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = labels[:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1338,6 +1262,7 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla ...@@ -1338,6 +1262,7 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla
name="classifier", name="classifier",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1366,9 +1291,7 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla ...@@ -1366,9 +1291,7 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( outputs = self.rembert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1378,28 +1301,14 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla ...@@ -1378,28 +1301,14 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.rembert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) pooled_output = self.dropout(inputs=pooled_output, training=training)
logits = self.classifier(inputs=pooled_output) logits = self.classifier(inputs=pooled_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1444,6 +1353,7 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1444,6 +1353,7 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1471,51 +1381,27 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1471,51 +1381,27 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None: if input_ids is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(input_ids)[1]
seq_length = shape_list(inputs["input_ids"])[2] seq_length = shape_list(input_ids)[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = ( flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
tf.reshape(tensor=inputs["input_ids"], shape=(-1, seq_length)) if inputs["input_ids"] is not None else None
)
flat_attention_mask = ( flat_attention_mask = (
tf.reshape(tensor=inputs["attention_mask"], shape=(-1, seq_length)) tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
if inputs["attention_mask"] is not None
else None
) )
flat_token_type_ids = ( flat_token_type_ids = (
tf.reshape(tensor=inputs["token_type_ids"], shape=(-1, seq_length)) tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
if inputs["token_type_ids"] is not None
else None
) )
flat_position_ids = ( flat_position_ids = (
tf.reshape(tensor=inputs["position_ids"], shape=(-1, seq_length)) tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
if inputs["position_ids"] is not None
else None
) )
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(tensor=inputs["inputs_embeds"], shape=(-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
outputs = self.rembert( outputs = self.rembert(
...@@ -1523,22 +1409,20 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1523,22 +1409,20 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
position_ids=flat_position_ids, position_ids=flat_position_ids,
head_mask=inputs["head_mask"], head_mask=head_mask,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) pooled_output = self.dropout(inputs=pooled_output, training=training)
logits = self.classifier(inputs=pooled_output) logits = self.classifier(inputs=pooled_output)
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
loss = ( loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1589,6 +1473,7 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific ...@@ -1589,6 +1473,7 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1615,9 +1500,7 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific ...@@ -1615,9 +1500,7 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( outputs = self.rembert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1627,28 +1510,14 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific ...@@ -1627,28 +1510,14 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.rembert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"]) sequence_output = self.dropout(inputs=sequence_output, training=training)
logits = self.classifier(inputs=sequence_output) logits = self.classifier(inputs=sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1684,6 +1553,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin ...@@ -1684,6 +1553,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1717,9 +1587,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin ...@@ -1717,9 +1587,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( outputs = self.rembert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1729,22 +1597,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin ...@@ -1729,22 +1597,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.rembert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(inputs=sequence_output) logits = self.qa_outputs(inputs=sequence_output)
...@@ -1753,12 +1606,12 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin ...@@ -1753,12 +1606,12 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin
end_logits = tf.squeeze(input=end_logits, axis=-1) end_logits = tf.squeeze(input=end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -50,8 +50,8 @@ from ...modeling_tf_utils import ( ...@@ -50,8 +50,8 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -606,6 +606,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -606,6 +606,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call( def call(
self, self,
...@@ -625,59 +626,40 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -625,59 +626,40 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if not self.config.is_decoder: if not self.config.is_decoder:
inputs["use_cache"] = False use_cache = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
if inputs["past_key_values"] is None: if past_key_values is None:
past_key_values_length = 0 past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer) past_key_values = [None] * len(self.encoder.layer)
else: else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] past_key_values_length = shape_list(past_key_values[0][0])[-2]
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
embedding_output = self.embeddings( embedding_output = self.embeddings(
input_ids=inputs["input_ids"], input_ids=input_ids,
position_ids=inputs["position_ids"], position_ids=position_ids,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
training=inputs["training"], training=training,
) )
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -685,7 +667,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -685,7 +667,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_values_length mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py` # Copied from `modeling_tf_t5.py`
...@@ -698,18 +680,18 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -698,18 +680,18 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None], seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask) attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
) )
if inputs["past_key_values"][0] is not None: if past_key_values[0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else: else:
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
) )
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
...@@ -723,18 +705,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -723,18 +705,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000 # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None: if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast( encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3: if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
...@@ -750,29 +730,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -750,29 +730,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=head_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]: if not return_dict:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -932,6 +912,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -932,6 +912,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.roberta = TFRobertaMainLayer(config, name="roberta") self.roberta = TFRobertaMainLayer(config, name="roberta")
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -977,9 +958,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -977,9 +958,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). Set to `False` during training, `True` during generation `past_key_values`). Set to `False` during training, `True` during generation
""" """
inputs = input_processing( outputs = self.roberta(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -994,23 +973,6 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -994,23 +973,6 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1107,6 +1069,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -1107,6 +1069,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1135,10 +1098,8 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -1135,10 +1098,8 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing( outputs = self.roberta(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1147,29 +1108,15 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -1147,29 +1108,15 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores) loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1221,6 +1168,7 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1221,6 +1168,7 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1270,9 +1218,7 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1270,9 +1218,7 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
inputs = input_processing( outputs = self.roberta(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1286,38 +1232,20 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1286,38 +1232,20 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.lm_head(hidden_states=sequence_output, training=inputs["training"]) logits = self.lm_head(hidden_states=sequence_output, training=training)
loss = None loss = None
if inputs["labels"] is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = labels[:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1399,6 +1327,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -1399,6 +1327,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.classifier = TFRobertaClassificationHead(config, name="classifier") self.classifier = TFRobertaClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1427,10 +1356,8 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -1427,10 +1356,8 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( outputs = self.roberta(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1439,28 +1366,14 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -1439,28 +1366,14 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=inputs["training"]) logits = self.classifier(sequence_output, training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1510,6 +1423,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1510,6 +1423,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1537,60 +1451,38 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1537,60 +1451,38 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None: if input_ids is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(input_ids)[1]
seq_length = shape_list(inputs["input_ids"])[2] seq_length = shape_list(input_ids)[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = ( flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
) flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
outputs = self.roberta( outputs = self.roberta(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
inputs["head_mask"], head_mask,
inputs["inputs_embeds"], inputs_embeds,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=inputs["training"]) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1647,6 +1539,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -1647,6 +1539,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1673,10 +1566,8 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -1673,10 +1566,8 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( outputs = self.roberta(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1685,30 +1576,16 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -1685,30 +1576,16 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1747,6 +1624,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1747,6 +1624,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1780,10 +1658,8 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1780,10 +1658,8 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( outputs = self.roberta(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1792,22 +1668,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1792,22 +1668,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1817,12 +1678,12 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1817,12 +1678,12 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -37,8 +37,8 @@ from ...modeling_tf_utils import ( ...@@ -37,8 +37,8 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -781,6 +781,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer): ...@@ -781,6 +781,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
attention_mask = tf.cast(tf.reverse(tf.math.cumsum(tf.reverse(attention_mask, [-1]), -1), [-1]), tf.int64) attention_mask = tf.cast(tf.reverse(tf.math.cumsum(tf.reverse(attention_mask, [-1]), -1), [-1]), tf.int64)
return attention_mask return attention_mask
@unpack_inputs
def call( def call(
self, self,
input_features=None, input_features=None,
...@@ -822,81 +823,64 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer): ...@@ -822,81 +823,64 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
""" """
inputs = input_processing( if input_features is None:
func=self.call,
config=self.config,
input_ids=input_features,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["input_features"] = inputs.pop("input_ids")
if inputs["input_features"] is None:
raise ValueError("You have to specify input_features") raise ValueError("You have to specify input_features")
inputs_embeds = self.conv(inputs["input_features"]) inputs_embeds = self.conv(input_features)
inputs_embeds = self.embed_scale * inputs_embeds inputs_embeds = self.embed_scale * inputs_embeds
# subsample attention mask if necessary # subsample attention mask if necessary
if inputs["attention_mask"] is not None: if attention_mask is not None:
inputs["attention_mask"] = self._get_feature_vector_attention_mask( attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
inputs_embeds.shape[1], inputs["attention_mask"] padding_mask = tf.cast(tf.math.not_equal(attention_mask, 1), tf.int64)
)
padding_mask = tf.cast(tf.math.not_equal(inputs["attention_mask"], 1), tf.int64)
else: else:
padding_mask = tf.zeros(inputs_embeds.shape[:-1], dtype=tf.int64) padding_mask = tf.zeros(inputs_embeds.shape[:-1], dtype=tf.int64)
embed_pos = self.embed_positions(padding_mask) embed_pos = self.embed_positions(padding_mask)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"]) attention_mask = _expand_mask(attention_mask)
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if output_hidden_states else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if training and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer( hidden_states, attn = encoder_layer(
hidden_states, hidden_states,
inputs["attention_mask"], attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, head_mask[idx] if head_mask is not None else None,
training=training, training=training,
) )
if inputs["output_attentions"]: if output_attentions:
all_attentions += (attn,) all_attentions += (attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
...@@ -957,6 +941,7 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer): ...@@ -957,6 +941,7 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
return combined_attention_mask return combined_attention_mask
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1034,114 +1019,92 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer): ...@@ -1034,114 +1019,92 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length # past_key_values_length
past_key_values_length = ( past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
)
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
else: else:
inputs_embeds = inputs["inputs_embeds"] inputs_embeds = inputs_embeds
attention_mask = self._prepare_decoder_attention_mask( attention_mask = self._prepare_decoder_attention_mask(
inputs["attention_mask"], input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
# embed positions # embed positions
positions = self.embed_positions(inputs["input_ids"], past_key_values_length=past_key_values_length) positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if output_attentions else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if inputs["use_cache"] else None next_decoder_cache = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for attn_mask in ["head_mask", "cross_attn_head_mask"]: for attn_mask in [head_mask, cross_attn_head_mask]:
if inputs[attn_mask] is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs[attn_mask])[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if training and (dropout_probability < self.layerdrop):
continue continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
cross_attn_layer_head_mask = ( cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
inputs["cross_attn_head_mask"][idx] if inputs["cross_attn_head_mask"] is not None else None
)
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=encoder_attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
if inputs["use_cache"]: if use_cache:
next_decoder_cache += (present_key_value,) next_decoder_cache += (present_key_value,)
if inputs["output_attentions"]: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None: if encoder_hidden_states is not None:
all_cross_attns += (layer_cross_attn,) all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if not inputs["return_dict"]: if not return_dict:
return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPastAndCrossAttentions( return TFBaseModelOutputWithPastAndCrossAttentions(
...@@ -1170,6 +1133,7 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer): ...@@ -1170,6 +1133,7 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings):
self.decoder.embed_tokens = new_embeddings self.decoder.embed_tokens = new_embeddings
@unpack_inputs
def call( def call(
self, self,
input_features=None, input_features=None,
...@@ -1189,85 +1153,61 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer): ...@@ -1189,85 +1153,61 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
func=self.call, output_hidden_states = (
config=self.config, output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
input_ids=input_features, )
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features=input_features,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["input_features"] = inputs.pop("input_ids")
# if the attribute is not set, fetch it from the config
for attr in ("output_attentions", "output_hidden_states", "use_cache"):
if inputs[attr] is None:
inputs[attr] = getattr(self.config, attr)
inputs["return_dict"] = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
)
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_features=inputs["input_features"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput( encoder_outputs = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0], last_hidden_state=encoder_outputs[0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): elif not return_dict and not isinstance(encoder_outputs, tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() encoder_outputs = encoder_outputs.to_tuple()
# downsample encoder attention mask # downsample encoder attention mask
if inputs["attention_mask"] is not None: if attention_mask is not None:
inputs["encoder_attention_mask"] = self.encoder._get_feature_vector_attention_mask( encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
inputs["encoder_outputs"][0].shape[1], inputs["attention_mask"] encoder_outputs[0].shape[1], attention_mask
) )
else: else:
inputs["encoder_attention_mask"] = None encoder_attention_mask = None
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=inputs["decoder_input_ids"], input_ids=decoder_input_ids,
attention_mask=inputs["decoder_attention_mask"], attention_mask=decoder_attention_mask,
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=encoder_attention_mask,
head_mask=inputs["decoder_head_mask"], head_mask=decoder_head_mask,
cross_attn_head_mask=inputs["cross_attn_head_mask"], cross_attn_head_mask=cross_attn_head_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=decoder_inputs_embeds,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
...@@ -1275,9 +1215,9 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer): ...@@ -1275,9 +1215,9 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
) )
...@@ -1297,6 +1237,7 @@ class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel): ...@@ -1297,6 +1237,7 @@ class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model.decoder return self.model.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1323,10 +1264,8 @@ class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel): ...@@ -1323,10 +1264,8 @@ class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing( outputs = self.model(
func=self.call, input_features=input_features,
config=self.config,
input_ids=input_features,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
...@@ -1341,27 +1280,6 @@ class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel): ...@@ -1341,27 +1280,6 @@ class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["input_features"] = inputs.pop("input_ids")
outputs = self.model(
input_features=inputs["input_features"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1412,6 +1330,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus ...@@ -1412,6 +1330,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
@unpack_inputs
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1473,61 +1392,35 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus ...@@ -1473,61 +1392,35 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
>>> transcription = processor.batch_decode(generated_ids) >>> transcription = processor.batch_decode(generated_ids)
```""" ```"""
inputs = input_processing( return_dict = return_dict if return_dict is not None else self.config.use_return_dict
func=self.call,
config=self.config, if labels is not None:
input_ids=input_features, if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features=input_features,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["input_features"] = inputs.pop("input_ids")
inputs["return_dict"] = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
)
if inputs["labels"] is not None:
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features=inputs["input_features"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
lm_logits = self.lm_head(outputs[0]) lm_logits = self.lm_head(outputs[0])
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment