Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5d63ca6c
Unverified
Commit
5d63ca6c
authored
Jun 10, 2020
by
Amil Khare
Committed by
GitHub
Jun 10, 2020
Browse files
[ctrl] fix pruning of MultiHeadAttention (#4904)
parent
4e10acb3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
3 deletions
+21
-3
src/transformers/modeling_ctrl.py
src/transformers/modeling_ctrl.py
+20
-2
tests/test_modeling_ctrl.py
tests/test_modeling_ctrl.py
+1
-1
No files found.
src/transformers/modeling_ctrl.py
View file @
5d63ca6c
...
...
@@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss
from
.configuration_ctrl
import
CTRLConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
Conv1D
,
PreTrainedModel
from
.modeling_utils
import
Conv1D
,
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -95,6 +95,24 @@ class MultiHeadAttention(torch.nn.Module):
self
.
Wv
=
torch
.
nn
.
Linear
(
d_model_size
,
d_model_size
)
self
.
dense
=
torch
.
nn
.
Linear
(
d_model_size
,
d_model_size
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
d_model_size
//
self
.
num_heads
if
len
(
heads
)
==
0
:
return
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
,
self
.
num_heads
,
attention_head_size
,
self
.
pruned_heads
)
# Prune linear layers
self
.
Wq
=
prune_linear_layer
(
self
.
Wq
,
index
)
self
.
Wk
=
prune_linear_layer
(
self
.
Wk
,
index
)
self
.
Wv
=
prune_linear_layer
(
self
.
Wv
,
index
)
self
.
dense
=
prune_linear_layer
(
self
.
dense
,
index
,
dim
=
1
)
# Update hyper params
self
.
num_heads
=
self
.
num_heads
-
len
(
heads
)
self
.
d_model_size
=
attention_head_size
*
self
.
num_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
split_into_heads
(
self
,
x
,
batch_size
):
x
=
x
.
reshape
(
batch_size
,
-
1
,
self
.
num_heads
,
self
.
depth
)
...
...
@@ -306,7 +324,7 @@ class CTRLModel(CTRLPreTrainedModel):
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
att
n
.
prune_heads
(
heads
)
self
.
h
[
layer
].
multi_head_attentio
n
.
prune_heads
(
heads
)
@
add_start_docstrings_to_callable
(
CTRL_INPUTS_DOCSTRING
)
def
forward
(
...
...
tests/test_modeling_ctrl.py
View file @
5d63ca6c
...
...
@@ -32,7 +32,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes
=
(
CTRLModel
,
CTRLLMHeadModel
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
CTRLLMHeadModel
,)
if
is_torch_available
()
else
()
test_pruning
=
Fals
e
test_pruning
=
Tru
e
test_torchscript
=
False
test_resize_embeddings
=
False
test_head_masking
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment