Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7220d47a
Commit
7220d47a
authored
Jun 17, 2019
by
thomwolf
Browse files
adding head pruning and tests
parent
8415a38b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
4 deletions
+45
-4
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+3
-4
tests/modeling_test.py
tests/modeling_test.py
+42
-0
No files found.
pytorch_pretrained_bert/modeling.py
View file @
7220d47a
...
@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_CONFIG_NAME
=
'bert_config.json'
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
prune_linear_layer
(
layer
,
index
,
dim
=
-
1
):
def
prune_linear_layer
(
layer
,
index
,
dim
=
0
):
""" Prune a linear layer (a model parameters) to keep only entries in index.
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
Used to remove heads.
"""
"""
dim
=
(
dim
+
100
)
%
2
index
=
index
.
to
(
layer
.
weight
.
device
)
index
=
index
.
to
(
layer
.
weight
.
device
)
W
=
layer
.
weight
.
index_select
(
dim
,
index
).
clone
().
detach
()
W
=
layer
.
weight
.
index_select
(
dim
,
index
).
clone
().
detach
()
if
layer
.
bias
is
not
None
:
if
layer
.
bias
is
not
None
:
...
@@ -394,7 +393,7 @@ class BertAttention(nn.Module):
...
@@ -394,7 +393,7 @@ class BertAttention(nn.Module):
self
.
output
=
BertSelfOutput
(
config
)
self
.
output
=
BertSelfOutput
(
config
)
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
self
.
n_heads
,
self
.
self
.
d_head
)
mask
=
torch
.
ones
(
self
.
self
.
n
um_attention
_heads
,
self
.
self
.
attention_head_size
)
for
head
in
heads
:
for
head
in
heads
:
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
...
@@ -403,7 +402,7 @@ class BertAttention(nn.Module):
...
@@ -403,7 +402,7 @@ class BertAttention(nn.Module):
self
.
self
.
query
=
prune_linear_layer
(
self
.
self
.
query
,
index
)
self
.
self
.
query
=
prune_linear_layer
(
self
.
self
.
query
,
index
)
self
.
self
.
key
=
prune_linear_layer
(
self
.
self
.
key
,
index
)
self
.
self
.
key
=
prune_linear_layer
(
self
.
self
.
key
,
index
)
self
.
self
.
value
=
prune_linear_layer
(
self
.
self
.
value
,
index
)
self
.
self
.
value
=
prune_linear_layer
(
self
.
self
.
value
,
index
)
self
.
output
.
dense
=
prune_linear_layer
(
self
.
output
.
dense
,
index
,
dim
=
0
)
self
.
output
.
dense
=
prune_linear_layer
(
self
.
output
.
dense
,
index
,
dim
=
1
)
# Update hyper params
# Update hyper params
self
.
self
.
num_attention_heads
=
self
.
self
.
num_attention_heads
-
len
(
heads
)
self
.
self
.
num_attention_heads
=
self
.
self
.
num_attention_heads
-
len
(
heads
)
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
...
...
tests/modeling_test.py
View file @
7220d47a
...
@@ -334,6 +334,47 @@ class BertModelTest(unittest.TestCase):
...
@@ -334,6 +334,47 @@ class BertModelTest(unittest.TestCase):
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
bert_model
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
-
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
@@ -394,6 +435,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -394,6 +435,7 @@ class BertModelTest(unittest.TestCase):
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment