Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0aea0a56
Unverified
Commit
0aea0a56
authored
Aug 03, 2021
by
Ningxin Zheng
Committed by
GitHub
Aug 03, 2021
Browse files
Support Speedup for Slim Pruner. (#4008)
parent
d8e56857
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
11 deletions
+55
-11
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+37
-6
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+18
-5
No files found.
nni/compression/pytorch/utils/mask_conflict.py
View file @
0aea0a56
...
@@ -184,14 +184,10 @@ class ChannelMaskConflict(MaskFix):
...
@@ -184,14 +184,10 @@ class ChannelMaskConflict(MaskFix):
super
(
ChannelMaskConflict
,
self
).
__init__
(
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
self
.
channel_prune_type
=
detect_channel_prune_type
(
masks
,
model
)
_logger
.
info
(
'Dectected conv prune dim" %d'
,
self
.
conv_prune_dim
)
_logger
.
info
(
'Dectected conv prune dim" %d'
,
self
.
conv_prune_dim
)
def
fix_mask
(
self
):
def
fix_mask
(
self
):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
"""
"""
Fix the mask conflict before the mask inference for the layers that
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
has shape dependencies. This function should be called before the
...
@@ -200,7 +196,8 @@ class ChannelMaskConflict(MaskFix):
...
@@ -200,7 +196,8 @@ class ChannelMaskConflict(MaskFix):
"""
"""
if
self
.
conv_prune_dim
==
0
:
if
self
.
conv_prune_dim
==
0
:
channel_depen
=
ChannelDependency
(
channel_depen
=
ChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
self
.
model
,
self
.
dummy_input
,
self
.
traced
,
self
.
channel_prune_type
)
else
:
else
:
channel_depen
=
InputChannelDependency
(
channel_depen
=
InputChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
...
@@ -307,10 +304,44 @@ class ChannelMaskConflict(MaskFix):
...
@@ -307,10 +304,44 @@ class ChannelMaskConflict(MaskFix):
return
self
.
masks
return
self
.
masks
def
detect_channel_prune_type
(
masks
,
model
):
"""
User can prune a channel through two ways: 1) prune
the corresponding filter of the conv layer(all the
filter related pruner), 2) prune the BN layers that
followed after a conv(Slim pruner). This function find
the pruning type of the masks.
Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.
Returns:
-------
prune_type: str
Could be Filter or Batchnorm
"""
prune_type
=
'Filter'
all_batch_norm
=
True
for
layer_name
in
masks
:
_
,
m
=
get_module_by_name
(
model
,
layer_name
)
if
m
is
None
or
(
not
isinstance
(
m
,
torch
.
nn
.
BatchNorm2d
)):
all_batch_norm
=
False
break
if
all_batch_norm
:
# if all masks are for batchnorm layers, then the prune_type is BatchNorm
# Note, actually we currently do not support pruning both Conv and BatchNorm
# at the same time.
prune_type
=
'Batchnorm'
return
prune_type
def
detect_mask_prune_dim
(
masks
,
model
):
def
detect_mask_prune_dim
(
masks
,
model
):
"""
"""
Detect how the masks of convolutional layers are pruned.
Detect how the masks of convolutional layers are pruned.
Parameters
Parameters
----------
----------
masks: dict
masks: dict
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
0aea0a56
...
@@ -85,7 +85,7 @@ def reshape_break_channel_dependency(op_node):
...
@@ -85,7 +85,7 @@ def reshape_break_channel_dependency(op_node):
class
ChannelDependency
(
Dependency
):
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
,
prune_type
=
'Filter'
):
"""
"""
This model analyze the channel dependencies between the conv
This model analyze the channel dependencies between the conv
layers in a model.
layers in a model.
...
@@ -98,7 +98,18 @@ class ChannelDependency(Dependency):
...
@@ -98,7 +98,18 @@ class ChannelDependency(Dependency):
traced_model : torch._C.Graph
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
need to trace the model again.
"""
prune_type: str
This parameter indicates the channel pruning type: 1) `Filter`
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
self
.
prune_type
=
prune_type
self
.
target_types
=
[]
if
self
.
prune_type
==
'Filter'
:
self
.
target_types
.
extend
([
'Conv2d'
,
'Linear'
,
'ConvTranspose2d'
])
elif
self
.
prune_type
==
'Batchnorm'
:
self
.
target_types
.
append
(
'BatchNorm2d'
)
super
(
ChannelDependency
,
self
).
__init__
(
super
(
ChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
model
,
dummy_input
,
traced_model
)
...
@@ -114,12 +125,13 @@ class ChannelDependency(Dependency):
...
@@ -114,12 +125,13 @@ class ChannelDependency(Dependency):
parent_layers: list
parent_layers: list
nearest father conv/linear layers for the target worknode.
nearest father conv/linear layers for the target worknode.
"""
"""
parent_layers
=
[]
parent_layers
=
[]
queue
=
[]
queue
=
[]
queue
.
append
(
node
)
queue
.
append
(
node
)
while
queue
:
while
queue
:
curnode
=
queue
.
pop
(
0
)
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
if
curnode
.
op_type
in
self
.
target_types
:
# find the first met conv
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
parent_layers
.
append
(
curnode
.
name
)
continue
continue
...
@@ -130,6 +142,7 @@ class ChannelDependency(Dependency):
...
@@ -130,6 +142,7 @@ class ChannelDependency(Dependency):
parents
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
parents
]
parents
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
parents
]
for
parent
in
parents
:
for
parent
in
parents
:
queue
.
append
(
parent
)
queue
.
append
(
parent
)
return
parent_layers
return
parent_layers
def
build_dependency
(
self
):
def
build_dependency
(
self
):
...
@@ -193,7 +206,7 @@ class ChannelDependency(Dependency):
...
@@ -193,7 +206,7 @@ class ChannelDependency(Dependency):
csv_w
=
csv
.
writer
(
csvf
,
delimiter
=
','
)
csv_w
=
csv
.
writer
(
csvf
,
delimiter
=
','
)
csv_w
.
writerow
(
header
)
csv_w
.
writerow
(
header
)
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
!=
'Conv2d'
or
node
in
visited
:
if
node
.
op_type
not
in
self
.
target_types
or
node
in
visited
:
continue
continue
setid
+=
1
setid
+=
1
row
=
[
'Set %d'
%
setid
]
row
=
[
'Set %d'
%
setid
]
...
@@ -220,7 +233,7 @@ class ChannelDependency(Dependency):
...
@@ -220,7 +233,7 @@ class ChannelDependency(Dependency):
d_sets
=
[]
d_sets
=
[]
visited
=
set
()
visited
=
set
()
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
(
node
.
op_type
!=
'Conv2d'
and
node
.
op_type
!=
'Linear'
)
or
node
in
visited
:
if
node
.
op_type
not
in
self
.
target_types
or
node
in
visited
:
continue
continue
tmp_set
=
set
()
tmp_set
=
set
()
if
node
.
name
not
in
self
.
dependency
:
if
node
.
name
not
in
self
.
dependency
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment