Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0a8fbbed
Unverified
Commit
0a8fbbed
authored
Sep 01, 2021
by
Ningxin Zheng
Committed by
GitHub
Sep 01, 2021
Browse files
Loose the group dependency (#4128)
parent
d204d8bf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
67 additions
and
42 deletions
+67
-42
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
...ms/compression/pytorch/pruning/dependency_aware_pruner.py
+2
-3
nni/compression/pytorch/utils/__init__.py
nni/compression/pytorch/utils/__init__.py
+23
-1
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+3
-3
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+20
-12
nni/compression/pytorch/utils/utils.py
nni/compression/pytorch/utils/utils.py
+1
-21
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+18
-2
No files found.
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
View file @
0a8fbbed
...
@@ -44,9 +44,8 @@ class DependencyAwarePruner(Pruner):
...
@@ -44,9 +44,8 @@ class DependencyAwarePruner(Pruner):
self
.
_unwrap_model
()
self
.
_unwrap_model
()
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
)
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
)
self
.
_wrap_model
()
self
.
_wrap_model
()
self
.
channel_depen
=
ChannelDependency
(
self
.
channel_depen
=
ChannelDependency
(
model
,
dummy_input
,
traced_model
=
self
.
graph
.
trace
)
traced_model
=
self
.
graph
.
trace
)
self
.
group_depen
=
GroupDependency
(
model
,
dummy_input
,
traced_model
=
self
.
graph
.
trace
)
self
.
group_depen
=
GroupDependency
(
traced_model
=
self
.
graph
.
trace
)
self
.
channel_depen
=
self
.
channel_depen
.
dependency_sets
self
.
channel_depen
=
self
.
channel_depen
.
dependency_sets
self
.
channel_depen
=
{
self
.
channel_depen
=
{
name
:
sets
for
sets
in
self
.
channel_depen
for
name
in
sets
}
name
:
sets
for
sets
in
self
.
channel_depen
for
name
in
sets
}
...
...
nni/compression/pytorch/utils/__init__.py
View file @
0a8fbbed
from
.utils
import
*
from
.utils
import
*
\ No newline at end of file
from
.shape_dependency
import
*
from
.shape_dependency
import
ReshapeDependency
def
not_safe_to_prune
(
model
,
dummy_input
):
"""
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset
=
ReshapeDependency
(
model
,
dummy_input
)
return
reshape_dset
.
dependency_sets
nni/compression/pytorch/utils/mask_conflict.py
View file @
0a8fbbed
...
@@ -10,7 +10,7 @@ from .utils import get_module_by_name
...
@@ -10,7 +10,7 @@ from .utils import get_module_by_name
_logger
=
logging
.
getLogger
(
'FixMaskConflict'
)
_logger
=
logging
.
getLogger
(
'FixMaskConflict'
)
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
fix_mask_conflict
(
masks
,
model
,
dummy_input
,
traced
=
None
):
"""
"""
MaskConflict fix the mask conflict for the channel dependencies
MaskConflict fix the mask conflict for the channel dependencies
and group dependency.
and group dependency.
...
@@ -81,7 +81,7 @@ class MaskFix:
...
@@ -81,7 +81,7 @@ class MaskFix:
class
GroupMaskConflict
(
MaskFix
):
class
GroupMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
"""
"""
GroupMaskConflict fix the mask conflict between the layers that
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
has group dependecy with each other.
...
@@ -168,7 +168,7 @@ class GroupMaskConflict(MaskFix):
...
@@ -168,7 +168,7 @@ class GroupMaskConflict(MaskFix):
class
ChannelMaskConflict
(
MaskFix
):
class
ChannelMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
"""
"""
ChannelMaskConflict fix the mask conflict between the layers that
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
has channel dependecy with each other.
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
0a8fbbed
...
@@ -3,10 +3,14 @@
...
@@ -3,10 +3,14 @@
import
csv
import
csv
import
logging
import
logging
import
torch
import
numpy
as
np
import
numpy
as
np
from
nni.compression.pytorch.compressor
import
PrunerModuleWrapper
from
.utils
import
get_module_by_name
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'InputChannelDependency'
,
'AttentionWeightDependency'
]
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'InputChannelDependency'
,
'AttentionWeightDependency'
]
CONV_TYPE
=
'aten::_convolution'
CONV_TYPE
=
'aten::_convolution'
...
@@ -45,6 +49,7 @@ class Dependency:
...
@@ -45,6 +49,7 @@ class Dependency:
# the model or a already traced model
# the model or a already traced model
assert
model
is
not
None
and
dummy_input
is
not
None
assert
model
is
not
None
and
dummy_input
is
not
None
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
,
traced_model
)
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
,
traced_model
)
self
.
model
=
model
self
.
dependency
=
dict
()
self
.
dependency
=
dict
()
self
.
build_dependency
()
self
.
build_dependency
()
...
@@ -85,7 +90,7 @@ def reshape_break_channel_dependency(op_node):
...
@@ -85,7 +90,7 @@ def reshape_break_channel_dependency(op_node):
class
ChannelDependency
(
Dependency
):
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
,
prune_type
=
'Filter'
):
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
,
prune_type
=
'Filter'
):
"""
"""
This model analyze the channel dependencies between the conv
This model analyze the channel dependencies between the conv
layers in a model.
layers in a model.
...
@@ -261,7 +266,7 @@ class InputChannelDependency(ChannelDependency):
...
@@ -261,7 +266,7 @@ class InputChannelDependency(ChannelDependency):
If not, the input channel dependency will be passed to the following nodes.
If not, the input channel dependency will be passed to the following nodes.
"""
"""
def
__init__
(
self
,
model
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
):
"""
"""
This model analyze the input channel dependencies between the conv
This model analyze the input channel dependencies between the conv
layers in a model.
layers in a model.
...
@@ -323,7 +328,7 @@ class InputChannelDependency(ChannelDependency):
...
@@ -323,7 +328,7 @@ class InputChannelDependency(ChannelDependency):
class
GroupDependency
(
Dependency
):
class
GroupDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
):
"""
"""
This model analyze the group dependencis between the conv
This model analyze the group dependencis between the conv
layers in a model.
layers in a model.
...
@@ -383,13 +388,17 @@ class GroupDependency(Dependency):
...
@@ -383,13 +388,17 @@ class GroupDependency(Dependency):
group : int
group : int
the number of the groups of the target conv layer.
the number of the groups of the target conv layer.
"""
"""
cpp_conv
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_name
=
node_group
.
name
CONV_TYPE
,
node_group
.
node_cpps
))
_
,
leaf_module
=
get_module_by_name
(
self
.
model
,
node_name
)
assert
len
(
cpp_conv
)
==
1
if
isinstance
(
leaf_module
,
PrunerModuleWrapper
):
cpp_conv
=
cpp_conv
[
0
]
leaf_module
=
leaf_module
.
module
inputs
=
list
(
cpp_conv
.
inputs
())
assert
isinstance
(
# get the number of the group from the input parameters
leaf_module
,
(
torch
.
nn
.
Conv2d
,
torch
.
nn
.
ConvTranspose2d
))
group
=
inputs
[
8
].
toIValue
()
group
=
leaf_module
.
groups
n_filter
=
leaf_module
.
out_channels
if
n_filter
==
group
:
# depthwise conv will not introduce extra group dependency
return
1
return
group
return
group
def
build_dependency
(
self
):
def
build_dependency
(
self
):
...
@@ -712,4 +721,3 @@ class AttentionWeightDependency(Dependency):
...
@@ -712,4 +721,3 @@ class AttentionWeightDependency(Dependency):
group
=
self
.
dependency
[
name
]
group
=
self
.
dependency
[
name
]
if
len
(
group
)
>
0
:
if
len
(
group
)
>
0
:
csv_w
.
writerow
([
name
,
group
])
csv_w
.
writerow
([
name
,
group
])
nni/compression/pytorch/utils/utils.py
View file @
0a8fbbed
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
torch
import
torch
from
.shape_dependency
import
ReshapeDependency
torch_float_dtype
=
[
torch
.
float
,
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
half
,
torch
.
double
]
torch_float_dtype
=
[
torch
.
float
,
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
half
,
torch
.
double
]
torch_integer_dtype
=
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int32
,
torch
.
long
,
torch
.
bool
]
torch_integer_dtype
=
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int32
,
torch
.
long
,
torch
.
bool
]
...
@@ -67,23 +67,3 @@ def randomize_tensor(tensor, start=1, end=100):
...
@@ -67,23 +67,3 @@ def randomize_tensor(tensor, start=1, end=100):
# with nn.init.uniform_
# with nn.init.uniform_
torch
.
nn
.
init
.
uniform_
(
tensor
.
data
,
start
,
end
)
torch
.
nn
.
init
.
uniform_
(
tensor
.
data
,
start
,
end
)
def
not_safe_to_prune
(
model
,
dummy_input
):
"""
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset
=
ReshapeDependency
(
model
,
dummy_input
)
return
reshape_dset
.
dependency_sets
\ No newline at end of file
test/ut/sdk/test_model_speedup.py
View file @
0a8fbbed
...
@@ -174,10 +174,18 @@ def prune_model_l1(model):
...
@@ -174,10 +174,18 @@ def prune_model_l1(model):
def
generate_random_sparsity
(
model
):
def
generate_random_sparsity
(
model
):
_start
=
0.5
_end
=
0.99
if
isinstance
(
model
,
models
.
mobilenet
.
MobileNetV2
):
# mobilenet models have great propagation characteristics
# so we use smaller sparsity ratio to avoid pruning the whole
# layer out
_start
=
0.01
_end
=
0.3
cfg_list
=
[]
cfg_list
=
[]
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
nn
.
Conv2d
):
if
isinstance
(
module
,
nn
.
Conv2d
):
sparsity
=
np
.
random
.
uniform
(
0.5
,
0.99
)
sparsity
=
np
.
random
.
uniform
(
_start
,
_end
)
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
return
cfg_list
...
@@ -187,11 +195,19 @@ def generate_random_sparsity_v2(model):
...
@@ -187,11 +195,19 @@ def generate_random_sparsity_v2(model):
"""
"""
Only select 50% layers to prune.
Only select 50% layers to prune.
"""
"""
_start
=
0.5
_end
=
0.99
if
isinstance
(
model
,
models
.
mobilenet
.
MobileNetV2
):
# mobilenet models have great propagation characteristics
# so we use smaller sparsity ratio to avoid pruning the whole
# layer out
_start
=
0.01
_end
=
0.3
cfg_list
=
[]
cfg_list
=
[]
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
nn
.
Conv2d
):
if
isinstance
(
module
,
nn
.
Conv2d
):
if
np
.
random
.
uniform
(
0
,
1.0
)
>
0.5
:
if
np
.
random
.
uniform
(
0
,
1.0
)
>
0.5
:
sparsity
=
np
.
random
.
uniform
(
0.5
,
0.99
)
sparsity
=
np
.
random
.
uniform
(
_start
,
_end
)
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
return
cfg_list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment