Unverified Commit 0f35cda4 authored by Utku Saglam's avatar Utku Saglam Committed by GitHub
Browse files

TF clearer model variable naming: funnel (#16178)


Co-authored-by: default avatarutku saglam <utkusaglam@utku-MacBook-Pro.local>
parent ee27b3d7
......@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
......@@ -748,6 +748,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -760,46 +761,33 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids, training=training)
encoder_outputs = self.encoder(
inputs["inputs_embeds"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
inputs_embeds,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return encoder_outputs
......@@ -834,6 +822,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -846,66 +835,52 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids, training=training)
encoder_outputs = self.encoder(
inputs["inputs_embeds"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
output_attentions=inputs["output_attentions"],
inputs_embeds,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=inputs["return_dict"],
training=inputs["training"],
return_dict=return_dict,
training=training,
)
decoder_outputs = self.decoder(
final_hidden=encoder_outputs[0],
first_block_hidden=encoder_outputs[1][self.block_sizes[0]],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not inputs["return_dict"]:
if not return_dict:
idx = 0
outputs = (decoder_outputs[0],)
if inputs["output_hidden_states"]:
if output_hidden_states:
idx += 1
outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],)
if inputs["output_attentions"]:
if output_attentions:
idx += 1
outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],)
return outputs
......@@ -913,11 +888,9 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
return TFBaseModelOutput(
last_hidden_state=decoder_outputs[0],
hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states)
if inputs["output_hidden_states"]
else None,
attentions=(encoder_outputs.attentions + decoder_outputs.attentions)
if inputs["output_attentions"]
if output_hidden_states
else None,
attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
)
......@@ -1131,6 +1104,7 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -1143,9 +1117,7 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
return self.funnel(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1154,18 +1126,6 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return self.funnel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
......@@ -1185,6 +1145,7 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.funnel = TFFunnelMainLayer(config, name="funnel")
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1204,9 +1165,8 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
return self.funnel(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1215,18 +1175,6 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return self.funnel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
......@@ -1250,6 +1198,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
self.funnel = TFFunnelMainLayer(config, name="funnel")
self.discriminator_predictions = TFFunnelDiscriminatorPredictions(config, name="discriminator_predictions")
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -1279,33 +1228,20 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> logits = model(inputs).logits
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
discriminator_hidden_states = self.funnel(
input_ids,
attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
discriminator_hidden_states = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output)
if not inputs["return_dict"]:
if not return_dict:
return (logits,) + discriminator_hidden_states[1:]
return TFFunnelForPreTrainingOutput(
......@@ -1336,6 +1272,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1362,36 +1299,22 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
input_ids,
attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
training=training,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=inputs["training"])
prediction_scores = self.lm_head(sequence_output, training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not inputs["return_dict"]:
if not return_dict:
output = (prediction_scores,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1425,6 +1348,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
self.funnel = TFFunnelBaseLayer(config, name="funnel")
self.classifier = TFFunnelClassificationHead(config, config.num_labels, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1451,37 +1375,23 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
outputs = self.funnel(
input_ids,
attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output, training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1524,6 +1434,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1549,38 +1460,19 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
......@@ -1589,20 +1481,20 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids,
inputs_embeds=flat_inputs_embeds,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output, training=training)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1653,6 +1545,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1677,38 +1570,24 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
outputs = self.funnel(
input_ids,
attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"])
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1744,6 +1623,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1775,30 +1655,16 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
outputs = self.funnel(
input_ids,
attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
......@@ -1808,11 +1674,11 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"], "end_position": inputs["end_positions"]}
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions, "end_position": end_positions}
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]:
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment