Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
47c7ea14
Unverified
Commit
47c7ea14
authored
Jul 15, 2021
by
Ningxin Zheng
Committed by
GitHub
Jul 15, 2021
Browse files
Add several speedup examples (#3880)
parent
5fe24500
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
109 additions
and
2 deletions
+109
-2
examples/model_compress/pruning/speedup/model_speedup.py
examples/model_compress/pruning/speedup/model_speedup.py
+1
-1
examples/model_compress/pruning/speedup/speedup_mobilnetv2.py
...ples/model_compress/pruning/speedup/speedup_mobilnetv2.py
+21
-0
examples/model_compress/pruning/speedup/speedup_nanodet.py
examples/model_compress/pruning/speedup/speedup_nanodet.py
+39
-0
examples/model_compress/pruning/speedup/speedup_yolov3.py
examples/model_compress/pruning/speedup/speedup_yolov3.py
+36
-0
nni/compression/pytorch/utils/utils.py
nni/compression/pytorch/utils/utils.py
+8
-1
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+4
-0
No files found.
examples/model_compress/pruning/model_speedup.py
→
examples/model_compress/pruning/
speedup/
model_speedup.py
View file @
47c7ea14
...
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from
torchvision
import
datasets
,
transforms
import
sys
sys
.
path
.
append
(
'../models'
)
sys
.
path
.
append
(
'../
../
models'
)
from
cifar10.vgg
import
VGG
from
mnist.lenet
import
LeNet
...
...
examples/model_compress/pruning/speedup/speedup_mobilnetv2.py
0 → 100644
View file @
47c7ea14
import
torch
from
torchvision.models
import
mobilenet_v2
from
nni.compression.pytorch
import
ModelSpeedup
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
model
=
mobilenet_v2
(
pretrained
=
True
)
dummy_input
=
torch
.
rand
(
8
,
3
,
416
,
416
)
cfg_list
=
[{
'op_types'
:[
'Conv2d'
],
'sparsity'
:
0.5
}]
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
pruner
.
export_model
(
'./model'
,
'./mask'
)
# need call _unwrap_model if you want run the speedup on the same model
pruner
.
_unwrap_model
()
# Speedup the nanodet
ms
=
ModelSpeedup
(
model
,
dummy_input
,
'./mask'
)
ms
.
speedup_model
()
model
(
dummy_input
)
\ No newline at end of file
examples/model_compress/pruning/speedup/speedup_nanodet.py
0 → 100644
View file @
47c7ea14
import
torch
from
nanodet.model.arch
import
build_model
from
nanodet.util
import
cfg
,
load_config
from
nni.compression.pytorch
import
ModelSpeedup
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
"""
NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git
"""
cfg_path
=
r
"nanodet/config/nanodet-RepVGG-A0_416.yml"
load_config
(
cfg
,
cfg_path
)
model
=
build_model
(
cfg
.
model
)
dummy_input
=
torch
.
rand
(
8
,
3
,
416
,
416
)
op_names
=
[]
# these three conv layers are followed by reshape-like functions
# that cannot be replaced, so we skip these three conv layers,
# you can also get such layers by `not_safe_to_prune` function
excludes
=
[
'head.gfl_cls.0'
,
'head.gfl_cls.1'
,
'head.gfl_cls.2'
]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
if
name
not
in
excludes
:
op_names
.
append
(
name
)
cfg_list
=
[{
'op_types'
:[
'Conv2d'
],
'sparsity'
:
0.5
,
'op_names'
:
op_names
}]
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
pruner
.
export_model
(
'./model'
,
'./mask'
)
# need call _unwrap_model if you want run the speedup on the same model
pruner
.
_unwrap_model
()
# Speedup the nanodet
ms
=
ModelSpeedup
(
model
,
dummy_input
,
'./mask'
)
ms
.
speedup_model
()
model
(
dummy_input
)
\ No newline at end of file
examples/model_compress/pruning/speedup/speedup_yolov3.py
0 → 100644
View file @
47c7ea14
import
torch
from
pytorchyolo
import
models
from
nni.compression.pytorch
import
ModelSpeedup
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
,
LevelPruner
from
nni.compression.pytorch.utils
import
not_safe_to_prune
# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git
prefix
=
'/home/user/PyTorch-YOLOv3'
# replace this path with yours
# Load the YOLO model
model
=
models
.
load_model
(
"%s/config/yolov3.cfg"
%
prefix
,
"%s/yolov3.weights"
%
prefix
)
model
.
eval
()
dummy_input
=
torch
.
rand
(
8
,
3
,
320
,
320
)
model
(
dummy_input
)
# Generate the config list for pruner
# Filter the layers that may not be able to prune
not_safe
=
not_safe_to_prune
(
model
,
dummy_input
)
cfg_list
=
[]
for
name
,
module
in
model
.
named_modules
():
if
name
in
not_safe
:
continue
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
cfg_list
.
append
({
'op_types'
:[
'Conv2d'
],
'sparsity'
:
0.6
,
'op_names'
:[
name
]})
# Prune the model
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
pruner
.
export_model
(
'./model'
,
'./mask'
)
pruner
.
_unwrap_model
()
# Speedup the model
ms
=
ModelSpeedup
(
model
,
dummy_input
,
'./mask'
)
ms
.
speedup_model
()
model
(
dummy_input
)
nni/compression/pytorch/utils/utils.py
View file @
47c7ea14
...
...
@@ -70,7 +70,14 @@ def randomize_tensor(tensor, start=1, end=100):
def
not_safe_to_prune
(
model
,
dummy_input
):
"""
Get the layers that are safe to prune(will not bring the shape conflict).
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters
----------
...
...
test/ut/sdk/test_model_speedup.py
View file @
47c7ea14
...
...
@@ -361,6 +361,10 @@ class SpeedupTestCase(TestCase):
self
.
speedup_integration
(
model_list
)
def
speedup_integration
(
self
,
model_list
,
speedup_cfg
=
None
):
# Note: hack trick, may be updated in the future
if
'win'
in
sys
.
platform
or
'Win'
in
sys
.
platform
:
print
(
'Skip test_speedup_integration on windows due to memory limit!'
)
return
Gen_cfg_funcs
=
[
generate_random_sparsity
,
generate_random_sparsity_v2
]
# for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121',
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment