Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7220d47a
Commit
7220d47a
authored
Jun 17, 2019
by
thomwolf
Browse files
adding head pruning and tests
parent
8415a38b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
4 deletions
+45
-4
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+3
-4
tests/modeling_test.py
tests/modeling_test.py
+42
-0
No files found.
pytorch_pretrained_bert/modeling.py
View file @
7220d47a
...
...
@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
prune_linear_layer
(
layer
,
index
,
dim
=
-
1
):
def
prune_linear_layer
(
layer
,
index
,
dim
=
0
):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
dim
=
(
dim
+
100
)
%
2
index
=
index
.
to
(
layer
.
weight
.
device
)
W
=
layer
.
weight
.
index_select
(
dim
,
index
).
clone
().
detach
()
if
layer
.
bias
is
not
None
:
...
...
@@ -394,7 +393,7 @@ class BertAttention(nn.Module):
self
.
output
=
BertSelfOutput
(
config
)
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
self
.
n_heads
,
self
.
self
.
d_head
)
mask
=
torch
.
ones
(
self
.
self
.
n
um_attention
_heads
,
self
.
self
.
attention_head_size
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
...
...
@@ -403,7 +402,7 @@ class BertAttention(nn.Module):
self
.
self
.
query
=
prune_linear_layer
(
self
.
self
.
query
,
index
)
self
.
self
.
key
=
prune_linear_layer
(
self
.
self
.
key
,
index
)
self
.
self
.
value
=
prune_linear_layer
(
self
.
self
.
value
,
index
)
self
.
output
.
dense
=
prune_linear_layer
(
self
.
output
.
dense
,
index
,
dim
=
0
)
self
.
output
.
dense
=
prune_linear_layer
(
self
.
output
.
dense
,
index
,
dim
=
1
)
# Update hyper params
self
.
self
.
num_attention_heads
=
self
.
self
.
num_attention_heads
-
len
(
heads
)
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
...
...
tests/modeling_test.py
View file @
7220d47a
...
...
@@ -334,6 +334,47 @@ class BertModelTest(unittest.TestCase):
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
bert_model
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
-
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
...
@@ -394,6 +435,7 @@ class BertModelTest(unittest.TestCase):
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment