Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
286d5bb6
"...resnet50_tensorflow.git" did not exist on "580aa3f6c0864d778c822066cdccbba7f40da100"
Commit
286d5bb6
authored
Dec 20, 2019
by
Aymeric Augustin
Browse files
Use a random temp dir for writing pruned models in tests.
parent
478e456e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
15 deletions
+9
-15
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+9
-15
No files found.
transformers/tests/modeling_common_test.py
View file @
286d5bb6
...
@@ -353,12 +353,11 @@ class CommonTestCases:
...
@@ -353,12 +353,11 @@ class CommonTestCases:
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
model_tester
.
num_attention_heads
)),
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
model_tester
.
num_attention_heads
)),
-
1
:
[
0
]}
-
1
:
[
0
]}
model
.
prune_heads
(
heads_to_prune
)
model
.
prune_heads
(
heads_to_prune
)
directory
=
"pruned_model"
if
not
os
.
path
.
exists
(
directory
):
with
TemporaryDirectory
()
as
temp_dir_name
:
os
.
makedirs
(
directory
)
model
.
save_pretrained
(
temp_dir_name
)
model
.
save_pretrained
(
directory
)
model
=
model_class
.
from_pretrained
(
temp_dir_name
)
model
=
model_class
.
from_pretrained
(
directory
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs_dict
)
outputs
=
model
(
**
inputs_dict
)
...
@@ -367,7 +366,6 @@ class CommonTestCases:
...
@@ -367,7 +366,6 @@ class CommonTestCases:
self
.
assertEqual
(
attentions
[
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
-
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
-
1
)
self
.
assertEqual
(
attentions
[
-
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
-
1
)
shutil
.
rmtree
(
directory
)
def
test_head_pruning_save_load_from_config_init
(
self
):
def
test_head_pruning_save_load_from_config_init
(
self
):
if
not
self
.
test_pruning
:
if
not
self
.
test_pruning
:
...
@@ -427,14 +425,10 @@ class CommonTestCases:
...
@@ -427,14 +425,10 @@ class CommonTestCases:
self
.
assertEqual
(
attentions
[
2
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
2
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
3
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
3
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
directory
=
"pruned_model"
with
TemporaryDirectory
()
as
temp_dir_name
:
model
.
save_pretrained
(
temp_dir_name
)
if
not
os
.
path
.
exists
(
directory
):
model
=
model_class
.
from_pretrained
(
temp_dir_name
)
os
.
makedirs
(
directory
)
model
.
to
(
torch_device
)
model
.
save_pretrained
(
directory
)
model
=
model_class
.
from_pretrained
(
directory
)
model
.
to
(
torch_device
)
shutil
.
rmtree
(
directory
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs_dict
)
outputs
=
model
(
**
inputs_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment