Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bd7edf36
Unverified
Commit
bd7edf36
authored
May 19, 2020
by
chicm-ms
Committed by
GitHub
May 19, 2020
Browse files
fix speedup issue (#2447)
parent
f8627a2f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
7 deletions
+27
-7
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+2
-0
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+25
-7
No files found.
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
View file @
bd7edf36
...
...
@@ -163,9 +163,11 @@ class ModelSpeedup:
first, do mask/shape inference,
second, replace modules
"""
training
=
self
.
bound_model
.
training
_logger
.
info
(
"start to speed up the model"
)
_logger
.
info
(
"infer module masks..."
)
self
.
infer_modules_masks
()
_logger
.
info
(
"replace compressed modules..."
)
self
.
replace_compressed_modules
()
self
.
bound_model
.
train
(
training
)
_logger
.
info
(
"speedup done"
)
src/sdk/pynni/tests/test_model_speedup.py
View file @
bd7edf36
...
...
@@ -10,9 +10,11 @@ from torchvision.models.vgg import vgg16
from
torchvision.models.resnet
import
resnet18
from
unittest
import
TestCase
,
main
from
nni.compression.torch
import
L1FilterPruner
from
nni.compression.torch
import
L1FilterPruner
,
apply_compression_results
from
nni.compression.speedup.torch
import
ModelSpeedup
torch
.
manual_seed
(
0
)
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
...
...
@@ -58,7 +60,10 @@ class BigModel(torch.nn.Module):
x
=
self
.
fc3
(
x
)
return
x
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
)
SPARSITY
=
0.5
MODEL_FILE
,
MASK_FILE
=
'./11_model.pth'
,
'./l1_mask.pth'
def
prune_model_l1
(
model
):
config_list
=
[{
'sparsity'
:
SPARSITY
,
...
...
@@ -66,14 +71,14 @@ def prune_model_l1(model):
}]
pruner
=
L1FilterPruner
(
model
,
config_list
)
pruner
.
compress
()
pruner
.
export_model
(
model_path
=
'./11_model.pth'
,
mask_path
=
'./l1_mask.pth'
)
pruner
.
export_model
(
model_path
=
MODEL_FILE
,
mask_path
=
MASK_FILE
)
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
model
=
vgg16
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
'./l1_mask.pth'
)
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
MASK_FILE
)
ms
.
speedup_model
()
orig_model
=
vgg16
()
...
...
@@ -88,20 +93,33 @@ class SpeedupTestCase(TestCase):
def
test_speedup_bigmodel
(
self
):
prune_model_l1
(
BigModel
())
model
=
BigModel
()
apply_compression_results
(
model
,
MASK_FILE
,
'cpu'
)
model
.
eval
()
mask_out
=
model
(
dummy_input
)
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
1
,
28
,
28
),
'./l1_mask.pth'
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
)
ms
.
speedup_model
()
assert
model
.
training
model
.
eval
()
speedup_out
=
model
(
dummy_input
)
if
not
torch
.
allclose
(
mask_out
,
speedup_out
,
atol
=
1e-07
):
print
(
'input:'
,
dummy_input
.
size
(),
torch
.
abs
(
dummy_input
).
sum
((
2
,
3
)))
print
(
'mask_out:'
,
mask_out
)
print
(
'speedup_out:'
,
speedup_out
)
raise
RuntimeError
(
'model speedup inference result is incorrect!'
)
orig_model
=
BigModel
()
assert
model
.
training
assert
model
.
backbone2
.
conv1
.
out_channels
==
int
(
orig_model
.
backbone2
.
conv1
.
out_channels
*
SPARSITY
)
assert
model
.
backbone2
.
conv2
.
in_channels
==
int
(
orig_model
.
backbone2
.
conv2
.
in_channels
*
SPARSITY
)
assert
model
.
backbone2
.
conv2
.
out_channels
==
int
(
orig_model
.
backbone2
.
conv2
.
out_channels
*
SPARSITY
)
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
tearDown
(
self
):
os
.
remove
(
'./11_model.pth'
)
os
.
remove
(
'./l1_mask.pth'
)
os
.
remove
(
MODEL_FILE
)
os
.
remove
(
MASK_FILE
)
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment