Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
abb23a78
Commit
abb23a78
authored
Nov 07, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
Head pruning for ALBERT
parent
4374eaea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
5 deletions
+49
-5
transformers/modeling_albert.py
transformers/modeling_albert.py
+42
-0
transformers/tests/modeling_albert_test.py
transformers/tests/modeling_albert_test.py
+7
-5
No files found.
transformers/modeling_albert.py
View file @
abb23a78
...
...
@@ -145,6 +145,29 @@ class AlbertAttention(BertSelfAttention):
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and emove already pruned heads
for
head
in
heads
:
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
query
=
prune_linear_layer
(
self
.
query
,
index
)
self
.
key
=
prune_linear_layer
(
self
.
key
,
index
)
self
.
value
=
prune_linear_layer
(
self
.
value
,
index
)
self
.
dense
=
prune_linear_layer
(
self
.
dense
,
index
,
dim
=
1
)
# Update hyper params and store pruned heads
self
.
num_attention_heads
=
self
.
num_attention_heads
-
len
(
heads
)
self
.
all_head_size
=
self
.
attention_head_size
*
self
.
num_attention_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
head_mask
=
None
):
mixed_query_layer
=
self
.
query
(
input_ids
)
mixed_key_layer
=
self
.
key
(
input_ids
)
...
...
@@ -409,6 +432,25 @@ class AlbertModel(AlbertPreTrainedModel):
self
.
embeddings
.
word_embeddings
=
new_embeddings
return
self
.
embeddings
.
word_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups.
If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there
is a total of 4 different layers.
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
while [2,3] correspond to the two inner groups of the second hidden layer.
Any layer with in index other than [0,1,2,3] will result in an error.
See base class PreTrainedModel for more information about head pruning
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
group_idx
=
int
(
layer
/
self
.
config
.
inner_group_num
)
inner_group_idx
=
int
(
layer
-
group_idx
*
self
.
config
.
inner_group_num
)
self
.
encoder
.
albert_layer_groups
[
group_idx
].
albert_layers
[
inner_group_idx
].
attention
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
...
...
transformers/tests/modeling_albert_test.py
View file @
abb23a78
...
...
@@ -35,7 +35,6 @@ else:
class
AlbertModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
AlbertModel
,
AlbertForMaskedLM
)
if
is_torch_available
()
else
()
test_pruning
=
False
test_head_masking
=
False
class
AlbertModelTester
(
object
):
...
...
@@ -49,9 +48,10 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
hidden_size
=
36
,
num_hidden_layers
=
6
,
num_hidden_groups
=
6
,
num_attention_heads
=
6
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
...
...
@@ -86,6 +86,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
num_hidden_groups
=
num_hidden_groups
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -117,7 +118,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
,
num_hidden_groups
=
self
.
num_hidden_groups
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment