Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e1207b8
"examples/run_squad_dataset_utils.py" did not exist on "335f57baf86094907a14de7ddc9f3e791ae3519b"
Commit
5e1207b8
authored
Jun 14, 2019
by
thomwolf
Browse files
add attention to all bert models and add test
parent
bcc9e93e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
47 deletions
+140
-47
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+78
-38
tests/modeling_test.py
tests/modeling_test.py
+62
-9
No files found.
pytorch_pretrained_bert/modeling.py
View file @
5e1207b8
...
...
@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
):
sequence_output
,
pooled_
output
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output
s
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
pooled_output
=
outputs
else
:
sequence_output
,
pooled_output
=
outputs
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
...
...
@@ -830,7 +835,8 @@ class BertForPreTraining(BertPreTrainedModel):
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
return
total_loss
else
:
elif
self
.
output_attentions
:
return
all_attentions
,
prediction_scores
,
seq_relationship_score
return
prediction_scores
,
seq_relationship_score
...
...
@@ -876,22 +882,28 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
):
sequence_
output
,
_
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output
s
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
prediction_scores
=
self
.
cls
(
sequence_output
)
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
return
masked_lm_loss
else
:
elif
self
.
output_attentions
:
return
all_attentions
,
prediction_scores
return
prediction_scores
...
...
@@ -938,22 +950,28 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
):
_
,
pooled_
output
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output
s
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
seq_relationship_score
=
self
.
cls
(
pooled_output
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
seq_relationship_score
=
self
.
cls
(
pooled_output
)
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
return
next_sentence_loss
else
:
elif
self
.
output_attentions
:
return
all_attentions
,
seq_relationship_score
return
seq_relationship_score
...
...
@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
):
def
__init__
(
self
,
config
,
num_labels
,
output_attentions
=
False
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
):
_
,
pooled_output
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
...
...
@@ -1019,7 +1042,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
else
:
elif
self
.
output_attentions
:
return
all_attentions
,
logits
return
logits
...
...
@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_choices
):
def
__init__
(
self
,
config
,
num_choices
,
output_attentions
=
False
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
_
,
pooled_output
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
self
.
num_choices
)
...
...
@@ -1088,7 +1117,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
return
loss
else
:
elif
self
.
output_attentions
:
return
all_attentions
,
reshaped_logits
return
reshaped_logits
...
...
@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
):
def
__init__
(
self
,
config
,
num_labels
,
output_attentions
=
False
):
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
):
sequence_output
,
_
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
...
...
@@ -1161,7 +1196,8 @@ class BertForTokenClassification(BertPreTrainedModel):
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
else
:
elif
self
.
output_attentions
:
return
all_attentions
,
logits
return
logits
...
...
@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
):
sequence_output
,
_
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
...
...
@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
else
:
elif
self
.
output_attentions
:
return
all_attentions
,
start_logits
,
end_logits
return
start_logits
,
end_logits
tests/modeling_test.py
View file @
5e1207b8
...
...
@@ -28,7 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
...
...
@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
...
...
@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
...
@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
.
eval
()
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
...
...
@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
...
@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
...
@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
...
...
@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
...
...
@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
])
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
...
@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
num_labels
])
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
...
@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
create_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMultipleChoice
(
config
=
config
,
num_choices
=
self
.
num_choices
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
choice_labels
)
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_multiple_choice
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
def
create_and_check_bert_for_attentions
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
output_attentions
=
True
)
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
output
[
0
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
...
@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase):
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_multiple_choice
(
*
config_and_inputs
)
tester
.
check_bert_for_multiple_choice
(
output_result
)
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment