Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cf10d4cf
Unverified
Commit
cf10d4cf
authored
Jun 24, 2020
by
Lysandre Debut
Committed by
GitHub
Jun 24, 2020
Browse files
Cleaning TensorFlow models (#5229)
* Cleaning TensorFlow models Update all classes stylr * Don't average loss
parent
609e0c58
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
483 additions
and
126 deletions
+483
-126
src/transformers/modeling_tf_albert.py
src/transformers/modeling_tf_albert.py
+40
-17
src/transformers/modeling_tf_bert.py
src/transformers/modeling_tf_bert.py
+39
-16
src/transformers/modeling_tf_distilbert.py
src/transformers/modeling_tf_distilbert.py
+50
-19
src/transformers/modeling_tf_electra.py
src/transformers/modeling_tf_electra.py
+22
-10
src/transformers/modeling_tf_mobilebert.py
src/transformers/modeling_tf_mobilebert.py
+37
-16
src/transformers/modeling_tf_roberta.py
src/transformers/modeling_tf_roberta.py
+44
-17
src/transformers/modeling_tf_xlm.py
src/transformers/modeling_tf_xlm.py
+22
-10
src/transformers/modeling_tf_xlnet.py
src/transformers/modeling_tf_xlnet.py
+46
-19
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+62
-2
tests/test_modeling_tf_distilbert.py
tests/test_modeling_tf_distilbert.py
+42
-0
tests/test_modeling_tf_electra.py
tests/test_modeling_tf_electra.py
+18
-0
tests/test_modeling_tf_roberta.py
tests/test_modeling_tf_roberta.py
+24
-0
tests/test_modeling_tf_xlnet.py
tests/test_modeling_tf_xlnet.py
+37
-0
No files found.
src/transformers/modeling_tf_albert.py
View file @
cf10d4cf
...
@@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
...
@@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
...
@@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
albert
(
outputs
=
self
.
albert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
...
@@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
...
@@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
loss, scores = outputs[:2]
loss, scores = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
albert
(
outputs
=
self
.
albert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
...
@@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
...
@@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
start_positions
end_positions
=
inputs
[
9
]
if
len
(
inputs
)
>
9
else
end_positions
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
outputs
=
self
.
albert
(
outputs
=
self
.
albert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
output_attentions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
output_attentions
output_attentions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
output_attentions
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
output_hidden_states
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
output_hidden_states
elif
isinstance
(
inputs
,
dict
):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
input_ids
=
inputs
.
get
(
"input_ids"
)
input_ids
=
inputs
.
get
(
"input_ids"
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
"token_type_ids"
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
"token_type_ids"
,
token_type_ids
)
...
@@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_attentions
)
labels
=
inputs
.
get
(
"labels"
,
labels
)
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
...
...
src/transformers/modeling_tf_bert.py
View file @
cf10d4cf
...
@@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
...
@@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
@
add_start_docstrings_to_callable
(
BERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
BERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
...
@@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
bert
(
outputs
=
self
.
bert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
output_attentions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
output_attentions
output_attentions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
output_attentions
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
output_hidden_states
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
output_hidden_states
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
input_ids
=
inputs
.
get
(
"input_ids"
)
input_ids
=
inputs
.
get
(
"input_ids"
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
...
@@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_hidden_states
)
labels
=
inputs
.
get
(
"labels"
,
labels
)
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
...
@@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
...
@@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
@
add_start_docstrings_to_callable
(
BERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
BERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
...
@@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
loss, scores = outputs[:2]
loss, scores = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
bert
(
outputs
=
self
.
bert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
...
@@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
@
add_start_docstrings_to_callable
(
BERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
BERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
...
@@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
assert answer == "a nice puppet"
assert answer == "a nice puppet"
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
start_positions
end_positions
=
inputs
[
9
]
if
len
(
inputs
)
>
9
else
end_positions
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
outputs
=
self
.
bert
(
outputs
=
self
.
bert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
...
src/transformers/modeling_tf_distilbert.py
View file @
cf10d4cf
...
@@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
...
@@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
@
add_start_docstrings_to_callable
(
DISTILBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
DISTILBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
...
@@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
labels
if
len
(
inputs
)
>
6
:
inputs
=
inputs
[:
6
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
distilbert_output
=
self
.
distilbert
(
distilbert_output
=
self
.
distilbert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
...
@@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
...
@@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
@
add_start_docstrings_to_callable
(
DISTILBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
DISTILBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
...
@@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
loss, scores = outputs[:2]
loss, scores = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
labels
if
len
(
inputs
)
>
6
:
inputs
=
inputs
[:
6
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
distilbert
(
outputs
=
self
.
distilbert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
...
@@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
...
@@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
sequence_output
=
self
.
dropout
(
sequence_output
,
training
=
training
)
sequence_output
=
self
.
dropout
(
sequence_output
,
training
=
training
)
logits
=
self
.
classifier
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
1
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
if
labels
is
not
None
:
loss
=
self
.
compute_loss
(
labels
,
logits
)
loss
=
self
.
compute_loss
(
labels
,
logits
)
...
@@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
...
@@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
distilbert
=
TFDistilBertMainLayer
(
config
,
name
=
"distilbert"
)
self
.
distilbert
=
TFDistilBertMainLayer
(
config
,
name
=
"distilbert"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden
_dropout
_prob
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
seq_classif
_dropout
)
self
.
pre_classifier
=
tf
.
keras
.
layers
.
Dense
(
self
.
pre_classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
dim
,
config
.
dim
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
...
@@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
...
@@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
attention_mask
=
None
,
attention_mask
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
...
@@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
inputs_embeds
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
inputs_embeds
inputs_embeds
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
inputs_embeds
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
output_attentions
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
output_attentions
output_hidden_states
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
output_hidden_states
labels
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
labels
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
input_ids
=
inputs
.
get
(
"input_ids"
)
input_ids
=
inputs
.
get
(
"input_ids"
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_hidden_states
)
labels
=
inputs
.
get
(
"labels"
,
labels
)
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
...
@@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
...
@@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_inputs_embeds
=
(
tf
.
reshape
(
inputs_embeds
,
(
-
1
,
seq_length
,
shape_list
(
inputs_embeds
)[
3
]))
if
inputs_embeds
is
not
None
else
None
)
flat_inputs
=
[
flat_inputs
=
[
flat_input_ids
,
flat_input_ids
,
flat_attention_mask
,
flat_attention_mask
,
head_mask
,
head_mask
,
inputs_embeds
,
flat_
inputs_embeds
,
output_attentions
,
output_attentions
,
output_hidden_states
,
output_hidden_states
,
]
]
...
@@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
...
@@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
@
add_start_docstrings_to_callable
(
DISTILBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
DISTILBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
...
@@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
start_positions
end_positions
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
end_positions
if
len
(
inputs
)
>
6
:
inputs
=
inputs
[:
6
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
distilbert_output
=
self
.
distilbert
(
distilbert_output
=
self
.
distilbert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
...
...
src/transformers/modeling_tf_electra.py
View file @
cf10d4cf
...
@@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
...
@@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
@
add_start_docstrings_to_callable
(
ELECTRA_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ELECTRA_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
...
@@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
loss, scores = outputs[:2]
loss, scores = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
discriminator_hidden_states
=
self
.
electra
(
discriminator_hidden_states
=
self
.
electra
(
input
_id
s
,
inputs
,
attention_mask
,
attention_mask
,
token_type_ids
,
token_type_ids
,
position_ids
,
position_ids
,
...
@@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
...
@@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
@
add_start_docstrings_to_callable
(
ELECTRA_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ELECTRA_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
...
@@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
start_positions
end_positions
=
inputs
[
9
]
if
len
(
inputs
)
>
9
else
end_positions
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
discriminator_hidden_states
=
self
.
electra
(
discriminator_hidden_states
=
self
.
electra
(
input
_id
s
,
inputs
,
attention_mask
,
attention_mask
,
token_type_ids
,
token_type_ids
,
position_ids
,
position_ids
,
...
...
src/transformers/modeling_tf_mobilebert.py
View file @
cf10d4cf
...
@@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
...
@@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
@
add_start_docstrings_to_callable
(
MOBILEBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
MOBILEBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
...
@@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
mobilebert
(
outputs
=
self
.
mobilebert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
...
@@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
@
add_start_docstrings_to_callable
(
MOBILEBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
MOBILEBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
...
@@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
assert answer == "a nice puppet"
assert answer == "a nice puppet"
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
start_positions
end_positions
=
inputs
[
9
]
if
len
(
inputs
)
>
9
else
end_positions
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
outputs
=
self
.
mobilebert
(
outputs
=
self
.
mobilebert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
...
@@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
...
@@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
output_attentions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
output_attentions
output_attentions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
output_attentions
output_hidden_states
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
output_hidden_states
output_hidden_states
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
output_hidden_states
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
input_ids
=
inputs
.
get
(
"input_ids"
)
input_ids
=
inputs
.
get
(
"input_ids"
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
...
@@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
...
@@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_hidden_states
)
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_hidden_states
)
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
labels
=
inputs
.
get
(
"labels"
,
labels
)
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
...
@@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
...
@@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
@
add_start_docstrings_to_callable
(
MOBILEBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
MOBILEBERT_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
...
@@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
loss, scores = outputs[:2]
loss, scores = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
mobilebert
(
outputs
=
self
.
mobilebert
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
...
src/transformers/modeling_tf_roberta.py
View file @
cf10d4cf
...
@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
...
@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
keras_serializable
,
keras_serializable
,
shape_list
,
shape_list
,
)
)
from
.tokenization_utils_base
import
BatchEncoding
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
...
@@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
@
add_start_docstrings_to_callable
(
ROBERTA_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ROBERTA_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
...
@@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
roberta
(
outputs
=
self
.
roberta
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
...
@@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
...
@@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
output_attentions
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
output_attentions
elif
isinstance
(
inputs
,
dict
):
output_hidden_states
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
output_hidden_states
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
input_ids
=
inputs
.
get
(
"input_ids"
)
input_ids
=
inputs
.
get
(
"input_ids"
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
"token_type_ids"
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
"token_type_ids"
,
token_type_ids
)
position_ids
=
inputs
.
get
(
"position_ids"
,
position_ids
)
position_ids
=
inputs
.
get
(
"position_ids"
,
position_ids
)
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
head_mask
=
inputs
.
get
(
"head_mask"
,
head_mask
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_attentions
)
labels
=
inputs
.
get
(
"labels"
,
labels
)
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
...
@@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
...
@@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
@
add_start_docstrings_to_callable
(
ROBERTA_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ROBERTA_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
...
@@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
loss, scores = outputs[:2]
loss, scores = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
labels
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
outputs
=
self
.
roberta
(
outputs
=
self
.
roberta
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
@@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
...
@@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
@
add_start_docstrings_to_callable
(
ROBERTA_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
ROBERTA_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
...
@@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
start_positions
end_positions
=
inputs
[
9
]
if
len
(
inputs
)
>
9
else
end_positions
if
len
(
inputs
)
>
8
:
inputs
=
inputs
[:
8
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
outputs
=
self
.
roberta
(
outputs
=
self
.
roberta
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
...
src/transformers/modeling_tf_xlm.py
View file @
cf10d4cf
...
@@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
...
@@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
@
add_start_docstrings_to_callable
(
XLM_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
XLM_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_ids
,
input
s
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
langs
=
None
,
langs
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
...
@@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
...
@@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
cache
=
None
,
cache
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
...
@@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
11
]
if
len
(
inputs
)
>
11
else
labels
if
len
(
inputs
)
>
11
:
inputs
=
inputs
[:
11
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
langs
=
langs
,
langs
=
langs
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
...
@@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
...
@@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
@
add_start_docstrings_to_callable
(
XLM_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
XLM_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
langs
=
None
,
langs
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
...
@@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
...
@@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
cache
=
None
,
cache
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
...
@@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
11
]
if
len
(
inputs
)
>
11
else
start_positions
end_positions
=
inputs
[
12
]
if
len
(
inputs
)
>
12
else
end_positions
if
len
(
inputs
)
>
11
:
inputs
=
inputs
[:
11
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
langs
=
langs
,
langs
=
langs
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
...
...
src/transformers/modeling_tf_xlnet.py
View file @
cf10d4cf
...
@@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
@
add_start_docstrings_to_callable
(
XLNET_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
XLNET_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
mems
=
None
,
perm_mask
=
None
,
perm_mask
=
None
,
...
@@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
use_cache
=
True
,
use_cache
=
True
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
12
]
if
len
(
inputs
)
>
12
else
labels
if
len
(
inputs
)
>
12
:
inputs
=
inputs
[:
12
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
perm_mask
=
perm_mask
,
perm_mask
=
perm_mask
,
...
@@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
@
add_start_docstrings_to_callable
(
XLNET_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
XLNET_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
inputs
,
inputs
=
None
,
token_type_ids
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
...
@@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
use_cache
=
True
,
use_cache
=
True
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
use_cache
=
inputs
[
9
]
if
len
(
inputs
)
>
9
else
use_cache
use_cache
=
inputs
[
9
]
if
len
(
inputs
)
>
9
else
use_cache
output_attentions
=
inputs
[
10
]
if
len
(
inputs
)
>
10
else
output_attentions
output_attentions
=
inputs
[
10
]
if
len
(
inputs
)
>
10
else
output_attentions
output_hidden_states
=
inputs
[
11
]
if
len
(
inputs
)
>
11
else
output_hidden_states
output_hidden_states
=
inputs
[
11
]
if
len
(
inputs
)
>
11
else
output_hidden_states
assert
len
(
inputs
)
<=
12
,
"Too many inputs."
labels
=
inputs
[
12
]
if
len
(
inputs
)
>
12
else
labels
assert
len
(
inputs
)
<=
13
,
"Too many inputs."
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
input_ids
=
inputs
.
get
(
"input_ids"
)
input_ids
=
inputs
.
get
(
"input_ids"
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
attention_mask
=
inputs
.
get
(
"attention_mask"
,
attention_mask
)
...
@@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
inputs_embeds
=
inputs
.
get
(
"inputs_embeds"
,
inputs_embeds
)
use_cache
=
inputs
.
get
(
"use_cache"
,
use_cache
)
use_cache
=
inputs
.
get
(
"use_cache"
,
use_cache
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_attentions
=
inputs
.
get
(
"output_attentions"
,
output_attentions
)
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_attentions
)
output_hidden_states
=
inputs
.
get
(
"output_hidden_states"
,
output_hidden_states
)
assert
len
(
inputs
)
<=
12
,
"Too many inputs."
labels
=
inputs
.
get
(
"labels"
,
labels
)
assert
len
(
inputs
)
<=
13
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
...
@@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_input_mask
=
tf
.
reshape
(
input_mask
,
(
-
1
,
seq_length
))
if
input_mask
is
not
None
else
None
flat_input_mask
=
tf
.
reshape
(
input_mask
,
(
-
1
,
seq_length
))
if
input_mask
is
not
None
else
None
flat_inputs_embeds
=
(
tf
.
reshape
(
inputs_embeds
,
(
-
1
,
seq_length
,
shape_list
(
inputs_embeds
)[
3
]))
if
inputs_embeds
is
not
None
else
None
)
flat_inputs
=
[
flat_inputs
=
[
flat_input_ids
,
flat_input_ids
,
...
@@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_token_type_ids
,
flat_token_type_ids
,
flat_input_mask
,
flat_input_mask
,
head_mask
,
head_mask
,
inputs_embeds
,
flat_
inputs_embeds
,
use_cache
,
use_cache
,
output_attentions
,
output_attentions
,
output_hidden_states
,
output_hidden_states
,
...
@@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
mems
=
None
,
perm_mask
=
None
,
perm_mask
=
None
,
...
@@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
use_cache
=
True
,
use_cache
=
True
,
labels
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
labels
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
loss, scores = outputs[:2]
loss, scores = outputs[:2]
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
labels
=
inputs
[
12
]
if
len
(
inputs
)
>
12
else
labels
if
len
(
inputs
)
>
12
:
inputs
=
inputs
[:
12
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
labels
=
inputs
.
pop
(
"labels"
,
labels
)
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
perm_mask
=
perm_mask
,
perm_mask
=
perm_mask
,
...
@@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
@
add_start_docstrings_to_callable
(
XLNET_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_callable
(
XLNET_INPUTS_DOCSTRING
)
def
call
(
def
call
(
self
,
self
,
input
_id
s
=
None
,
inputs
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
mems
=
None
,
perm_mask
=
None
,
perm_mask
=
None
,
...
@@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
use_cache
=
True
,
use_cache
=
True
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
p_mask
=
None
,
is_impossible
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
start_positions
=
None
,
end_positions
=
None
,
training
=
False
,
training
=
False
,
):
):
r
"""
r
"""
...
@@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
"""
if
isinstance
(
inputs
,
(
tuple
,
list
)):
start_positions
=
inputs
[
12
]
if
len
(
inputs
)
>
12
else
start_positions
end_positions
=
inputs
[
13
]
if
len
(
inputs
)
>
13
else
end_positions
if
len
(
inputs
)
>
12
:
inputs
=
inputs
[:
12
]
elif
isinstance
(
inputs
,
(
dict
,
BatchEncoding
)):
start_positions
=
inputs
.
pop
(
"start_positions"
,
start_positions
)
end_positions
=
inputs
.
pop
(
"end_positions"
,
start_positions
)
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
input
_id
s
,
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
perm_mask
=
perm_mask
,
perm_mask
=
perm_mask
,
...
...
tests/test_modeling_tf_common.py
View file @
cf10d4cf
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
copy
import
copy
import
inspect
import
os
import
os
import
random
import
random
import
tempfile
import
tempfile
...
@@ -35,6 +36,9 @@ if is_tf_available():
...
@@ -35,6 +36,9 @@ if is_tf_available():
TFAdaptiveEmbedding
,
TFAdaptiveEmbedding
,
TFSharedEmbeddings
,
TFSharedEmbeddings
,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
)
)
if
_tf_gpu_memory_limit
is
not
None
:
if
_tf_gpu_memory_limit
is
not
None
:
...
@@ -71,14 +75,25 @@ class TFModelTesterMixin:
...
@@ -71,14 +75,25 @@ class TFModelTesterMixin:
test_resize_embeddings
=
True
test_resize_embeddings
=
True
is_encoder_decoder
=
False
is_encoder_decoder
=
False
def
_prepare_for_class
(
self
,
inputs_dict
,
model_class
):
def
_prepare_for_class
(
self
,
inputs_dict
,
model_class
,
return_labels
=
False
):
if
model_class
in
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
.
values
():
if
model_class
in
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
.
values
():
return
{
inputs_dict
=
{
k
:
tf
.
tile
(
tf
.
expand_dims
(
v
,
1
),
(
1
,
self
.
model_tester
.
num_choices
,
1
))
k
:
tf
.
tile
(
tf
.
expand_dims
(
v
,
1
),
(
1
,
self
.
model_tester
.
num_choices
,
1
))
if
isinstance
(
v
,
tf
.
Tensor
)
and
v
.
ndim
!=
0
if
isinstance
(
v
,
tf
.
Tensor
)
and
v
.
ndim
!=
0
else
v
else
v
for
k
,
v
in
inputs_dict
.
items
()
for
k
,
v
in
inputs_dict
.
items
()
}
}
if
return_labels
:
if
model_class
in
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
.
values
():
inputs_dict
[
"labels"
]
=
tf
.
ones
(
self
.
model_tester
.
batch_size
)
elif
model_class
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
values
():
inputs_dict
[
"start_positions"
]
=
tf
.
zeros
(
self
.
model_tester
.
batch_size
)
inputs_dict
[
"end_positions"
]
=
tf
.
zeros
(
self
.
model_tester
.
batch_size
)
elif
model_class
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
values
():
inputs_dict
[
"labels"
]
=
tf
.
zeros
(
self
.
model_tester
.
batch_size
)
elif
model_class
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
values
():
inputs_dict
[
"labels"
]
=
tf
.
zeros
((
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
seq_length
))
return
inputs_dict
return
inputs_dict
def
test_initialization
(
self
):
def
test_initialization
(
self
):
...
@@ -572,6 +587,51 @@ class TFModelTesterMixin:
...
@@ -572,6 +587,51 @@ class TFModelTesterMixin:
generated_ids
=
output_tokens
[:,
input_ids
.
shape
[
-
1
]
:]
generated_ids
=
output_tokens
[:,
input_ids
.
shape
[
-
1
]
:]
self
.
assertFalse
(
self
.
_check_match_tokens
(
generated_ids
.
numpy
().
tolist
(),
bad_words_ids
))
self
.
assertFalse
(
self
.
_check_match_tokens
(
generated_ids
.
numpy
().
tolist
(),
bad_words_ids
))
def
test_loss_computation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
if
getattr
(
model
,
"compute_loss"
,
None
):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
added_label
=
prepared_for_class
[
list
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
())[
0
]]
loss_size
=
tf
.
size
(
added_label
)
# Test that model correctly compute the loss with kwargs
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
input_ids
=
prepared_for_class
.
pop
(
"input_ids"
)
loss
=
model
(
input_ids
,
**
prepared_for_class
)[
0
]
self
.
assertEqual
(
loss
.
shape
,
[
loss_size
])
# Test that model correctly compute the loss with a dict
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
loss
=
model
(
prepared_for_class
)[
0
]
self
.
assertEqual
(
loss
.
shape
,
[
loss_size
])
# Test that model correctly compute the loss with a tuple
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
# Get keys that were added with the _prepare_for_class function
label_keys
=
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
()
signature
=
inspect
.
getfullargspec
(
model
.
call
)[
0
]
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping
=
{
1
:
"input_ids"
}
for
label_key
in
label_keys
:
label_key_index
=
signature
.
index
(
label_key
)
tuple_index_mapping
[
label_key_index
]
=
label_key
sorted_tuple_index_mapping
=
sorted
(
tuple_index_mapping
.
items
())
# Initialize a list with None, update the values and convert to a tuple
list_input
=
[
None
]
*
sorted_tuple_index_mapping
[
-
1
][
0
]
for
index
,
value
in
sorted_tuple_index_mapping
:
list_input
[
index
-
1
]
=
prepared_for_class
[
value
]
tuple_input
=
tuple
(
list_input
)
# Send to model
loss
=
model
(
tuple_input
)[
0
]
self
.
assertEqual
(
loss
.
shape
,
[
loss_size
])
def
_generate_random_bad_tokens
(
self
,
num_bad_tokens
,
model
):
def
_generate_random_bad_tokens
(
self
,
num_bad_tokens
,
model
):
# special tokens cannot be bad tokens
# special tokens cannot be bad tokens
special_tokens
=
[]
special_tokens
=
[]
...
...
tests/test_modeling_tf_distilbert.py
View file @
cf10d4cf
...
@@ -24,11 +24,14 @@ from .utils import require_tf
...
@@ -24,11 +24,14 @@ from .utils import require_tf
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_distilbert
import
(
from
transformers.modeling_tf_distilbert
import
(
TFDistilBertModel
,
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForSequenceClassification
,
TFDistilBertForSequenceClassification
,
TFDistilBertForTokenClassification
,
TFDistilBertForMultipleChoice
,
)
)
...
@@ -147,6 +150,35 @@ class TFDistilBertModelTester:
...
@@ -147,6 +150,35 @@ class TFDistilBertModelTester:
}
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_labels
])
def
create_and_check_distilbert_for_multiple_choice
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_choices
=
self
.
num_choices
model
=
TFDistilBertForMultipleChoice
(
config
)
multiple_choice_inputs_ids
=
tf
.
tile
(
tf
.
expand_dims
(
input_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
multiple_choice_input_mask
=
tf
.
tile
(
tf
.
expand_dims
(
input_mask
,
1
),
(
1
,
self
.
num_choices
,
1
))
inputs
=
{
"input_ids"
:
multiple_choice_inputs_ids
,
"attention_mask"
:
multiple_choice_input_mask
,
}
(
logits
,)
=
model
(
inputs
)
result
=
{
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_choices
])
def
create_and_check_distilbert_for_token_classification
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
TFDistilBertForTokenClassification
(
config
)
inputs
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
input_mask
}
(
logits
,)
=
model
(
inputs
)
result
=
{
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
(
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
...
@@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
TFDistilBertForMaskedLM
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForSequenceClassification
,
TFDistilBertForSequenceClassification
,
TFDistilBertForTokenClassification
,
TFDistilBertForMultipleChoice
,
)
)
if
is_tf_available
()
if
is_tf_available
()
else
None
else
None
...
@@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_sequence_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_distilbert_for_sequence_classification
(
*
config_and_inputs
)
def
test_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_multiple_choice
(
*
config_and_inputs
)
def
test_for_token_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_token_classification
(
*
config_and_inputs
)
# @slow
# @slow
# def test_model_from_pretrained(self):
# def test_model_from_pretrained(self):
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...
...
tests/test_modeling_tf_electra.py
View file @
cf10d4cf
...
@@ -29,6 +29,7 @@ if is_tf_available():
...
@@ -29,6 +29,7 @@ if is_tf_available():
TFElectraForMaskedLM
,
TFElectraForMaskedLM
,
TFElectraForPreTraining
,
TFElectraForPreTraining
,
TFElectraForTokenClassification
,
TFElectraForTokenClassification
,
TFElectraForQuestionAnswering
,
)
)
...
@@ -137,6 +138,19 @@ class TFElectraModelTester:
...
@@ -137,6 +138,19 @@ class TFElectraModelTester:
}
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_and_check_electra_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFElectraForQuestionAnswering
(
config
=
config
)
inputs
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
input_mask
,
"token_type_ids"
:
token_type_ids
}
start_logits
,
end_logits
=
model
(
inputs
)
result
=
{
"start_logits"
:
start_logits
.
numpy
(),
"end_logits"
:
end_logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_and_check_electra_for_token_classification
(
def
create_and_check_electra_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
):
...
@@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_electra_for_pretraining
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_electra_for_pretraining
(
*
config_and_inputs
)
def
test_for_question_answering
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_electra_for_question_answering
(
*
config_and_inputs
)
def
test_for_token_classification
(
self
):
def
test_for_token_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_electra_for_token_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_electra_for_token_classification
(
*
config_and_inputs
)
...
...
tests/test_modeling_tf_roberta.py
View file @
cf10d4cf
...
@@ -32,6 +32,7 @@ if is_tf_available():
...
@@ -32,6 +32,7 @@ if is_tf_available():
TFRobertaForSequenceClassification
,
TFRobertaForSequenceClassification
,
TFRobertaForTokenClassification
,
TFRobertaForTokenClassification
,
TFRobertaForQuestionAnswering
,
TFRobertaForQuestionAnswering
,
TFRobertaForMultipleChoice
,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
,
)
)
...
@@ -154,6 +155,25 @@ class TFRobertaModelTester:
...
@@ -154,6 +155,25 @@ class TFRobertaModelTester:
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_and_check_roberta_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_choices
=
self
.
num_choices
model
=
TFRobertaForMultipleChoice
(
config
=
config
)
multiple_choice_inputs_ids
=
tf
.
tile
(
tf
.
expand_dims
(
input_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
multiple_choice_input_mask
=
tf
.
tile
(
tf
.
expand_dims
(
input_mask
,
1
),
(
1
,
self
.
num_choices
,
1
))
multiple_choice_token_type_ids
=
tf
.
tile
(
tf
.
expand_dims
(
token_type_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
inputs
=
{
"input_ids"
:
multiple_choice_inputs_ids
,
"attention_mask"
:
multiple_choice_input_mask
,
"token_type_ids"
:
multiple_choice_token_type_ids
,
}
(
logits
,)
=
model
(
inputs
)
result
=
{
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_choices
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
(
...
@@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_roberta_for_question_answering
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_roberta_for_question_answering
(
*
config_and_inputs
)
def
test_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_roberta_for_multiple_choice
(
*
config_and_inputs
)
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_xlnet.py
View file @
cf10d4cf
...
@@ -33,6 +33,7 @@ if is_tf_available():
...
@@ -33,6 +33,7 @@ if is_tf_available():
TFXLNetForSequenceClassification
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
,
TFXLNetForQuestionAnsweringSimple
,
TFXLNetForMultipleChoice
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
,
)
)
...
@@ -66,6 +67,7 @@ class TFXLNetModelTester:
...
@@ -66,6 +67,7 @@ class TFXLNetModelTester:
self
.
bos_token_id
=
1
self
.
bos_token_id
=
1
self
.
eos_token_id
=
2
self
.
eos_token_id
=
2
self
.
pad_token_id
=
5
self
.
pad_token_id
=
5
self
.
num_choices
=
4
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
@@ -316,6 +318,36 @@ class TFXLNetModelTester:
...
@@ -316,6 +318,36 @@ class TFXLNetModelTester:
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
)
def
create_and_check_xlnet_for_multiple_choice
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
):
config
.
num_choices
=
self
.
num_choices
model
=
TFXLNetForMultipleChoice
(
config
=
config
)
multiple_choice_inputs_ids
=
tf
.
tile
(
tf
.
expand_dims
(
input_ids_1
,
1
),
(
1
,
self
.
num_choices
,
1
))
multiple_choice_input_mask
=
tf
.
tile
(
tf
.
expand_dims
(
input_mask
,
1
),
(
1
,
self
.
num_choices
,
1
))
multiple_choice_token_type_ids
=
tf
.
tile
(
tf
.
expand_dims
(
segment_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
inputs
=
{
"input_ids"
:
multiple_choice_inputs_ids
,
"attention_mask"
:
multiple_choice_input_mask
,
"token_type_ids"
:
multiple_choice_token_type_ids
,
}
(
logits
,)
=
model
(
inputs
)
result
=
{
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_choices
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
(
...
@@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
TFXLNetForSequenceClassification
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
,
TFXLNetForQuestionAnsweringSimple
,
TFXLNetForMultipleChoice
,
)
)
if
is_tf_available
()
if
is_tf_available
()
else
()
else
()
...
@@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
def
test_xlnet_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_for_multiple_choice
(
*
config_and_inputs
)
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment