Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3836689f
Unverified
Commit
3836689f
authored
Mar 04, 2022
by
Ningxin Zheng
Committed by
GitHub
Mar 04, 2022
Browse files
issue 4540 (#4594)
parent
21abc280
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
4 deletions
+54
-4
nni/compression/pytorch/speedup/infer_mask.py
nni/compression/pytorch/speedup/infer_mask.py
+7
-3
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+7
-1
test/ut/compression/v1/test_model_speedup.py
test/ut/compression/v1/test_model_speedup.py
+40
-0
No files found.
nni/compression/pytorch/speedup/infer_mask.py
View file @
3836689f
...
@@ -171,10 +171,14 @@ class AutoMaskInference:
...
@@ -171,10 +171,14 @@ class AutoMaskInference:
# apply the input mask
# apply the input mask
for
tid
,
in_tensor
in
enumerate
(
self
.
dummy_input
):
for
tid
,
in_tensor
in
enumerate
(
self
.
dummy_input
):
if
isinstance
(
in_tensor
,
torch
.
Tensor
)
and
self
.
in_masks
[
tid
]
is
not
None
:
if
isinstance
(
in_tensor
,
torch
.
Tensor
)
and
self
.
in_masks
[
tid
]
is
not
None
:
# in_tensor.data = in_tensor.data * \
# self.in_masks[tid] + \
# (1-self.in_masks[tid]) * self.in_constants[tid]
# issue-4540 when two tensors are multiplied, the constants part make
# the propagation weaker, and lead to shape misaligment. Currently, we
# donnot support the constant folding, so, we just remove the constant here
in_tensor
.
data
=
in_tensor
.
data
*
\
in_tensor
.
data
=
in_tensor
.
data
*
\
self
.
in_masks
[
tid
]
+
\
self
.
in_masks
[
tid
]
(
1
-
self
.
in_masks
[
tid
])
*
self
.
in_constants
[
tid
]
def
__apply_weight_mask
(
self
):
def
__apply_weight_mask
(
self
):
"""
"""
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
3836689f
...
@@ -163,7 +163,13 @@ class ChannelDependency(Dependency):
...
@@ -163,7 +163,13 @@ class ChannelDependency(Dependency):
parent_layers
=
[]
parent_layers
=
[]
# find the node that contains aten::add
# find the node that contains aten::add
# or aten::cat operations
# or aten::cat operations
if
node
.
op_type
in
ADD_TYPES
:
if
node
.
op_type
in
ADD_TYPES
or
node
.
op_type
in
MUL_TYPES
:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input
# tensor is from skip connection(residual), the channel propagation
# may be failed(the input is also used by another layer and cannot be
# pruned), in this case, we need to fix the conflict maunally.
parent_layers
=
self
.
_get_parent_layers
(
node
)
parent_layers
=
self
.
_get_parent_layers
(
node
)
elif
node
.
op_type
==
CAT_TYPE
:
elif
node
.
op_type
==
CAT_TYPE
:
# To determine if this cat operation will introduce channel
# To determine if this cat operation will introduce channel
...
...
test/ut/compression/v1/test_model_speedup.py
View file @
3836689f
...
@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase):
...
@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase):
print
(
"Fine-grained speeduped model"
)
print
(
"Fine-grained speeduped model"
)
print
(
model
)
print
(
model
)
def
test_multiplication_speedup
(
self
):
"""
Model from issue 4540.
"""
class
Net
(
torch
.
nn
.
Module
):
def
__init__
(
self
,):
super
(
Net
,
self
).
__init__
()
self
.
avgpool
=
torch
.
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
input
=
torch
.
nn
.
Conv2d
(
3
,
8
,
3
)
self
.
bn
=
torch
.
nn
.
BatchNorm2d
(
8
)
self
.
fc1
=
torch
.
nn
.
Conv2d
(
8
,
16
,
1
)
self
.
fc2
=
torch
.
nn
.
Conv2d
(
16
,
8
,
1
)
self
.
activation
=
torch
.
nn
.
ReLU
()
self
.
scale_activation
=
torch
.
nn
.
Hardsigmoid
()
self
.
out
=
torch
.
nn
.
Conv2d
(
8
,
12
,
1
)
def
forward
(
self
,
input
):
input
=
self
.
activation
(
self
.
bn
(
self
.
input
(
input
)))
scale
=
self
.
avgpool
(
input
)
out1
=
self
.
activation
(
self
.
fc1
(
scale
))
out1
=
self
.
scale_activation
(
self
.
fc2
(
out1
))
return
self
.
out
(
out1
*
input
)
model
=
Net
().
to
(
device
)
model
.
eval
()
im
=
torch
.
ones
(
1
,
3
,
512
,
512
).
to
(
device
)
model
(
im
)
cfg_list
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
cfg_list
.
append
({
'op_types'
:[
'Conv2d'
],
'sparsity'
:
0.3
,
'op_names'
:[
name
]})
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
ms
=
ModelSpeedup
(
model
,
im
,
MASK_FILE
)
ms
.
speedup_model
()
def
tearDown
(
self
):
def
tearDown
(
self
):
if
os
.
path
.
exists
(
MODEL_FILE
):
if
os
.
path
.
exists
(
MODEL_FILE
):
os
.
remove
(
MODEL_FILE
)
os
.
remove
(
MODEL_FILE
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment