Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7f7c41b0
Commit
7f7c41b0
authored
Nov 30, 2018
by
thomwolf
Browse files
tests for all model classes with and without labels
parent
c588453a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
160 additions
and
8 deletions
+160
-8
tests/modeling_test.py
tests/modeling_test.py
+160
-8
No files found.
tests/modeling_test.py
View file @
7f7c41b0
...
@@ -22,7 +22,10 @@ import random
...
@@ -22,7 +22,10 @@ import random
import
torch
import
torch
from
pytorch_pretrained_bert
import
BertConfig
,
BertModel
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
class
BertModelTest
(
unittest
.
TestCase
):
class
BertModelTest
(
unittest
.
TestCase
):
...
@@ -35,6 +38,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -35,6 +38,7 @@ class BertModelTest(unittest.TestCase):
is_training
=
True
,
is_training
=
True
,
use_input_mask
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
hidden_size
=
32
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
...
@@ -45,7 +49,9 @@ class BertModelTest(unittest.TestCase):
...
@@ -45,7 +49,9 @@ class BertModelTest(unittest.TestCase):
attention_probs_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
num_labels
=
3
,
scope
=
None
):
scope
=
None
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -53,6 +59,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -53,6 +59,7 @@ class BertModelTest(unittest.TestCase):
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
...
@@ -63,10 +70,12 @@ class BertModelTest(unittest.TestCase):
...
@@ -63,10 +70,12 @@ class BertModelTest(unittest.TestCase):
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
scope
=
scope
self
.
scope
=
scope
def
c
re
ate_model
(
self
):
def
p
re
pare_config_and_inputs
(
self
):
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
input_mask
=
None
...
@@ -77,6 +86,12 @@ class BertModelTest(unittest.TestCase):
...
@@ -77,6 +86,12 @@ class BertModelTest(unittest.TestCase):
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
token_type_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
token_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
config
=
BertConfig
(
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -90,10 +105,16 @@ class BertModelTest(unittest.TestCase):
...
@@ -90,10 +105,16 @@ class BertModelTest(unittest.TestCase):
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
)
model
=
BertModel
(
config
=
config
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertModel
(
config
=
config
)
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"sequence_output"
:
all_encoder_layers
[
-
1
],
"sequence_output"
:
all_encoder_layers
[
-
1
],
"pooled_output"
:
pooled_output
,
"pooled_output"
:
pooled_output
,
...
@@ -101,13 +122,119 @@ class BertModelTest(unittest.TestCase):
...
@@ -101,13 +122,119 @@ class BertModelTest(unittest.TestCase):
}
}
return
outputs
return
outputs
def
check_output
(
self
,
result
):
def
check_bert_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
return
outputs
def
check_bert_for_masked_lm_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
}
return
outputs
def
check_bert_for_next_sequence_prediction_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForPreTraining
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"seq_relationship_score"
:
seq_relationship_score
,
}
return
outputs
def
check_bert_for_pretraining_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
return
outputs
def
check_bert_for_question_answering_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_sequence_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_token_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
@@ -118,8 +245,33 @@ class BertModelTest(unittest.TestCase):
...
@@ -118,8 +245,33 @@ class BertModelTest(unittest.TestCase):
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
def
run_tester
(
self
,
tester
):
def
run_tester
(
self
,
tester
):
output_result
=
tester
.
create_model
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
check_output
(
output_result
)
output_result
=
tester
.
create_bert_model
(
*
config_and_inputs
)
tester
.
check_bert_model_output
(
output_result
)
output_result
=
tester
.
create_bert_for_masked_lm
(
*
config_and_inputs
)
tester
.
check_bert_for_masked_lm_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
tester
.
check_bert_for_next_sequence_prediction_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_pretraining
(
*
config_and_inputs
)
tester
.
check_bert_for_pretraining_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_question_answering
(
*
config_and_inputs
)
tester
.
check_bert_for_question_answering_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_sequence_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_sequence_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_token_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment