Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
445bbfd2
Unverified
Commit
445bbfd2
authored
Aug 04, 2020
by
Ningxin Zheng
Committed by
GitHub
Aug 04, 2020
Browse files
Speedup enhancement (#2719)
parent
41312de5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
12 deletions
+29
-12
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
+8
-0
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
+18
-9
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+3
-3
No files found.
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
View file @
445bbfd2
...
@@ -141,6 +141,14 @@ class ModelSpeedup:
...
@@ -141,6 +141,14 @@ class ModelSpeedup:
"""
"""
for
module_name
,
mask
in
self
.
masks
.
items
():
for
module_name
,
mask
in
self
.
masks
.
items
():
_logger
.
debug
(
'Start mask inference from %s'
,
module_name
)
_logger
.
debug
(
'Start mask inference from %s'
,
module_name
)
if
module_name
not
in
self
.
torch_graph
.
name_to_node
:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
_logger
.
warning
(
'%s has mask, but not found in the traced graph, just skip it.'
,
module_name
)
continue
self
.
infer_module_mask
(
module_name
,
None
,
mask
=
mask
)
self
.
infer_module_mask
(
module_name
,
None
,
mask
=
mask
)
def
replace_compressed_modules
(
self
):
def
replace_compressed_modules
(
self
):
...
...
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
View file @
445bbfd2
...
@@ -222,6 +222,7 @@ infer_from_inshape = {
...
@@ -222,6 +222,7 @@ infer_from_inshape = {
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
@@ -241,7 +242,8 @@ infer_from_inshape = {
...
@@ -241,7 +242,8 @@ infer_from_inshape = {
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
)
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
)
}
}
"""
"""
...
@@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
...
@@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
return
module_masks
.
output_mask
return
module_masks
.
output_mask
# if alreay visited
# if alreay visited
assert
module_masks
.
input_mask
<=
mask
assert
module_masks
.
input_mask
<=
mask
if
module_masks
.
input_mask
==
mask
:
# It should be the same, we pass the masks by the reference(not the value),
return
None
# so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
module_masks
.
output_mask
return
module_masks
.
output_mask
...
@@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
...
@@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
"""
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
0
]
is
None
assert
module_masks
.
input_mask
is
None
if
module_masks
.
input_mask
is
not
None
:
assert
module_masks
.
input_mask
<=
mask
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
return
None
return
None
...
@@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
...
@@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
assert
module_masks
.
input_mask
is
None
# due to the cat operation, the same node may be
# accessed more than once
if
module_masks
.
input_mask
is
not
None
:
assert
module_masks
.
input_mask
<=
mask
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
output_cmask
=
CoarseMask
(
num_dim
=
2
)
output_cmask
=
CoarseMask
(
num_dim
=
2
)
index
=
[]
index
=
[]
...
@@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
...
@@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
The mask of its output tensor
The mask of its output tensor
"""
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
isinstance
(
mask
,
CoarseMask
)
# TODO: double check this assert, is it possible that a module is passed twice
if
module_masks
.
input_mask
is
not
None
:
if
module_masks
.
input_mask
is
not
None
:
# check if has a mask conflict
# check if has a mask conflict
assert
module_masks
.
input_mask
==
mask
assert
module_masks
.
input_mask
<=
mask
# No need to pass the mask again
return
None
# assert module_masks.input_mask is None, "A relu op can only be processed once"
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
...
...
src/sdk/pynni/tests/test_model_speedup.py
View file @
445bbfd2
...
@@ -145,18 +145,18 @@ class SpeedupTestCase(TestCase):
...
@@ -145,18 +145,18 @@ class SpeedupTestCase(TestCase):
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
test_speedup_integration
(
self
):
def
test_speedup_integration
(
self
):
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
]:
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'inception_v3'
]:
Model
=
getattr
(
models
,
model_name
)
Model
=
getattr
(
models
,
model_name
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
speedup_model
=
Model
().
to
(
device
)
net
.
eval
()
# this line is necessary
net
.
eval
()
# this line is necessary
speedup_model
.
eval
()
# random generate the prune config for the pruner
# random generate the prune config for the pruner
cfgs
=
generate_random_sparsity
(
net
)
cfgs
=
generate_random_sparsity
(
net
)
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
.
compress
()
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
speedup_model
=
Model
().
to
(
device
)
speedup_model
.
eval
()
state_dict
=
torch
.
load
(
MODEL_FILE
)
state_dict
=
torch
.
load
(
MODEL_FILE
)
speedup_model
.
load_state_dict
(
state_dict
)
speedup_model
.
load_state_dict
(
state_dict
)
zero_bn_bias
(
net
)
zero_bn_bias
(
net
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment