Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
96c4d3d9
Commit
96c4d3d9
authored
Jun 17, 2019
by
thomwolf
Browse files
add head masking tests
parent
34858ae1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
28 deletions
+142
-28
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+100
-28
tests/modeling_test.py
tests/modeling_test.py
+42
-0
No files found.
pytorch_pretrained_bert/modeling.py
View file @
96c4d3d9
...
...
@@ -51,6 +51,32 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
prune_linear_layer
(
layer
,
index
,
dim
=-
1
):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
dim
=
(
dim
+
100
)
%
2
index
=
index
.
to
(
layer
.
weight
.
device
)
W
=
layer
.
weight
.
index_select
(
dim
,
index
).
clone
().
detach
()
if
layer
.
bias
is
not
None
:
if
dim
==
1
:
b
=
layer
.
bias
.
clone
().
detach
()
else
:
b
=
layer
.
bias
[
index
].
clone
().
detach
()
new_size
=
list
(
layer
.
weight
.
size
())
new_size
[
dim
]
=
len
(
index
)
new_layer
=
nn
.
Linear
(
new_size
[
1
],
new_size
[
0
],
bias
=
layer
.
bias
is
not
None
)
new_layer
.
weight
.
requires_grad
=
False
new_layer
.
weight
.
copy_
(
W
.
contiguous
())
new_layer
.
weight
.
requires_grad
=
True
if
layer
.
bias
is
not
None
:
new_layer
.
bias
.
requires_grad
=
False
new_layer
.
bias
.
copy_
(
b
.
contiguous
())
new_layer
.
bias
.
requires_grad
=
True
return
new_layer
def
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
...
...
@@ -329,12 +355,7 @@ class BertSelfAttention(nn.Module):
attention_probs
=
self
.
dropout
(
attention_probs
)
# Mask heads if we want to
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can define heads to mask for each instance in the batch
attention_probs
=
attention_probs
*
head_mask
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
...
...
@@ -365,12 +386,28 @@ class BertSelfOutput(nn.Module):
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
output
=
BertSelfOutput
(
config
)
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
self
.
n_heads
,
self
.
self
.
d_head
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
self
.
query
=
prune_linear_layer
(
self
.
self
.
query
,
index
)
self
.
self
.
key
=
prune_linear_layer
(
self
.
self
.
key
,
index
)
self
.
self
.
value
=
prune_linear_layer
(
self
.
self
.
value
,
index
)
self
.
output
.
dense
=
prune_linear_layer
(
self
.
output
.
dense
,
index
,
dim
=
0
)
# Update hyper params
self
.
self
.
num_attention_heads
=
self
.
self
.
num_attention_heads
-
len
(
heads
)
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_output
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
if
self
.
output_attentions
:
...
...
@@ -411,10 +448,11 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
...
...
@@ -430,10 +468,11 @@ class BertLayer(nn.Module):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
...
...
@@ -741,14 +780,28 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
...
...
@@ -770,6 +823,17 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1 in head_mask indicate we need to mask the head
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each instance in batch
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
(
1.0
-
head_mask
)
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
...
...
@@ -836,10 +900,11 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -905,10 +970,11 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -974,10 +1040,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1045,11 +1112,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1116,11 +1184,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1192,11 +1261,12 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_bert_weights
)
...
...
@@ -1273,14 +1343,16 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
...
...
tests/modeling_test.py
View file @
96c4d3d9
...
...
@@ -293,6 +293,47 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
create_and_check_bert_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
]
=
0.0
head_mask
[
-
1
]
=
0.0
# Mask all but the first and last heads
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
num_attention_heads
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
num_attention_heads
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
...
@@ -352,6 +393,7 @@ class BertModelTest(unittest.TestCase):
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment