Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d9184620
"test/assets/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a4ca717f2225c0c9e06b74190617d87b1207da29"
Commit
d9184620
authored
Jun 29, 2019
by
thomwolf
Browse files
fix tests and new API
parent
213981d8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
138 deletions
+99
-138
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+10
-20
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+4
-4
tests/modeling_test.py
tests/modeling_test.py
+78
-88
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+7
-26
No files found.
pytorch_pretrained_bert/modeling.py
View file @
d9184620
...
@@ -320,9 +320,6 @@ class BertSelfAttention(nn.Module):
...
@@ -320,9 +320,6 @@ class BertSelfAttention(nn.Module):
attention_probs
=
attention_probs
*
head_mask
attention_probs
=
attention_probs
*
head_mask
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
context_layer
self
.
multihead_output
.
retain_grad
()
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
...
@@ -416,7 +413,8 @@ class BertLayer(nn.Module):
...
@@ -416,7 +413,8 @@ class BertLayer(nn.Module):
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
intermediate_output
=
self
.
intermediate
(
attention_outputs
[
0
])
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
return
outputs
...
@@ -571,8 +569,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -571,8 +569,7 @@ class BertModel(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -688,8 +685,7 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -688,8 +685,7 @@ class BertForPreTraining(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -770,8 +766,7 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -770,8 +766,7 @@ class BertForMaskedLM(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -845,8 +840,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
...
@@ -845,8 +840,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
@@ -919,8 +913,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
...
@@ -919,8 +913,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
Inputs:
...
@@ -1003,8 +996,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1003,8 +996,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
Inputs:
...
@@ -1085,8 +1077,7 @@ class BertForTokenClassification(BertPreTrainedModel):
...
@@ -1085,8 +1077,7 @@ class BertForTokenClassification(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
Inputs:
...
@@ -1170,8 +1161,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
...
@@ -1170,8 +1161,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Params:
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
d9184620
...
@@ -504,10 +504,10 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -504,10 +504,10 @@ class XLNetRelativeAttention(nn.Module):
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
output_g
=
None
outputs
=
[
output_h
,
output_g
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
output_h
,
output_g
,
attn_prob
outputs
=
outputs
+
[
attn_prob
]
return
outputs
return
output_h
,
output_g
class
XLNetFeedForward
(
nn
.
Module
):
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -867,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -867,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
outputs
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
outputs
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
head_mask
=
head_mask
[
i
]
)
output_h
,
output_g
=
outputs
[:
2
]
output_h
,
output_g
=
outputs
[:
2
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
.
append
(
outputs
[
2
:])
attentions
.
append
(
outputs
[
2
:])
...
...
tests/modeling_test.py
View file @
d9184620
...
@@ -123,9 +123,13 @@ class BertModelTest(unittest.TestCase):
...
@@ -123,9 +123,13 @@ class BertModelTest(unittest.TestCase):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
=
BertModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
model
=
BertModel
(
config
=
config
,
output_hidden_states
=
True
)
model
.
eval
()
_
,
_
,
all_encoder_layers
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"sequence_output"
:
all_encoder_layers
[
-
1
]
,
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
"pooled_output"
:
pooled_output
,
"all_encoder_layers"
:
all_encoder_layers
,
"all_encoder_layers"
:
all_encoder_layers
,
}
}
...
@@ -134,7 +138,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -134,7 +138,7 @@ class BertModelTest(unittest.TestCase):
def
check_bert_model_output
(
self
,
result
):
def
check_bert_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
self
.
num_hidden_layers
)
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
(
self
.
num_hidden_layers
+
1
)
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
...
@@ -144,8 +148,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -144,8 +148,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
,
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"prediction_scores"
:
prediction_scores
,
...
@@ -160,8 +163,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -160,8 +163,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
"seq_relationship_score"
:
seq_relationship_score
,
...
@@ -177,8 +179,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -177,8 +179,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
loss
,
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"prediction_scores"
:
prediction_scores
,
...
@@ -198,8 +199,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -198,8 +199,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
loss
,
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"start_logits"
:
start_logits
,
...
@@ -219,8 +219,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -219,8 +219,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
...
@@ -236,8 +235,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -236,8 +235,7 @@ class BertModelTest(unittest.TestCase):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
...
@@ -256,13 +254,10 @@ class BertModelTest(unittest.TestCase):
...
@@ -256,13 +254,10 @@ class BertModelTest(unittest.TestCase):
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
=
model
(
multiple_choice_inputs_ids
,
loss
,
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
multiple_choice_input_mask
,
choice_labels
)
choice_labels
)
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
...
@@ -285,8 +280,8 @@ class BertModelTest(unittest.TestCase):
...
@@ -285,8 +280,8 @@ class BertModelTest(unittest.TestCase):
else
:
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
model
.
eval
()
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
output
s
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
output
[
0
]
attentions
=
output
s
[
-
1
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
list
(
attentions
[
0
].
size
()),
...
@@ -300,57 +295,56 @@ class BertModelTest(unittest.TestCase):
...
@@ -300,57 +295,56 @@ class BertModelTest(unittest.TestCase):
if
model_class
in
[
BertForSequenceClassification
,
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
num_labels
=
self
.
num_labels
)
keep_multihead_output
=
True
)
else
:
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
=
model_class
(
config
=
config
)
model
.
eval
()
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
=
torch
.
ones
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
0.0
# Mask all but the first and last heads on the first layer
head_mask
[
0
,
1
:
-
1
]
=
0.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
0.0
# Mask all but the first head on the last layer
head_mask
[
-
1
,
1
:]
=
0.0
# Mask all but the first head on the last layer
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
BertModel
):
# Compute some gradients
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
=
output
.
sum
()
output
.
backward
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
).
get_multihead_outputs
()
multihead_outputs
=
head_mask
.
grad
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
#
self.parent.assertListEqual(
list
(
multihead_outputs
[
0
].
size
()),
#
list(multihead_outputs[0].size()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
#
[self.batch_size, self.num_attention_heads,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
#
self.seq_length, self.hidden_size // self.num_attention_heads])
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
num_attention_heads
-
1
),
:,
:].
nonzero
()),
#
len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
0
)
#
0)
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
#
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
#
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
0
][:,
self
.
num_attention_heads
-
1
,
:,
:].
nonzero
()),
#
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
#
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self
.
parent
.
assertListEqual
(
#
self.parent.assertListEqual(
list
(
multihead_outputs
[
1
].
size
()),
#
list(multihead_outputs[1].size()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
#
[self.batch_size, self.num_attention_heads,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
#
self.seq_length, self.hidden_size // self.num_attention_heads])
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
1
].
nonzero
()),
#
len(multihead_outputs[1].nonzero()),
multihead_outputs
[
1
].
numel
())
#
multihead_outputs[1].numel())
self
.
parent
.
assertListEqual
(
#
self.parent.assertListEqual(
list
(
multihead_outputs
[
-
1
].
size
()),
#
list(multihead_outputs[-1].size()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
#
[self.batch_size, self.num_attention_heads,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
#
self.seq_length, self.hidden_size // self.num_attention_heads])
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
#
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0
)
#
0)
self
.
parent
.
assertEqual
(
#
self.parent.assertEqual(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
#
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
#
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
@@ -360,38 +354,34 @@ class BertModelTest(unittest.TestCase):
...
@@ -360,38 +354,34 @@ class BertModelTest(unittest.TestCase):
if
model_class
in
[
BertForSequenceClassification
,
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
num_labels
=
self
.
num_labels
)
keep_multihead_output
=
True
)
else
:
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
=
model_class
(
config
=
config
)
model
.
eval
()
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
bert_model
.
prune_heads
(
heads_to_prune
)
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
if
isinstance
(
model
,
BertModel
):
# output = sum(t.sum() for t in outputs[0])
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
# output = output.sum()
elif
isinstance
(
output
,
(
list
,
tuple
)):
# output.backward()
output
=
sum
(
t
.
sum
()
for
t
in
output
)
# multihead_outputs = bert_model.get_multihead_outputs()
output
=
output
.
sum
()
output
.
backward
()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
multihead_outputs
=
bert_model
.
get_multihead_outputs
()
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
# [self.batch_size, 1,
self
.
parent
.
assertListEqual
(
# self.seq_length, self.hidden_size // self.num_attention_heads])
list
(
multihead_outputs
[
0
].
size
()),
# self.parent.assertListEqual(
[
self
.
batch_size
,
1
,
# list(multihead_outputs[1].size()),
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
# [self.batch_size, self.num_attention_heads,
self
.
parent
.
assertListEqual
(
# self.seq_length, self.hidden_size // self.num_attention_heads])
list
(
multihead_outputs
[
1
].
size
()),
# self.parent.assertListEqual(
[
self
.
batch_size
,
self
.
num_attention_heads
,
# list(multihead_outputs[-1].size()),
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
# [self.batch_size, self.num_attention_heads-1,
self
.
parent
.
assertListEqual
(
# self.seq_length, self.hidden_size // self.num_attention_heads])
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
-
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
test_default
(
self
):
def
test_default
(
self
):
...
...
tests/modeling_xlnet_test.py
View file @
d9184620
...
@@ -134,26 +134,19 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -134,26 +134,19 @@ class XLNetModelTest(unittest.TestCase):
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
loss_1
,
all_logits_1
,
mems_1
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
all_logits_1
,
mems_1b
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1a
)
loss_2
,
all_logits_2
,
mems_2
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1
)
all_logits_2
,
mems_2b
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
mems
=
mems_1b
)
logits
,
_
=
model
(
input_ids_q
,
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
outputs
=
{
outputs
=
{
"loss_1"
:
loss_1
,
"loss_1"
:
loss_1
,
"mems_1
a
"
:
mems_1
a
,
"mems_1"
:
mems_1
,
"all_logits_1"
:
all_logits_1
,
"all_logits_1"
:
all_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"loss_2"
:
loss_2
,
"mems_2
a
"
:
mems_2
a
,
"mems_2"
:
mems_2
,
"all_logits_2"
:
all_logits_2
,
"all_logits_2"
:
all_logits_2
,
"mems_2b"
:
mems_2b
,
}
}
return
outputs
return
outputs
...
@@ -165,14 +158,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -165,14 +158,8 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"all_logits_1"
].
size
()),
list
(
result
[
"all_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1
a
"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
list
(
result
[
"loss_2"
].
size
()),
...
@@ -181,14 +168,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -181,14 +168,8 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"all_logits_2"
].
size
()),
list
(
result
[
"all_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2
a
"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment