Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
96c4d3d9
Commit
96c4d3d9
authored
Jun 17, 2019
by
thomwolf
Browse files
add head masking tests
parent
34858ae1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
28 deletions
+142
-28
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+100
-28
tests/modeling_test.py
tests/modeling_test.py
+42
-0
No files found.
pytorch_pretrained_bert/modeling.py
View file @
96c4d3d9
...
...
@@ -51,6 +51,32 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
prune_linear_layer
(
layer
,
index
,
dim
=-
1
):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
dim
=
(
dim
+
100
)
%
2
index
=
index
.
to
(
layer
.
weight
.
device
)
W
=
layer
.
weight
.
index_select
(
dim
,
index
).
clone
().
detach
()
if
layer
.
bias
is
not
None
:
if
dim
==
1
:
b
=
layer
.
bias
.
clone
().
detach
()
else
:
b
=
layer
.
bias
[
index
].
clone
().
detach
()
new_size
=
list
(
layer
.
weight
.
size
())
new_size
[
dim
]
=
len
(
index
)
new_layer
=
nn
.
Linear
(
new_size
[
1
],
new_size
[
0
],
bias
=
layer
.
bias
is
not
None
)
new_layer
.
weight
.
requires_grad
=
False
new_layer
.
weight
.
copy_
(
W
.
contiguous
())
new_layer
.
weight
.
requires_grad
=
True
if
layer
.
bias
is
not
None
:
new_layer
.
bias
.
requires_grad
=
False
new_layer
.
bias
.
copy_
(
b
.
contiguous
())
new_layer
.
bias
.
requires_grad
=
True
return
new_layer
def
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
...
...
@@ -329,12 +355,7 @@ class BertSelfAttention(nn.Module):
attention_probs
=
self
.
dropout
(
attention_probs
)
# Mask heads if we want to
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can define heads to mask for each instance in the batch
attention_probs
=
attention_probs
*
head_mask
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
...
...
@@ -365,12 +386,28 @@ class BertSelfOutput(nn.Module):
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
output
=
BertSelfOutput
(
config
)
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
self
.
n_heads
,
self
.
self
.
d_head
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
self
.
query
=
prune_linear_layer
(
self
.
self
.
query
,
index
)
self
.
self
.
key
=
prune_linear_layer
(
self
.
self
.
key
,
index
)
self
.
self
.
value
=
prune_linear_layer
(
self
.
self
.
value
,
index
)
self
.
output
.
dense
=
prune_linear_layer
(
self
.
output
.
dense
,
index
,
dim
=
0
)
# Update hyper params
self
.
self
.
num_attention_heads
=
self
.
self
.
num_attention_heads
-
len
(
heads
)
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_output
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
if
self
.
output_attentions
:
...
...
@@ -411,10 +448,11 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
...
...
@@ -430,10 +468,11 @@ class BertLayer(nn.Module):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
...
...
@@ -741,14 +780,28 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
...
...
@@ -770,6 +823,17 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1 in head_mask indicate we need to mask the head
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each instance in batch
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
(
1.0
-
head_mask
)
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
...
...
@@ -836,10 +900,11 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -905,10 +970,11 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -974,10 +1040,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1045,11 +1112,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1116,11 +1184,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1192,11 +1261,12 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1273,14 +1343,16 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
...
...
tests/modeling_test.py
View file @
96c4d3d9
...
...
@@ -293,6 +293,47 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
create_and_check_bert_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
]
=
0.0
head_mask
[
-
1
]
=
0.0
# Mask all but the first and last heads
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
num_attention_heads
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
num_attention_heads
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
...
@@ -352,6 +393,7 @@ class BertModelTest(unittest.TestCase):
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment