Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2ee76dc7
Unverified
Commit
2ee76dc7
authored
Sep 17, 2020
by
Ningxin Zheng
Committed by
GitHub
Sep 17, 2020
Browse files
Add the input channel dependency pruning. (#2865)
parent
062b037f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
106 additions
and
3 deletions
+106
-3
src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
...sdk/pynni/nni/compression/torch/utils/shape_dependency.py
+106
-3
No files found.
src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
View file @
2ee76dc7
...
...
@@ -4,13 +4,13 @@
import
csv
import
logging
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPaddingDependency'
]
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPaddingDependency'
,
'InputChannelDependency'
]
CONV_TYPE
=
'aten::_convolution'
ADD_TYPES
=
[
'aten::add'
,
'aten::add_'
]
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
class
Dependency
:
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
...
...
@@ -37,7 +37,7 @@ class Dependency:
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
This model analyze the channel dependencis between the conv
This model analyze the channel dependenci
e
s between the conv
layers in a model.
Parameters
...
...
@@ -185,6 +185,109 @@ class ChannelDependency(Dependency):
d_sets
.
append
(
tmp_set
)
return
d_sets
def
reshape_break_channel_dependency
(
op_node
):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
in_channel
=
in_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
class
InputChannelDependency
(
ChannelDependency
):
"""
Some pruners may prune the input channel of the convolutional
layers. While pruning the input channel of the convolutional layers,
the layers that share the same input tensor should prune the same
channels, and we say these layers that share the same input tensor/channel
has the input channel dependency. If we only prune the input channel of one
layer in the dependency set, there will be a shape conflict for the other
layers in the same dependency set, which may trigger a runtime error.
Here we judge whether the application will truncate the dependency by analyzing
whether the number of channels before and after the operation has changed.
If not, the input channel dependency will be passed to the following nodes.
"""
def
__init__
(
self
,
model
,
dummy_input
=
None
,
traced_model
=
None
):
"""
This model analyze the input channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super
(
InputChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_following_convs
(
self
,
tensor
):
queue
=
[]
key_layers
=
[]
queue
.
extend
(
self
.
graph
.
input_to_node
[
tensor
])
while
queue
:
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
:
# find the first met conv
key_layers
.
append
(
curnode
.
name
)
continue
elif
curnode
.
op_type
in
RESHAPE_OPS
:
# check if the reshape operation will break the channel dependency
if
reshape_break_channel_dependency
(
curnode
):
# reshape operations also breaks the dependency relationship
continue
successors
=
self
.
graph
.
find_successors
(
curnode
.
unique_name
)
successors
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
successors
]
for
layer
in
successors
:
queue
.
append
(
layer
)
return
key_layers
def
build_dependency
(
self
):
"""
Build the input channel dependencies.
The `InputChannelDependency` indicates the layers that have
dependencies when pruning the input channel of the conv layers.
In contrast, `ChannelDependency` indicates the dependent layers
when pruning the output channles of conv layers (for example, L1FilterPruner).
"""
# unpack the tuple or list manually
self
.
graph
.
unpack_manually
()
for
tensor
in
self
.
graph
.
input_to_node
:
# start from this tensor, find all the conv layers that
# take this tensor as input. Similar to the `ChannelDependency`
# the conv layer will truncate the dependencies
layers
=
self
.
_get_following_convs
(
tensor
)
dependency_set
=
set
(
layers
)
for
layer
in
layers
:
if
layer
in
self
.
dependency
:
dependency_set
.
update
(
self
.
dependency
[
layer
])
for
layer
in
dependency_set
:
self
.
dependency
[
layer
]
=
dependency_set
class
CatPaddingDependency
(
ChannelDependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
(
CatPaddingDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment