Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d9184620
Commit
d9184620
authored
Jun 29, 2019
by
thomwolf
Browse files
fix tests and new API
parent
213981d8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
138 deletions
+99
-138
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+10
-20
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+4
-4
tests/modeling_test.py
tests/modeling_test.py
+78
-88
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+7
-26
No files found.
pytorch_pretrained_bert/modeling.py
View file @
d9184620
...
@@ -320,9 +320,6 @@ class BertSelfAttention(nn.Module):
...
@@ -320,9 +320,6 @@ class BertSelfAttention(nn.Module):
attention_probs
=
attention_probs
*
head_mask
attention_probs
=
attention_probs
*
head_mask
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
context_layer
self
.
multihead_output
.
retain_grad
()
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
...
@@ -416,7 +413,8 @@ class BertLayer(nn.Module):
...
@@ -416,7 +413,8 @@ class BertLayer(nn.Module):
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
intermediate_output
=
self
.
intermediate
(
attention_outputs
[
0
])
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
return
outputs
...
@@ -571,8 +569,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -571,8 +569,7 @@ class BertModel(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -688,8 +685,7 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -688,8 +685,7 @@ class BertForPreTraining(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -770,8 +766,7 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -770,8 +766,7 @@ class BertForMaskedLM(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -845,8 +840,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
...
@@ -845,8 +840,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -919,8 +913,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
...
@@ -919,8 +913,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
Inputs:
...
@@ -1003,8 +996,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1003,8 +996,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
Inputs:
...
@@ -1085,8 +1077,7 @@ class BertForTokenClassification(BertPreTrainedModel):
...
@@ -1085,8 +1077,7 @@ class BertForTokenClassification(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
Inputs:
...
@@ -1170,8 +1161,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
...
@@ -1170,8 +1161,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
d9184620
...
@@ -504,10 +504,10 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -504,10 +504,10 @@ class XLNetRelativeAttention(nn.Module):
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
output_g
=
None
outputs
=
[
output_h
,
output_g
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
output_h
,
output_g
,
attn_prob
outputs
=
outputs
+
[
attn_prob
]
return
outputs
return
output_h
,
output_g
class
XLNetFeedForward
(
nn
.
Module
):
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -867,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -867,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
outputs
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
outputs
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
head_mask
=
head_mask
[
i
]
)
output_h
,
output_g
=
outputs
[:
2
]
output_h
,
output_g
=
outputs
[:
2
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
.
append
(
outputs
[
2
:])
attentions
.
append
(
outputs
[
2
:])
...
...
tests/modeling_test.py
View file @
d9184620
...
@@ -123,9 +123,13 @@ class BertModelTest(unittest.TestCase):
...
@@ -123,9 +123,13 @@ class BertModelTest(unittest.TestCase):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
=
BertModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
model
=
BertModel
(
config
=
config
,
output_hidden_states
=
True
)
model
.
eval
()
_
,
_
,
all_encoder_layers
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"sequence_output"
:
all_encoder_layers
[
-
1
]
,
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
"pooled_output"
:
pooled_output
,
"all_encoder_layers"
:
all_encoder_layers
,
"all_encoder_layers"
:
all_encoder_layers
,
}
}
...
@@ -134,7 +138,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -134,7 +138,7 @@ class BertModelTest(unittest.TestCase):
def
check_bert_model_output
(
self
,
result
):
def
check_bert_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
self
.
num_hidden_layers
)
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
(
self
.
num_hidden_layers
+
1
)
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
...
@@ -144,8 +148,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -144,8 +148,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
,
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"prediction_scores"
:
prediction_scores
,
...
@@ -160,8 +163,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -160,8 +163,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
"seq_relationship_score"
:
seq_relationship_score
,
...
@@ -177,8 +179,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -177,8 +179,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
loss
,
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"prediction_scores"
:
prediction_scores
,
...
@@ -198,8 +199,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -198,8 +199,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
loss
,
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"start_logits"
:
start_logits
,
...
@@ -219,8 +219,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -219,8 +219,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
...
@@ -236,8 +235,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -236,8 +235,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
...
@@ -256,13 +254,10 @@ class BertModelTest(unittest.TestCase):
...
@@ -256,13 +254,10 @@ class BertModelTest(unittest.TestCase):
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
=
model
(
multiple_choice_inputs_ids
,
loss
,
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
multiple_choice_input_mask
,
choice_labels
)
choice_labels
)
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
...
@@ -285,8 +280,8 @@ class BertModelTest(unittest.TestCase):
...
@@ -285,8 +280,8 @@ class BertModelTest(unittest.TestCase):
else
:
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
model
.
eval
()
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
output
s
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
output
[
0
]
attentions
=
output
s
[
-
1
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
list
(
attentions
[
0
].
size
()),
...
@@ -300,57 +295,56 @@ class BertModelTest(unittest.TestCase):
...
@@ -300,57 +295,56 @@ class BertModelTest(unittest.TestCase):
if
model_class
in
[
BertForSequenceClassification
,
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
num_labels
=
self
.
num_labels
)
keep_multihead_output
=
True
)
else
:
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
=
model_class
(
config
=
config
)
model
.
eval
()
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
=
torch
.
ones
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
0.0
# Mask all but the first and last heads on the first layer
head_mask
[
0
,
1
:
-
1
]
=
0.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
0.0
# Mask all but the first head on the last layer
head_mask
[
-
1
,
1
:]
=
0.0
# Mask all but the first head on the last layer
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
BertModel
):
# Compute some gradients
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
=
output
.
sum
()
output
.
backward
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
).
get_multihead_outputs
()
multihead_outputs
=
head_mask
.
grad
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
#
self.parent.assertListEqual(
list
(
multihead_outputs
[
0
].
size
()),
#
list(multihead_outputs[0].size()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
#
[self.batch_size, self.num_attention_heads,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
#
self.seq_length, self.hidden_size // self.num_attention_heads])
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
num_attention_heads
-
1
),
:,
:].
nonzero
()),
#
len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
0
)
#
0)
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
#
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
#
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
0
][:,
self
.
num_attention_heads
-
1
,
:,
:].
nonzero
()),
#
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
#
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self
.
parent
.
assertListEqual
(
#
self.parent.assertListEqual(
list
(
multihead_outputs
[
1
].
size
()),
#
list(multihead_outputs[1].size()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
#
[self.batch_size, self.num_attention_heads,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
#
self.seq_length, self.hidden_size // self.num_attention_heads])
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
1
].
nonzero
()),
#
len(multihead_outputs[1].nonzero()),
multihead_outputs
[
1
].
numel
())
#
multihead_outputs[1].numel())
self
.
parent
.
assertListEqual
(
#
self.parent.assertListEqual(
list
(
multihead_outputs
[
-
1
].
size
()),
#
list(multihead_outputs[-1].size()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
#
[self.batch_size, self.num_attention_heads,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
#
self.seq_length, self.hidden_size // self.num_attention_heads])
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
#
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0
)
#
0)
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
#
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
#
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
@@ -360,38 +354,34 @@ class BertModelTest(unittest.TestCase):
...
@@ -360,38 +354,34 @@ class BertModelTest(unittest.TestCase):
if
model_class
in
[
BertForSequenceClassification
,
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
num_labels
=
self
.
num_labels
)
keep_multihead_output
=
True
)
else
:
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
=
model_class
(
config
=
config
)
model
.
eval
()
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
bert_model
.
prune_heads
(
heads_to_prune
)
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
if
isinstance
(
model
,
BertModel
):
# output = sum(t.sum() for t in outputs[0])
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
# output = output.sum()
elif
isinstance
(
output
,
(
list
,
tuple
)):
# output.backward()
output
=
sum
(
t
.
sum
()
for
t
in
output
)
# multihead_outputs = bert_model.get_multihead_outputs()
output
=
output
.
sum
()
output
.
backward
()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
multihead_outputs
=
bert_model
.
get_multihead_outputs
()
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
# [self.batch_size, 1,
self
.
parent
.
assertListEqual
(
# self.seq_length, self.hidden_size // self.num_attention_heads])
list
(
multihead_outputs
[
0
].
size
()),
# self.parent.assertListEqual(
[
self
.
batch_size
,
1
,
# list(multihead_outputs[1].size()),
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
# [self.batch_size, self.num_attention_heads,
self
.
parent
.
assertListEqual
(
# self.seq_length, self.hidden_size // self.num_attention_heads])
list
(
multihead_outputs
[
1
].
size
()),
# self.parent.assertListEqual(
[
self
.
batch_size
,
self
.
num_attention_heads
,
# list(multihead_outputs[-1].size()),
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
# [self.batch_size, self.num_attention_heads-1,
self
.
parent
.
assertListEqual
(
# self.seq_length, self.hidden_size // self.num_attention_heads])
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
-
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
test_default
(
self
):
def
test_default
(
self
):
...
...
tests/modeling_xlnet_test.py
View file @
d9184620
...
@@ -134,26 +134,19 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -134,26 +134,19 @@ class XLNetModelTest(unittest.TestCase):
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
loss_1
,
all_logits_1
,
mems_1
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
all_logits_1
,
mems_1b
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1a
)
loss_2
,
all_logits_2
,
mems_2
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1
)
all_logits_2
,
mems_2b
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
mems
=
mems_1b
)
logits
,
_
=
model
(
input_ids_q
,
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
outputs
=
{
outputs
=
{
"loss_1"
:
loss_1
,
"loss_1"
:
loss_1
,
"mems_1
a
"
:
mems_1
a
,
"mems_1"
:
mems_1
,
"all_logits_1"
:
all_logits_1
,
"all_logits_1"
:
all_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"loss_2"
:
loss_2
,
"mems_2
a
"
:
mems_2
a
,
"mems_2"
:
mems_2
,
"all_logits_2"
:
all_logits_2
,
"all_logits_2"
:
all_logits_2
,
"mems_2b"
:
mems_2b
,
}
}
return
outputs
return
outputs
...
@@ -165,14 +158,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -165,14 +158,8 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"all_logits_1"
].
size
()),
list
(
result
[
"all_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1
a
"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
list
(
result
[
"loss_2"
].
size
()),
...
@@ -181,14 +168,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -181,14 +168,8 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"all_logits_2"
].
size
()),
list
(
result
[
"all_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2
a
"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment