"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "37c5759cbe54ab98edbb410853b458449681afea"
Unverified Commit e3645fd2 authored by Minh Chien Vu's avatar Minh Chien Vu Committed by GitHub
Browse files

Change unpacking of TF mobilebert inputs to use decorator (#16110)



* Change unpacking of TF mobilebert inputs to use decorator

* Move unpack_inputs as the top decorator

* make fixup
Co-authored-by: default avatarChienVM <chien_vm@detomo.co.jp>
parent 5dbf36bd
...@@ -49,8 +49,8 @@ from ...modeling_tf_utils import ( ...@@ -49,8 +49,8 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -679,6 +679,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -679,6 +679,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -693,51 +694,29 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -693,51 +694,29 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
embedding_output = self.embeddings( embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -754,25 +733,25 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -754,25 +733,25 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], head_mask,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
inputs["return_dict"], return_dict,
training=inputs["training"], training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]: if not return_dict:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -929,6 +908,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel): ...@@ -929,6 +908,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -950,9 +930,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel): ...@@ -950,9 +930,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( outputs = self.mobilebert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -963,19 +941,6 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel): ...@@ -963,19 +941,6 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1013,6 +978,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1013,6 +978,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1044,10 +1010,8 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1044,10 +1010,8 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
>>> outputs = model(input_ids) >>> outputs = model(input_ids)
>>> prediction_scores, seq_relationship_scores = outputs[:2] >>> prediction_scores, seq_relationship_scores = outputs[:2]
```""" ```"""
inputs = input_processing( outputs = self.mobilebert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1057,26 +1021,13 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1057,26 +1021,13 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output) seq_relationship_score = self.seq_relationship(pooled_output)
if not inputs["return_dict"]: if not return_dict:
return (prediction_scores, seq_relationship_score) + outputs[2:] return (prediction_scores, seq_relationship_score) + outputs[2:]
return TFMobileBertForPreTrainingOutput( return TFMobileBertForPreTrainingOutput(
...@@ -1120,6 +1071,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1120,6 +1071,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1148,10 +1100,8 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1148,10 +1100,8 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels loss is only computed for the tokens with labels
""" """
inputs = input_processing( outputs = self.mobilebert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1160,28 +1110,14 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1160,28 +1110,14 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.predictions(sequence_output, training=inputs["training"]) prediction_scores = self.predictions(sequence_output, training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores) loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1224,6 +1160,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS ...@@ -1224,6 +1160,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
self.cls = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls") self.cls = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls")
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1259,10 +1196,8 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS ...@@ -1259,10 +1196,8 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
>>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0] >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0]
```""" ```"""
inputs = input_processing( outputs = self.mobilebert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1271,32 +1206,18 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS ...@@ -1271,32 +1206,18 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
next_sentence_label=next_sentence_label,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output) seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = ( next_sentence_loss = (
None None
if inputs["next_sentence_label"] is None if next_sentence_label is None
else self.hf_compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores) else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
) )
if not inputs["return_dict"]: if not return_dict:
output = (seq_relationship_scores,) + outputs[2:] output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
...@@ -1345,6 +1266,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1345,6 +1266,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1373,10 +1295,8 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1373,10 +1295,8 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( outputs = self.mobilebert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1385,30 +1305,16 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1385,30 +1305,16 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=inputs["training"]) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1453,6 +1359,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1453,6 +1359,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1486,10 +1393,8 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1486,10 +1393,8 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( outputs = self.mobilebert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1498,22 +1403,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1498,22 +1403,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1523,12 +1413,11 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1523,12 +1413,11 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions, "end_position": end_positions}
labels["end_position"] = inputs["end_positions"]
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1586,6 +1475,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1586,6 +1475,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward( @add_start_docstrings_to_model_forward(
MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
) )
...@@ -1615,43 +1505,20 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1615,43 +1505,20 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing( if input_ids is not None:
func=self.call, num_choices = shape_list(input_ids)[1]
config=self.config, seq_length = shape_list(input_ids)[2]
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = ( flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
) flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
outputs = self.mobilebert( outputs = self.mobilebert(
...@@ -1659,21 +1526,21 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1659,21 +1526,21 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
inputs["head_mask"], head_mask,
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=inputs["training"]) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1738,6 +1605,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1738,6 +1605,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1764,10 +1632,8 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1764,10 +1632,8 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( outputs = self.mobilebert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1776,30 +1642,16 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1776,30 +1642,16 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment