"pipelines/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "80beca52c4269ac5dcd4955ae496f8c3a44d20ef"
Unverified Commit cf10d4cf authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Cleaning TensorFlow models (#5229)

* Cleaning TensorFlow models

Update all classes


stylr

* Don't average loss
parent 609e0c58
...@@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert( outputs = self.albert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert( outputs = self.albert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.albert( outputs = self.albert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs." output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
elif isinstance(inputs, dict): labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids) token_type_ids = inputs.get("token_type_ids", token_type_ids)
...@@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions) output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs." output_hidden_states = inputs.get("output_hidden_states", output_attentions)
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
......
...@@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert( outputs = self.bert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs." output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions) output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs." output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert( outputs = self.bert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
assert answer == "a nice puppet" assert answer == "a nice puppet"
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.bert( outputs = self.bert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
......
...@@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[6] if len(inputs) > 6 else labels
if len(inputs) > 6:
inputs = inputs[:6]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[6] if len(inputs) > 6 else labels
if len(inputs) > 6:
inputs = inputs[:6]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.distilbert( outputs = self.distilbert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
...@@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.distilbert = TFDistilBertMainLayer(config, name="distilbert") self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
self.pre_classifier = tf.keras.layers.Dense( self.pre_classifier = tf.keras.layers.Dense(
config.dim, config.dim,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
...@@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs." output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
labels = inputs[6] if len(inputs) > 6 else labels
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 4, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
labels = inputs.get("labels", labels)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
flat_inputs = [ flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
head_mask, head_mask,
inputs_embeds, flat_inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] ]
...@@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[6] if len(inputs) > 6 else start_positions
end_positions = inputs[7] if len(inputs) > 7 else end_positions
if len(inputs) > 6:
inputs = inputs[:6]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
......
...@@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, inputs,
attention_mask, attention_mask,
token_type_ids, token_type_ids,
position_ids, position_ids,
...@@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, inputs,
attention_mask, attention_mask,
token_type_ids, token_type_ids,
position_ids, position_ids,
......
...@@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert( outputs = self.mobilebert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
assert answer == "a nice puppet" assert answer == "a nice puppet"
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.mobilebert( outputs = self.mobilebert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs." labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions) output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs." labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert( outputs = self.mobilebert(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
......
...@@ -33,6 +33,7 @@ from .modeling_tf_utils import ( ...@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils_base import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta( outputs = self.roberta(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
elif isinstance(inputs, dict): output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids) token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta( outputs = self.roberta(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[8] if len(inputs) > 8 else start_positions
end_positions = inputs[9] if len(inputs) > 9 else end_positions
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.roberta( outputs = self.roberta(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
......
...@@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids, inputs=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
cache=None, cache=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[11] if len(inputs) > 11 else labels
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
cache=None, cache=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[11] if len(inputs) > 11 else start_positions
end_positions = inputs[12] if len(inputs) > 12 else end_positions
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
......
...@@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
...@@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call( def call(
self, self,
inputs, inputs=None,
token_type_ids=None, token_type_ids=None,
input_mask=None, input_mask=None,
attention_mask=None, attention_mask=None,
...@@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
use_cache = inputs[9] if len(inputs) > 9 else use_cache use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
assert len(inputs) <= 12, "Too many inputs." labels = inputs[12] if len(inputs) > 12 else labels
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions) output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_attentions) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 12, "Too many inputs." labels = inputs.get("labels", labels)
assert len(inputs) <= 13, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
flat_inputs = [ flat_inputs = [
flat_input_ids, flat_input_ids,
...@@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_token_type_ids, flat_token_type_ids,
flat_input_mask, flat_input_mask,
head_mask, head_mask,
inputs_embeds, flat_inputs_embeds,
use_cache, use_cache,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
...@@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
labels=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
...@@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call( def call(
self, self,
input_ids=None, inputs=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
start_positions=None,
end_positions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
""" """
if isinstance(inputs, (tuple, list)):
start_positions = inputs[12] if len(inputs) > 12 else start_positions
end_positions = inputs[13] if len(inputs) > 13 else end_positions
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
import inspect
import os import os
import random import random
import tempfile import tempfile
...@@ -35,6 +36,9 @@ if is_tf_available(): ...@@ -35,6 +36,9 @@ if is_tf_available():
TFAdaptiveEmbedding, TFAdaptiveEmbedding,
TFSharedEmbeddings, TFSharedEmbeddings,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
) )
if _tf_gpu_memory_limit is not None: if _tf_gpu_memory_limit is not None:
...@@ -71,14 +75,25 @@ class TFModelTesterMixin: ...@@ -71,14 +75,25 @@ class TFModelTesterMixin:
test_resize_embeddings = True test_resize_embeddings = True
is_encoder_decoder = False is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return { inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1)) k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
if isinstance(v, tf.Tensor) and v.ndim != 0 if isinstance(v, tf.Tensor) and v.ndim != 0
else v else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
return inputs_dict return inputs_dict
def test_initialization(self): def test_initialization(self):
...@@ -572,6 +587,51 @@ class TFModelTesterMixin: ...@@ -572,6 +587,51 @@ class TFModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :] generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
loss_size = tf.size(added_label)
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
input_ids = prepared_for_class.pop("input_ids")
loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.getfullargspec(model.call)[0]
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {1: "input_ids"}
for label_key in label_keys:
label_key_index = signature.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with None, update the values and convert to a tuple
list_input = [None] * sorted_tuple_index_mapping[-1][0]
for index, value in sorted_tuple_index_mapping:
list_input[index - 1] = prepared_for_class[value]
tuple_input = tuple(list_input)
# Send to model
loss = model(tuple_input)[0]
self.assertEqual(loss.shape, [loss_size])
def _generate_random_bad_tokens(self, num_bad_tokens, model): def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens # special tokens cannot be bad tokens
special_tokens = [] special_tokens = []
......
...@@ -24,11 +24,14 @@ from .utils import require_tf ...@@ -24,11 +24,14 @@ from .utils import require_tf
if is_tf_available(): if is_tf_available():
import tensorflow as tf
from transformers.modeling_tf_distilbert import ( from transformers.modeling_tf_distilbert import (
TFDistilBertModel, TFDistilBertModel,
TFDistilBertForMaskedLM, TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification, TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForMultipleChoice,
) )
...@@ -147,6 +150,35 @@ class TFDistilBertModelTester: ...@@ -147,6 +150,35 @@ class TFDistilBertModelTester:
} }
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
def create_and_check_distilbert_for_multiple_choice(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFDistilBertForMultipleChoice(config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def create_and_check_distilbert_for_token_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFDistilBertForTokenClassification(config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
...@@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
TFDistilBertForMaskedLM, TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification, TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForMultipleChoice,
) )
if is_tf_available() if is_tf_available()
else None else None
...@@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
# @slow # @slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -29,6 +29,7 @@ if is_tf_available(): ...@@ -29,6 +29,7 @@ if is_tf_available():
TFElectraForMaskedLM, TFElectraForMaskedLM,
TFElectraForPreTraining, TFElectraForPreTraining,
TFElectraForTokenClassification, TFElectraForTokenClassification,
TFElectraForQuestionAnswering,
) )
...@@ -137,6 +138,19 @@ class TFElectraModelTester: ...@@ -137,6 +138,19 @@ class TFElectraModelTester:
} }
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
def create_and_check_electra_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFElectraForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
start_logits, end_logits = model(inputs)
result = {
"start_logits": start_logits.numpy(),
"end_logits": end_logits.numpy(),
}
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
def create_and_check_electra_for_token_classification( def create_and_check_electra_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs) self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
def test_for_token_classification(self): def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
......
...@@ -32,6 +32,7 @@ if is_tf_available(): ...@@ -32,6 +32,7 @@ if is_tf_available():
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TFRobertaForTokenClassification, TFRobertaForTokenClassification,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
TFRobertaForMultipleChoice,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
...@@ -154,6 +155,25 @@ class TFRobertaModelTester: ...@@ -154,6 +155,25 @@ class TFRobertaModelTester:
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFRobertaForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -33,6 +33,7 @@ if is_tf_available(): ...@@ -33,6 +33,7 @@ if is_tf_available():
TFXLNetForSequenceClassification, TFXLNetForSequenceClassification,
TFXLNetForTokenClassification, TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple, TFXLNetForQuestionAnsweringSimple,
TFXLNetForMultipleChoice,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
...@@ -66,6 +67,7 @@ class TFXLNetModelTester: ...@@ -66,6 +67,7 @@ class TFXLNetModelTester:
self.bos_token_id = 1 self.bos_token_id = 1
self.eos_token_id = 2 self.eos_token_id = 2
self.pad_token_id = 5 self.pad_token_id = 5
self.num_choices = 4
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -316,6 +318,36 @@ class TFXLNetModelTester: ...@@ -316,6 +318,36 @@ class TFXLNetModelTester:
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
) )
def create_and_check_xlnet_for_multiple_choice(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
):
config.num_choices = self.num_choices
model = TFXLNetForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids_1, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(segment_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
TFXLNetForSequenceClassification, TFXLNetForSequenceClassification,
TFXLNetForTokenClassification, TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple, TFXLNetForQuestionAnsweringSimple,
TFXLNetForMultipleChoice,
) )
if is_tf_available() if is_tf_available()
else () else ()
...@@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs) self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
def test_xlnet_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_for_multiple_choice(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment