Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7f7c41b0
Commit
7f7c41b0
authored
Nov 30, 2018
by
thomwolf
Browse files
tests for all model classes with and without labels
parent
c588453a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
160 additions
and
8 deletions
+160
-8
tests/modeling_test.py
tests/modeling_test.py
+160
-8
No files found.
tests/modeling_test.py
View file @
7f7c41b0
...
...
@@ -22,7 +22,10 @@ import random
import
torch
from
pytorch_pretrained_bert
import
BertConfig
,
BertModel
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
class
BertModelTest
(
unittest
.
TestCase
):
...
...
@@ -35,6 +38,7 @@ class BertModelTest(unittest.TestCase):
is_training
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
...
...
@@ -45,7 +49,9 @@ class BertModelTest(unittest.TestCase):
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
scope
=
None
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -53,6 +59,7 @@ class BertModelTest(unittest.TestCase):
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
...
...
@@ -63,10 +70,12 @@ class BertModelTest(unittest.TestCase):
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
scope
=
scope
def
c
re
ate_model
(
self
):
def
p
re
pare_config_and_inputs
(
self
):
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
...
...
@@ -77,6 +86,12 @@ class BertModelTest(unittest.TestCase):
if
self
.
use_token_type_ids
:
token_type_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
token_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
...
...
@@ -90,10 +105,16 @@ class BertModelTest(unittest.TestCase):
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
model
=
BertModel
(
config
=
config
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertModel
(
config
=
config
)
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"sequence_output"
:
all_encoder_layers
[
-
1
],
"pooled_output"
:
pooled_output
,
...
...
@@ -101,13 +122,119 @@ class BertModelTest(unittest.TestCase):
}
return
outputs
def
check_output
(
self
,
result
):
def
check_bert_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
return
outputs
def
check_bert_for_masked_lm_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
}
return
outputs
def
check_bert_for_next_sequence_prediction_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForPreTraining
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"seq_relationship_score"
:
seq_relationship_score
,
}
return
outputs
def
check_bert_for_pretraining_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
return
outputs
def
check_bert_for_question_answering_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_sequence_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_token_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
...
@@ -118,8 +245,33 @@ class BertModelTest(unittest.TestCase):
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
def
run_tester
(
self
,
tester
):
output_result
=
tester
.
create_model
()
tester
.
check_output
(
output_result
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_bert_model
(
*
config_and_inputs
)
tester
.
check_bert_model_output
(
output_result
)
output_result
=
tester
.
create_bert_for_masked_lm
(
*
config_and_inputs
)
tester
.
check_bert_for_masked_lm_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
tester
.
check_bert_for_next_sequence_prediction_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_pretraining
(
*
config_and_inputs
)
tester
.
check_bert_for_pretraining_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_question_answering
(
*
config_and_inputs
)
tester
.
check_bert_for_question_answering_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_sequence_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_sequence_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_token_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment