Unverified Commit cf10d4cf authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Cleaning TensorFlow models (#5229)

* Cleaning TensorFlow models

Update all classes


stylr

* Don't average loss
parent 609e0c58
......@@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
loss, logits = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
loss, scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.albert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict):
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
......@@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
......
......@@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
loss, logits = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
......@@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
loss, scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
assert answer == "a nice puppet"
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.bert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......
......@@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
loss, logits = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[6] if len(inputs) > 6 else labels
if len(inputs) > 6:
inputs = inputs[:6]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
distilbert_output = self.distilbert(
input_ids,
inputs,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
......@@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
loss, scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[6] if len(inputs) > 6 else labels
if len(inputs) > 6:
inputs = inputs[:6]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.distilbert(
input_ids,
inputs,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
......@@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, logits)
......@@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
super().__init__(config, *inputs, **kwargs)
self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
self.pre_classifier = tf.keras.layers.Dense(
config.dim,
kernel_initializer=get_initializer(config.initializer_range),
......@@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
labels = inputs[6] if len(inputs) > 6 else labels
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 4, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
labels = inputs.get("labels", labels)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
......@@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
flat_inputs = [
flat_input_ids,
flat_attention_mask,
head_mask,
inputs_embeds,
flat_inputs_embeds,
output_attentions,
output_hidden_states,
]
......@@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[6] if len(inputs) > 6 else start_positions
end_positions = inputs[7] if len(inputs) > 7 else end_positions
if len(inputs) > 6:
inputs = inputs[:6]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
distilbert_output = self.distilbert(
input_ids,
inputs,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
......
......@@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
loss, scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
discriminator_hidden_states = self.electra(
input_ids,
inputs,
attention_mask,
token_type_ids,
position_ids,
......@@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
discriminator_hidden_states = self.electra(
input_ids,
inputs,
attention_mask,
token_type_ids,
position_ids,
......
......@@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
loss, logits = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
assert answer == "a nice puppet"
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.mobilebert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs."
labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs."
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
......@@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
loss, scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......
......@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils_base import BatchEncoding
logger = logging.getLogger(__name__)
......@@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
loss, logits = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
......@@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
loss, scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.roberta(
input_ids,
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......
......@@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call(
self,
input_ids,
inputs=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
loss, logits = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[11] if len(inputs) > 11 else labels
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
input_ids,
inputs,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
......@@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
cache=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[11] if len(inputs) > 11 else start_positions
end_positions = inputs[12] if len(inputs) > 12 else end_positions
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
input_ids,
inputs,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
......
......@@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
loss, logits = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
input_ids,
inputs,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
......@@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call(
self,
inputs,
inputs=None,
token_type_ids=None,
input_mask=None,
attention_mask=None,
......@@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
assert len(inputs) <= 12, "Too many inputs."
labels = inputs[12] if len(inputs) > 12 else labels
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
assert len(inputs) <= 12, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
labels = inputs.get("labels", labels)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
......@@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
flat_inputs = [
flat_input_ids,
......@@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_token_type_ids,
flat_input_mask,
head_mask,
inputs_embeds,
flat_inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
......@@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
......@@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
loss, scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
input_ids,
inputs,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
......@@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
inputs=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
head_mask=None,
inputs_embeds=None,
use_cache=True,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False,
):
r"""
......@@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
if isinstance(inputs, (tuple, list)):
start_positions = inputs[12] if len(inputs) > 12 else start_positions
end_positions = inputs[13] if len(inputs) > 13 else end_positions
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
input_ids,
inputs,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
......
......@@ -15,6 +15,7 @@
import copy
import inspect
import os
import random
import tempfile
......@@ -35,6 +36,9 @@ if is_tf_available():
TFAdaptiveEmbedding,
TFSharedEmbeddings,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
)
if _tf_gpu_memory_limit is not None:
......@@ -71,14 +75,25 @@ class TFModelTesterMixin:
test_resize_embeddings = True
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return {
inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
if isinstance(v, tf.Tensor) and v.ndim != 0
else v
for k, v in inputs_dict.items()
}
if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
return inputs_dict
def test_initialization(self):
......@@ -572,6 +587,51 @@ class TFModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
loss_size = tf.size(added_label)
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
input_ids = prepared_for_class.pop("input_ids")
loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.getfullargspec(model.call)[0]
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {1: "input_ids"}
for label_key in label_keys:
label_key_index = signature.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with None, update the values and convert to a tuple
list_input = [None] * sorted_tuple_index_mapping[-1][0]
for index, value in sorted_tuple_index_mapping:
list_input[index - 1] = prepared_for_class[value]
tuple_input = tuple(list_input)
# Send to model
loss = model(tuple_input)[0]
self.assertEqual(loss.shape, [loss_size])
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []
......
......@@ -24,11 +24,14 @@ from .utils import require_tf
if is_tf_available():
import tensorflow as tf
from transformers.modeling_tf_distilbert import (
TFDistilBertModel,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForMultipleChoice,
)
......@@ -147,6 +150,35 @@ class TFDistilBertModelTester:
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
def create_and_check_distilbert_for_multiple_choice(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFDistilBertForMultipleChoice(config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def create_and_check_distilbert_for_token_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFDistilBertForTokenClassification(config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
......@@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForMultipleChoice,
)
if is_tf_available()
else None
......@@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
# @slow
# def test_model_from_pretrained(self):
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
......@@ -29,6 +29,7 @@ if is_tf_available():
TFElectraForMaskedLM,
TFElectraForPreTraining,
TFElectraForTokenClassification,
TFElectraForQuestionAnswering,
)
......@@ -137,6 +138,19 @@ class TFElectraModelTester:
}
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
def create_and_check_electra_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFElectraForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
start_logits, end_logits = model(inputs)
result = {
"start_logits": start_logits.numpy(),
"end_logits": end_logits.numpy(),
}
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
def create_and_check_electra_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......@@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
......
......@@ -32,6 +32,7 @@ if is_tf_available():
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaForQuestionAnswering,
TFRobertaForMultipleChoice,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
......@@ -154,6 +155,25 @@ class TFRobertaModelTester:
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFRobertaForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
......@@ -33,6 +33,7 @@ if is_tf_available():
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForMultipleChoice,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
)
......@@ -66,6 +67,7 @@ class TFXLNetModelTester:
self.bos_token_id = 1
self.eos_token_id = 2
self.pad_token_id = 5
self.num_choices = 4
def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
......@@ -316,6 +318,36 @@ class TFXLNetModelTester:
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def create_and_check_xlnet_for_multiple_choice(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
):
config.num_choices = self.num_choices
model = TFXLNetForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids_1, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(segment_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForMultipleChoice,
)
if is_tf_available()
else ()
......@@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
def test_xlnet_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_for_multiple_choice(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment