Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e1207b8
Commit
5e1207b8
authored
Jun 14, 2019
by
thomwolf
Browse files
add attention to all bert models and add test
parent
bcc9e93e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
47 deletions
+140
-47
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+78
-38
tests/modeling_test.py
tests/modeling_test.py
+62
-9
No files found.
pytorch_pretrained_bert/modeling.py
View file @
5e1207b8
...
@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
):
sequence_output
,
pooled_
output
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output
s
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
pooled_output
=
outputs
else
:
sequence_output
,
pooled_output
=
outputs
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
...
@@ -830,8 +835,9 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -830,8 +835,9 @@ class BertForPreTraining(BertPreTrainedModel):
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
total_loss
=
masked_lm_loss
+
next_sentence_loss
return
total_loss
return
total_loss
else
:
elif
self
.
output_attentions
:
return
prediction_scores
,
seq_relationship_score
return
all_attentions
,
prediction_scores
,
seq_relationship_score
return
prediction_scores
,
seq_relationship_score
class
BertForMaskedLM
(
BertPreTrainedModel
):
class
BertForMaskedLM
(
BertPreTrainedModel
):
...
@@ -876,23 +882,29 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -876,23 +882,29 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
):
sequence_
output
,
_
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output
s
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
prediction_scores
=
self
.
cls
(
sequence_output
)
prediction_scores
=
self
.
cls
(
sequence_output
)
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
return
masked_lm_loss
return
masked_lm_loss
else
:
elif
self
.
output_attentions
:
return
prediction_scores
return
all_attentions
,
prediction_scores
return
prediction_scores
class
BertForNextSentencePrediction
(
BertPreTrainedModel
):
class
BertForNextSentencePrediction
(
BertPreTrainedModel
):
...
@@ -938,23 +950,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
...
@@ -938,23 +950,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
):
_
,
pooled_
output
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output
s
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
output_all_encoded_layers
=
False
)
seq_relationship_score
=
self
.
cls
(
pooled_output
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
seq_relationship_score
=
self
.
cls
(
pooled_output
)
if
next_sentence_label
is
not
None
:
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
return
next_sentence_loss
return
next_sentence_loss
else
:
elif
self
.
output_attentions
:
return
seq_relationship_score
return
all_attentions
,
seq_relationship_score
return
seq_relationship_score
class
BertForSequenceClassification
(
BertPreTrainedModel
):
class
BertForSequenceClassification
(
BertPreTrainedModel
):
...
@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
...
@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
num_labels
):
def
__init__
(
self
,
config
,
num_labels
,
output_attentions
=
False
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
):
_
,
pooled_output
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
pooled_output
=
self
.
dropout
(
pooled_output
)
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
...
@@ -1019,8 +1042,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
...
@@ -1019,8 +1042,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
return
loss
else
:
elif
self
.
output_attentions
:
return
logits
return
all_attentions
,
logits
return
logits
class
BertForMultipleChoice
(
BertPreTrainedModel
):
class
BertForMultipleChoice
(
BertPreTrainedModel
):
...
@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
num_choices
):
def
__init__
(
self
,
config
,
num_choices
,
output_attentions
=
False
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_choices
=
num_choices
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
...
@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
_
,
pooled_output
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
pooled_output
=
self
.
dropout
(
pooled_output
)
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
self
.
num_choices
)
reshaped_logits
=
logits
.
view
(
-
1
,
self
.
num_choices
)
...
@@ -1088,8 +1117,9 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1088,8 +1117,9 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
loss
=
loss_fct
(
reshaped_logits
,
labels
)
return
loss
return
loss
else
:
elif
self
.
output_attentions
:
return
reshaped_logits
return
all_attentions
,
reshaped_logits
return
reshaped_logits
class
BertForTokenClassification
(
BertPreTrainedModel
):
class
BertForTokenClassification
(
BertPreTrainedModel
):
...
@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel):
...
@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
num_labels
):
def
__init__
(
self
,
config
,
num_labels
,
output_attentions
=
False
):
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
):
sequence_output
,
_
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
sequence_output
=
self
.
dropout
(
sequence_output
)
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
...
@@ -1161,8 +1196,9 @@ class BertForTokenClassification(BertPreTrainedModel):
...
@@ -1161,8 +1196,9 @@ class BertForTokenClassification(BertPreTrainedModel):
else
:
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
return
loss
else
:
elif
self
.
output_attentions
:
return
logits
return
all_attentions
,
logits
return
logits
class
BertForQuestionAnswering
(
BertPreTrainedModel
):
class
BertForQuestionAnswering
(
BertPreTrainedModel
):
...
@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
...
@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
output_attentions
=
output_attentions
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
):
sequence_output
,
_
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
...
@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
...
@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
return
total_loss
else
:
elif
self
.
output_attentions
:
return
start_logits
,
end_logits
return
all_attentions
,
start_logits
,
end_logits
return
start_logits
,
end_logits
tests/modeling_test.py
View file @
5e1207b8
...
@@ -28,7 +28,7 @@ import torch
...
@@ -28,7 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
...
@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size
=
2
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
):
scope
=
None
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
...
@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
...
@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BertConfig
(
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
...
@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
list
(
result
[
"loss"
].
size
()),
[])
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
=
BertModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
...
@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list
(
result
[
"prediction_scores"
].
size
()),
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
[
self
.
batch_size
,
2
])
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
...
@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
[
self
.
batch_size
,
2
])
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
...
@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
self
.
seq_length
])
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
num_labels
])
[
self
.
batch_size
,
self
.
num_labels
])
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase):
...
@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
create_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMultipleChoice
(
config
=
config
,
num_choices
=
self
.
num_choices
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
choice_labels
)
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_multiple_choice
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
def
create_and_check_bert_for_attentions
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
output_attentions
=
True
)
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
output
[
0
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase):
...
@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase):
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_multiple_choice
(
*
config_and_inputs
)
tester
.
check_bert_for_multiple_choice
(
output_result
)
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment