Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
445bbfd2
Unverified
Commit
445bbfd2
authored
Aug 04, 2020
by
Ningxin Zheng
Committed by
GitHub
Aug 04, 2020
Browse files
Speedup enhancement (#2719)
parent
41312de5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
12 deletions
+29
-12
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
+8
-0
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
+18
-9
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+3
-3
No files found.
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
View file @
445bbfd2
...
@@ -141,6 +141,14 @@ class ModelSpeedup:
...
@@ -141,6 +141,14 @@ class ModelSpeedup:
"""
"""
for
module_name
,
mask
in
self
.
masks
.
items
():
for
module_name
,
mask
in
self
.
masks
.
items
():
_logger
.
debug
(
'Start mask inference from %s'
,
module_name
)
_logger
.
debug
(
'Start mask inference from %s'
,
module_name
)
if
module_name
not
in
self
.
torch_graph
.
name_to_node
:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
_logger
.
warning
(
'%s has mask, but not found in the traced graph, just skip it.'
,
module_name
)
continue
self
.
infer_module_mask
(
module_name
,
None
,
mask
=
mask
)
self
.
infer_module_mask
(
module_name
,
None
,
mask
=
mask
)
def
replace_compressed_modules
(
self
):
def
replace_compressed_modules
(
self
):
...
...
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
View file @
445bbfd2
...
@@ -222,6 +222,7 @@ infer_from_inshape = {
...
@@ -222,6 +222,7 @@ infer_from_inshape = {
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
@@ -241,7 +242,8 @@ infer_from_inshape = {
...
@@ -241,7 +242,8 @@ infer_from_inshape = {
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
)
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
)
}
}
"""
"""
...
@@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
...
@@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
return
module_masks
.
output_mask
return
module_masks
.
output_mask
# if alreay visited
# if alreay visited
assert
module_masks
.
input_mask
<=
mask
assert
module_masks
.
input_mask
<=
mask
if
module_masks
.
input_mask
==
mask
:
# It should be the same, we pass the masks by the reference(not the value),
return
None
# so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
module_masks
.
output_mask
return
module_masks
.
output_mask
...
@@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
...
@@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
"""
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
0
]
is
None
assert
module_masks
.
input_mask
is
None
if
module_masks
.
input_mask
is
not
None
:
assert
module_masks
.
input_mask
<=
mask
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
return
None
return
None
...
@@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
...
@@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
assert
module_masks
.
input_mask
is
None
# due to the cat operation, the same node may be
# accessed more than once
if
module_masks
.
input_mask
is
not
None
:
assert
module_masks
.
input_mask
<=
mask
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
output_cmask
=
CoarseMask
(
num_dim
=
2
)
output_cmask
=
CoarseMask
(
num_dim
=
2
)
index
=
[]
index
=
[]
...
@@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
...
@@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
The mask of its output tensor
The mask of its output tensor
"""
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
isinstance
(
mask
,
CoarseMask
)
# TODO: double check this assert, is it possible that a module is passed twice
if
module_masks
.
input_mask
is
not
None
:
if
module_masks
.
input_mask
is
not
None
:
# check if has a mask conflict
# check if has a mask conflict
assert
module_masks
.
input_mask
==
mask
assert
module_masks
.
input_mask
<=
mask
# No need to pass the mask again
return
None
# assert module_masks.input_mask is None, "A relu op can only be processed once"
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
...
...
src/sdk/pynni/tests/test_model_speedup.py
View file @
445bbfd2
...
@@ -145,18 +145,18 @@ class SpeedupTestCase(TestCase):
...
@@ -145,18 +145,18 @@ class SpeedupTestCase(TestCase):
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
test_speedup_integration
(
self
):
def
test_speedup_integration
(
self
):
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
]:
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'inception_v3'
]:
Model
=
getattr
(
models
,
model_name
)
Model
=
getattr
(
models
,
model_name
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
speedup_model
=
Model
().
to
(
device
)
net
.
eval
()
# this line is necessary
net
.
eval
()
# this line is necessary
speedup_model
.
eval
()
# random generate the prune config for the pruner
# random generate the prune config for the pruner
cfgs
=
generate_random_sparsity
(
net
)
cfgs
=
generate_random_sparsity
(
net
)
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
.
compress
()
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
speedup_model
=
Model
().
to
(
device
)
speedup_model
.
eval
()
state_dict
=
torch
.
load
(
MODEL_FILE
)
state_dict
=
torch
.
load
(
MODEL_FILE
)
speedup_model
.
load_state_dict
(
state_dict
)
speedup_model
.
load_state_dict
(
state_dict
)
zero_bn_bias
(
net
)
zero_bn_bias
(
net
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment