Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
719cb373
Commit
719cb373
authored
Aug 21, 2019
by
LysandreJik
Browse files
Pruning for GPT and GPT-2
parent
fc1fbae4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
5 deletions
+24
-5
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+6
-0
pytorch_transformers/modeling_openai.py
pytorch_transformers/modeling_openai.py
+6
-0
pytorch_transformers/tests/modeling_common_test.py
pytorch_transformers/tests/modeling_common_test.py
+12
-5
No files found.
pytorch_transformers/modeling_gpt2.py
View file @
719cb373
...
...
@@ -453,6 +453,12 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
if
hasattr
(
config
,
"pruned_heads"
):
pruned_heads
=
config
.
pruned_heads
.
copy
().
items
()
for
layer
,
heads
in
pruned_heads
:
if
self
.
h
[
int
(
layer
)].
attn
.
n_head
==
config
.
n_head
:
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
apply
(
self
.
init_weights
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
...
...
pytorch_transformers/modeling_openai.py
View file @
719cb373
...
...
@@ -456,6 +456,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
if
hasattr
(
config
,
"pruned_heads"
):
pruned_heads
=
config
.
pruned_heads
.
copy
().
items
()
for
layer
,
heads
in
pruned_heads
:
if
self
.
h
[
int
(
layer
)].
attn
.
n_head
==
config
.
n_head
:
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
apply
(
self
.
init_weights
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
...
...
pytorch_transformers/tests/modeling_common_test.py
View file @
719cb373
...
...
@@ -213,13 +213,12 @@ class CommonTestCases:
if
not
self
.
test_pruning
:
return
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
"head_mask"
in
inputs_dict
:
del
inputs_dict
[
"head_mask"
]
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
"head_mask"
in
inputs_dict
:
del
inputs_dict
[
"head_mask"
]
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
model
=
model_class
(
config
=
config
)
...
...
@@ -244,6 +243,10 @@ class CommonTestCases:
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
"head_mask"
in
inputs_dict
:
del
inputs_dict
[
"head_mask"
]
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
model
=
model_class
(
config
=
config
)
...
...
@@ -274,6 +277,10 @@ class CommonTestCases:
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
"head_mask"
in
inputs_dict
:
del
inputs_dict
[
"head_mask"
]
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment