Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8bc74a2c
Unverified
Commit
8bc74a2c
authored
Oct 09, 2020
by
chicm-ms
Committed by
GitHub
Oct 09, 2020
Browse files
Speedup supports channel pruning (#2906)
parent
f43719a8
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
639 additions
and
144 deletions
+639
-144
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+32
-0
src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
+8
-4
src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
...k/pynni/nni/compression/torch/speedup/compress_modules.py
+13
-6
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
+18
-28
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
+290
-45
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
+132
-57
src/sdk/pynni/nni/compression/torch/utils/utils.py
src/sdk/pynni/nni/compression/torch/utils/utils.py
+30
-0
src/sdk/pynni/tests/test_compression_utils.py
src/sdk/pynni/tests/test_compression_utils.py
+1
-1
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+115
-3
No files found.
src/sdk/pynni/nni/_graph_utils.py
View file @
8bc74a2c
...
...
@@ -426,6 +426,36 @@ class TorchModuleGraph(TorchGraph):
cat_info
[
'in_shape'
]
=
input_shapes
return
cat_info
def
_extract_linear_shape_info
(
self
,
node_group
):
"""
Extract linear shape input/output tensor shape info from its aten::addmm op.
Parameters
----------
node_group : NodePyGroup
NodePyGroup object associated with the linear module.
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
for
cpp_node
in
node_group
.
node_cpps
:
if
cpp_node
.
kind
()
==
'aten::addmm'
:
# https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682
# inputs of aten::addmm:
# inputs[0] is bias
# inputs[1] is input data
# inputs[2] is weight
t_input
=
list
(
cpp_node
.
inputs
())[
1
]
t_output
=
cpp_node
.
output
()
assert
isinstance
(
t_input
.
type
(),
torch
.
_C
.
TensorType
)
assert
isinstance
(
t_output
.
type
(),
torch
.
_C
.
TensorType
)
in_shape
=
t_input
.
type
().
sizes
()
out_shape
=
t_output
.
type
().
sizes
()
return
{
'in_shape'
:
in_shape
,
'out_shape'
:
out_shape
}
return
None
def
_extract_shape_info
(
self
,
node
):
"""
Extract the shape information of ```aten::view``` node
...
...
@@ -701,6 +731,8 @@ class TorchModuleGraph(TorchGraph):
cpp_node
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_group
.
op_type
,
node_group
.
node_cpps
))[
0
]
node_group
.
auxiliary
=
self
.
_extract_shape_info
(
cpp_node
)
elif
node_group
.
op_type
==
'Linear'
:
node_group
.
auxiliary
=
self
.
_extract_linear_shape_info
(
node_group
)
elif
node_group
.
op_type
==
CAT_KIND
:
# get the detail information for cat func
cpp_node
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_group
.
op_type
,
...
...
src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
View file @
8bc74a2c
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import
logging
from
schema
import
And
,
Optional
from
schema
import
And
,
Optional
,
SchemaError
from
nni._graph_utils
import
TorchModuleGraph
from
nni.compression.torch.utils.shape_dependency
import
ChannelDependency
,
GroupDependency
from
.constants
import
MASKER_DICT
...
...
@@ -186,12 +186,16 @@ class _StructuredFilterPruner(OneshotPruner):
def
validate_config
(
self
,
model
,
config_list
):
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
for
config
in
config_list
:
if
'exclude'
not
in
config
and
'sparsity'
not
in
config
:
raise
SchemaError
(
'Either sparisty or exclude must be specified!'
)
def
_dependency_calc_mask
(
self
,
wrappers
,
channel_dsets
,
wrappers_idx
=
None
):
"""
...
...
src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
View file @
8bc74a2c
...
...
@@ -116,15 +116,19 @@ def replace_conv2d(conv, mask):
else
:
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels
=
out_channels_index
.
size
()[
0
]
_logger
.
debug
(
"replace conv2d with in_channels: %d, out_channels: %d"
,
in_channels
,
out_channels
)
groups
=
conv
.
groups
if
conv
.
in_channels
==
conv
.
out_channels
==
conv
.
groups
:
# remove groups for depthwise layers
assert
in_channels
==
out_channels
groups
=
in_channels
_logger
.
debug
(
"replace conv2d %s with in_channels: %d, out_channels: %d"
,
mask
.
module_name
,
in_channels
,
out_channels
)
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
conv
.
kernel_size
,
stride
=
conv
.
stride
,
padding
=
conv
.
padding
,
dilation
=
conv
.
dilation
,
groups
=
conv
.
groups
,
groups
=
groups
,
bias
=
conv
.
bias
is
not
None
,
padding_mode
=
conv
.
padding_mode
)
...
...
@@ -142,13 +146,16 @@ def replace_conv2d(conv, mask):
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step
=
int
(
conv
.
in_channels
/
conv
.
groups
)
in_channels_group
=
int
(
in_channels
/
conv
.
groups
)
filter_step
=
int
(
out_channels
/
conv
.
groups
)
if
mask
.
input_mask
is
not
None
:
in_channels_group
=
int
(
in_channels
/
groups
)
filter_step
=
int
(
out_channels
/
groups
)
if
mask
.
input_mask
is
not
None
and
not
(
in_channels
==
out_channels
==
groups
)
:
for
groupid
in
range
(
conv
.
groups
):
start
=
groupid
*
input_step
end
=
(
groupid
+
1
)
*
input_step
current_input_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
in_channels_index
.
tolist
()))
if
not
current_input_index
:
# there is no kept channel in current group
continue
# shift the global index into the group index
current_input_index
=
[
x
-
start
for
x
in
current_input_index
]
# if the groups is larger than 1, the input channels of each
...
...
src/sdk/pynni/nni/compression/torch/speedup/compressor.py
View file @
8bc74a2c
...
...
@@ -4,34 +4,13 @@
import
logging
import
torch
from
nni.compression.torch.utils.mask_conflict
import
fix_mask_conflict
from
nni.compression.torch.utils.utils
import
get_module_by_name
from
.compress_modules
import
replace_module
from
.infer_shape
import
ModuleMasks
,
infer_from_mask
,
infer_from_inshape
,
infer_from_outshape
from
.infer_shape
import
ModuleMasks
,
infer_from_mask
,
infer_from_inshape
,
infer_from_outshape
,
set_conv_prune_dim
_logger
=
logging
.
getLogger
(
__name__
)
def
get_module_by_name
(
model
,
module_name
):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list
=
module_name
.
split
(
"."
)
for
name
in
name_list
[:
-
1
]:
model
=
getattr
(
model
,
name
)
leaf_module
=
getattr
(
model
,
name_list
[
-
1
])
return
model
,
leaf_module
class
ModelSpeedup
:
"""
This class is to speedup the model with provided weight mask
...
...
@@ -87,7 +66,8 @@ class ModelSpeedup:
if
module_name
in
self
.
inferred_masks
:
module_masks
=
self
.
inferred_masks
[
module_name
]
else
:
module_masks
=
ModuleMasks
(
module_name
)
_
,
m
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
module_masks
=
ModuleMasks
(
module_name
,
m
)
self
.
inferred_masks
[
module_name
]
=
module_masks
m_type
=
self
.
torch_graph
.
name_to_node
[
module_name
].
op_type
...
...
@@ -98,7 +78,12 @@ class ModelSpeedup:
raise
RuntimeError
(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.
format
(
m_type
,
module_name
))
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
)
if
m_type
in
[
'Linear'
]:
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
,
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
)
else
:
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
)
if
in_shape
is
not
None
:
_logger
.
debug
(
"in_shape is not None"
)
if
not
m_type
in
infer_from_inshape
:
...
...
@@ -124,7 +109,10 @@ class ModelSpeedup:
raise
RuntimeError
(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.
format
(
m_type
,
module_name
))
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
if
m_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
]:
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
,
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
)
else
:
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
if
input_cmask
:
predecessors
=
self
.
torch_graph
.
find_predecessors
(
module_name
)
...
...
@@ -178,7 +166,6 @@ class ModelSpeedup:
else
:
raise
RuntimeError
(
"Unsupported node type: {}"
.
format
(
g_node
.
type
))
def
speedup_model
(
self
):
"""
There are basically two steps:
...
...
@@ -187,8 +174,11 @@ class ModelSpeedup:
"""
training
=
self
.
bound_model
.
training
_logger
.
info
(
"start to speed up the model"
)
_logger
.
info
(
"fix the mask conflict of the interdependent layers"
)
fix_mask_conflict
(
self
.
masks
,
self
.
bound_model
,
self
.
dummy_input
)
_
,
conv_prune_dim
=
fix_mask_conflict
(
self
.
masks
,
self
.
bound_model
,
self
.
dummy_input
)
set_conv_prune_dim
(
conv_prune_dim
)
_logger
.
info
(
"infer module masks..."
)
self
.
infer_modules_masks
()
_logger
.
info
(
"replace compressed modules..."
)
...
...
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
View file @
8bc74a2c
...
...
@@ -6,8 +6,22 @@ One is given output shape, infer its input shape and initialization parameters (
The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape)
"""
import
logging
import
torch
_logger
=
logging
.
getLogger
(
__name__
)
conv_prune_dim
=
-
1
def
set_conv_prune_dim
(
dim
):
"""
Parameters:
dim: int
0: filter pruning
1: channel pruning
"""
global
conv_prune_dim
conv_prune_dim
=
dim
class
CoarseMask
:
"""
...
...
@@ -160,7 +174,7 @@ class ModuleMasks:
The masks of a module, including the masks for weights, inputs, output
"""
def
__init__
(
self
,
module_name
):
def
__init__
(
self
,
module_name
,
module
=
None
):
"""
Parameters
----------
...
...
@@ -168,6 +182,7 @@ class ModuleMasks:
The name of the module or function
"""
self
.
module_name
=
module_name
self
.
module
=
module
self
.
param_masks
=
dict
()
self
.
input_mask
=
None
self
.
output_mask
=
None
...
...
@@ -202,8 +217,8 @@ class ModuleMasks:
self
.
output_mask
=
mask
def
__repr__
(
self
):
return
'input_mask: {}, output_mask: {}, param_masks: {}'
.
format
(
self
.
input_mask
,
self
.
output_mask
,
self
.
param_masks
return
'
module_name: {},
input_mask: {}, output_mask: {}, param_masks: {}'
.
format
(
self
.
module_name
,
self
.
input_mask
,
self
.
output_mask
,
self
.
param_masks
)
...
...
@@ -212,7 +227,8 @@ Infer input and output shape of a module/function from its weight mask
"""
infer_from_mask
=
{
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_mask
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_mask
(
module_masks
,
mask
)
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_mask
(
module_masks
,
mask
),
'Linear'
:
lambda
module_masks
,
mask
,
shape
:
linear_mask
(
module_masks
,
mask
,
shape
)
}
"""
...
...
@@ -260,7 +276,34 @@ infer_from_inshape = {
Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape
=
{
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_outshape
(
module_masks
,
mask
)
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_outshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_outshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'aten::adaptive_avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'AvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'ReLU'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::add_'
:
lambda
module_masks
,
mask
:
add_outshape
(
module_masks
,
mask
),
'aten::add'
:
lambda
module_mask
,
mask
:
add_outshape
(
module_mask
,
mask
),
'aten::flatten'
:
lambda
module_mask
,
mask
,
shape
:
view_outshape
(
module_mask
,
mask
,
shape
),
'aten::view'
:
lambda
module_masks
,
mask
,
shape
:
view_outshape
(
module_masks
,
mask
,
shape
),
'aten::reshape'
:
lambda
module_masks
,
mask
,
shape
:
view_outshape
(
module_masks
,
mask
,
shape
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_outshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
)
}
def
dropout_inshape
(
module_masks
,
mask
):
...
...
@@ -282,7 +325,15 @@ def dropout_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
return
module_masks
.
output_mask
def
dropout_outshape
(
module_masks
,
mask
):
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
return
module_masks
.
input_mask
# if alreay visited
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
return
module_masks
.
output_mask
def
cat_inshape
(
module_masks
,
mask
,
cat_info
,
last_visited
):
"""
...
...
@@ -382,6 +433,20 @@ def add_inshape(module_masks, mask):
raise
Exception
(
'Mask conflict happenes!'
)
return
None
def
add_outshape
(
module_masks
,
mask
):
"""
Inference the input mask of the add operation from the
output mask.
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
return
mask
else
:
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
return
mask
def
batchnorm2d_inshape
(
module_masks
,
mask
):
"""
...
...
@@ -412,6 +477,34 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks
.
set_param_masks
(
'bias'
,
weight_cmask
)
return
mask
def
batchnorm2d_outshape
(
module_masks
,
mask
):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
len
(
mask
.
mask_index
)
in
[
2
,
4
]
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
weight_cmask
=
CoarseMask
(
num_dim
=
1
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
weight_cmask
)
return
mask
def
linear_inshape
(
module_masks
,
mask
):
"""
...
...
@@ -484,6 +577,42 @@ def view_inshape(module_masks, mask, shape):
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
def
view_outshape
(
module_masks
,
mask
,
shape
):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```flatten``` op
mask : CoarseMask
The mask of its input tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
"""
# NOTE: the case constrained by the following four asserts
assert
shape
[
'in_shape'
][
0
]
==
shape
[
'out_shape'
][
0
]
assert
len
(
shape
[
'in_shape'
])
==
4
assert
len
(
shape
[
'out_shape'
])
==
2
assert
shape
[
'out_shape'
][
1
]
==
shape
[
'in_shape'
][
1
]
*
\
shape
[
'in_shape'
][
2
]
*
shape
[
'in_shape'
][
3
]
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_output_mask
(
mask
)
input_cmask
=
CoarseMask
(
num_dim
=
4
)
index
=
[]
step_size
=
shape
[
'in_shape'
][
2
]
*
shape
[
'in_shape'
][
3
]
for
loc
in
mask
.
mask_index
[
1
]:
index
.
extend
([
loc
*
step_size
+
i
for
i
in
range
(
step_size
)])
input_cmask
.
add_index_mask
(
dim
=
1
,
index
=
torch
.
tensor
(
index
))
# pylint: disable=not-callable
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
def
size_inshape
(
module_masks
,
mask
):
"""
...
...
@@ -513,6 +642,26 @@ def mean_inshape(module_masks, mask, shape):
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
def
mean_outshape
(
module_masks
,
mask
,
shape
):
"""
Similar to view operation, currently mask inference only supports
the mean operation on the 3rd and 4th dimensions.
"""
assert
shape
[
'in_shape'
][
0
]
==
shape
[
'out_shape'
][
0
]
assert
shape
[
'out_shape'
][
1
]
==
shape
[
'in_shape'
][
1
]
assert
len
(
shape
[
'in_shape'
])
==
4
assert
len
(
shape
[
'out_shape'
])
==
2
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_output_mask
(
mask
)
input_cmask
=
CoarseMask
(
num_dim
=
4
)
input_cmask
.
add_index_mask
(
dim
=
1
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
def
maxpool2d_inshape
(
module_masks
,
mask
):
"""
Assume only the second dimension is masked
...
...
@@ -541,6 +690,29 @@ def maxpool2d_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
return
mask
def
maxpool2d_outshape
(
module_masks
,
mask
):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
relu_inshape
(
module_masks
,
mask
):
"""
...
...
@@ -558,25 +730,44 @@ def relu_inshape(module_masks, mask):
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
input_mask
is
not
None
:
#
check if has a mask conflict
#
mask conflict should be solved before speedup
assert
module_masks
.
input_mask
<=
mask
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
relu_outshape
(
module_masks
,
mask
):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
output_mask
is
not
None
:
# mask conflict should be solved before speedup
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
batchnorm2d_mask
(
module_masks
,
mask
):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
...
...
@@ -601,6 +792,38 @@ def batchnorm2d_mask(module_masks, mask):
module_masks
.
set_output_mask
(
output_cmask
)
return
input_cmask
,
output_cmask
def
linear_mask
(
module_masks
,
mask
,
shape
):
"""
Infer input and output shape from weight mask with limitations:
Only support infer input mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the Linear
mask : dict
The mask of its weights, from the user provided mask file
shape: dict
Shape of its input and output tensors
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert
'weight'
in
mask
num_input_dim
=
len
(
shape
[
'in_shape'
])
# Input data of Linear module can have multiple dimensions.
# here we only support infer coarse mask on the first dimension (dimension 0)
nonzero_index
=
torch
.
nonzero
(
mask
[
'weight'
].
sum
(
0
),
as_tuple
=
True
)[
0
]
# infer shape of input tensor
input_cmask
=
CoarseMask
(
num_dim
=
num_input_dim
)
input_cmask
.
add_index_mask
(
dim
=
num_input_dim
-
1
,
index
=
nonzero_index
)
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
,
None
def
conv2d_mask
(
module_masks
,
mask
):
"""
...
...
@@ -618,12 +841,15 @@ def conv2d_mask(module_masks, mask):
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
def
convert_to_coarse_mask
(
mask
):
def
convert_to_coarse_mask
(
mask
,
dim
=
0
):
"""
Parameters
----------
mask : dict
Weight mask from user provided mask file
dim: int
0: filter pruning
1: channel pruning
Returns
-------
...
...
@@ -632,64 +858,69 @@ def conv2d_mask(module_masks, mask):
"""
assert
'weight'
in
mask
assert
isinstance
(
mask
[
'weight'
],
torch
.
Tensor
)
assert
dim
in
[
0
,
1
]
weight_mask
=
mask
[
'weight'
]
shape
=
weight_mask
.
size
()
ones
=
torch
.
ones
(
shape
[
1
:]).
to
(
weight_mask
.
device
)
zeros
=
torch
.
zeros
(
shape
[
1
:]).
to
(
weight_mask
.
device
)
index
=
[]
for
i
in
range
(
shape
[
0
]):
if
torch
.
all
(
torch
.
eq
(
weight_mask
[
i
],
ones
)):
index
.
append
(
i
)
elif
torch
.
all
(
torch
.
eq
(
weight_mask
[
i
],
zeros
)):
continue
else
:
index
=
None
break
sum_idx
=
(
1
,
2
,
3
)
if
dim
==
0
else
(
0
,
2
,
3
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
if
len
(
index
)
==
weight_mask
.
shape
[
dim
]:
# full mask
index
=
None
if
index
is
None
:
return
None
,
None
,
None
else
:
index
=
torch
.
LongTensor
(
index
).
to
(
weight_mask
.
device
)
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
index
)
weight_cmask
.
add_index_mask
(
dim
=
dim
,
index
=
index
)
bias_cmask
=
None
if
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
if
dim
==
0
and
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
bias_index
=
torch
.
nonzero
(
mask
[
'bias'
],
as_tuple
=
True
)[
0
]
assert
torch
.
all
(
torch
.
eq
(
index
,
bias_index
)),
\
"bias mask should be consistent with weight mask"
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
bias_index
)
return
index
,
weight_cmask
,
bias_cmask
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
)
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
,
dim
=
conv_prune_dim
)
if
index
is
None
:
# TODO: fine grained mask speedup
return
None
,
None
# deal with coarse grain mask
# mask conflict should be solved by fix_mask_conflict before speedup
if
'weight'
in
module_masks
.
param_masks
:
module_masks
.
param_masks
[
'weight'
].
merge
(
weight_cmask
)
module_masks
.
param_masks
[
'bias'
].
merge
(
bias_cmask
)
assert
module_masks
.
param_masks
[
'weight'
]
==
weight_cmask
else
:
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
output_cmask
=
CoarseMask
(
num_dim
=
4
)
output_cmask
.
add_index_mask
(
dim
=
1
,
index
=
index
)
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
output_cmask
)
else
:
module_masks
.
output_mask
.
merge
(
output_cmask
)
return
None
,
module_masks
.
output_mask
if
conv_prune_dim
==
0
:
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
io_cmask
=
CoarseMask
(
num_dim
=
4
)
io_cmask
.
add_index_mask
(
dim
=
1
,
index
=
index
)
if
conv_prune_dim
==
0
:
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
io_cmask
)
else
:
assert
module_masks
.
output_mask
==
io_cmask
return
None
,
module_masks
.
output_mask
else
:
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
io_cmask
)
else
:
assert
module_masks
.
input_mask
==
io_cmask
return
module_masks
.
input_mask
,
None
def
conv2d_inshape
(
module_masks
,
mask
):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
...
...
@@ -701,8 +932,15 @@ def conv2d_inshape(module_masks, mask):
else
:
# the same conv layer may be accessed more
# than once, such as a concat operation.
assert
module_masks
.
input_mask
<=
mask
module_masks
.
input_mask
.
merge
(
mask
)
# mask conflict should be solved by fix_mask_conflict before speedup
assert
module_masks
.
input_mask
==
mask
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
...
...
@@ -728,18 +966,25 @@ def conv2d_outshape(module_masks, mask):
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
if
module_masks
.
output_mask
is
not
None
:
assert
isinstance
(
module_masks
.
output_mask
,
CoarseMask
)
# set shape of output
mask
=
module_masks
.
output_mask
.
merge
(
mask
)
else
:
if
module_masks
.
output_mask
is
None
:
module_masks
.
output_mask
=
mask
# infer shape of parameters
else
:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
# input shape is not changed
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
View file @
8bc74a2c
...
...
@@ -4,9 +4,10 @@ import os
import
logging
import
torch
import
numpy
as
np
from
.shape_dependency
import
ChannelDependency
,
GroupDependency
,
CatPaddingDependency
from
.shape_dependency
import
ChannelDependency
,
GroupDependency
,
CatPaddingDependency
,
InputChannelDependency
from
.utils
import
get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
_logger
=
logging
.
getLogger
(
'FixMaskConflict'
)
_logger
=
logging
.
getLogger
(
__name__
)
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
...
...
@@ -45,7 +46,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks
=
fix_channel_mask
.
fix_mask
()
padding_cat_mask
=
CatMaskPadding
(
masks
,
model
,
dummy_input
,
traced
)
masks
=
padding_cat_mask
.
fix_mask
()
return
masks
return
masks
,
fix_channel_mask
.
conv_prune_dim
class
MaskFix
:
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
...
...
@@ -221,74 +222,148 @@ class ChannelMaskConflict(MaskFix):
we donnot use the model and dummpy_input to get the trace graph.
"""
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
_logger
.
info
(
'detected conv prune dim: %s'
,
self
.
conv_prune_dim
)
def
fix_mask
(
self
):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
mask inference of the 'speedup' module. Only structured pruning masks
are supported.
"""
channel_depen
=
ChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
if
self
.
conv_prune_dim
==
0
:
channel_depen
=
ChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
else
:
channel_depen
=
InputChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depen_sets
=
channel_depen
.
dependency_sets
sum_idx
=
(
1
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
0
,
2
,
3
)
for
dset
in
depen_sets
:
if
len
(
dset
)
==
1
:
# This layer has no channel dependency with other layers
if
len
(
dset
)
<=
1
:
continue
channel_remain
=
set
()
# channel_masks is a list, each element is None or a vector, for example:
# [[0, 1, 1, 0, 0], [0, 0, 1, 1, 0], None], None means no channel
# is pruned.
channel_masks
=
[]
fine_grained
=
False
out_channels
=
None
# A flag that represents if all the layers in
# the dependency set are pruned
all_pruned
=
True
for
name
in
dset
:
if
name
not
in
self
.
masks
:
# this layer is not pruned
all_pruned
=
False
continue
w_mask
=
self
.
masks
[
name
][
'weight'
]
if
out_channels
is
None
:
out_channels
=
w_mask
.
size
(
0
)
shape
=
w_mask
.
size
()
count
=
np
.
prod
(
shape
[
1
:])
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
if
len
(
all_ones
)
+
len
(
all_zeros
)
<
w_mask
.
size
(
0
):
# In fine-grained pruning, there is no need to check
# the shape conflict
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
','
.
join
(
dset
))
fine_grained
=
True
break
channel_remain
.
update
(
all_ones
)
_logger
.
debug
(
'Layer: %s '
,
name
)
_logger
.
debug
(
'Original pruned filters: %s'
,
str
(
all_zeros
))
# Update the masks for the layers in the dependency set
if
fine_grained
or
out_channels
is
None
:
# if use the fine-grained pruner or all the layers in
# this dependency set are not pruned
if
name
in
self
.
masks
:
_
,
m
=
get_module_by_name
(
self
.
model
,
name
)
assert
m
is
not
None
mask
=
self
.
masks
[
name
][
'weight'
]
if
type
(
m
).
__name__
==
'Conv2d'
:
channel_mask
=
(
mask
.
abs
().
sum
(
sum_idx
)
!=
0
).
int
()
channel_masks
.
append
(
channel_mask
)
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
fine_grained
=
True
elif
type
(
m
).
__name__
==
'Linear'
:
channel_masks
.
append
((
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
channel_masks
.
append
(
mask
.
int
())
else
:
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
else
:
# no mask means not pruned, equivlent to full masks
channel_masks
.
append
(
None
)
if
fine_grained
:
_logger
.
info
(
'fine-grained mask detected, skip solving conflict for this set: %s'
,
dset
)
continue
if
not
all_pruned
:
# if some layer are not pruned at all
# then all the layers in this dependency set
# cannot be pruned due to the shape dependency.
channel_remain
.
update
(
range
(
out_channels
))
ori_channels
=
0
if
all
(
x
is
None
for
x
in
channel_masks
):
continue
num_channels_list
=
[
len
(
x
)
for
x
in
channel_masks
if
x
is
not
None
]
# number of channels in same set should be identical
assert
len
(
set
(
num_channels_list
))
==
1
num_channels
=
num_channels_list
[
0
]
for
i
,
dim_mask
in
enumerate
(
channel_masks
):
if
dim_mask
is
None
:
channel_masks
[
i
]
=
torch
.
ones
(
num_channels
).
int
()
# merge masks with 'or'
merged_channel_mask
=
channel_masks
[
0
].
clone
()
for
i
in
range
(
1
,
len
(
channel_masks
)):
merged_channel_mask
=
((
merged_channel_mask
+
channel_masks
[
i
])
!=
0
).
int
()
merged_index
=
torch
.
nonzero
(
merged_channel_mask
,
as_tuple
=
True
)[
0
]
for
name
in
dset
:
if
name
not
in
self
.
masks
:
# this layer is not pruned at all
# in this case, all_pruned is False
# and the other layers in the same dset
# will not be pruned either.
assert
all
(
merged_channel_mask
)
continue
mask
=
self
.
masks
[
name
]
w_shape
=
mask
[
'weight'
].
size
()
ori_channels
=
w_shape
[
0
]
for
i
in
channel_remain
:
mask
[
'weight'
][
i
]
=
torch
.
ones
(
w_shape
[
1
:])
if
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
mask
[
'bias'
][
i
]
=
1
_logger
.
info
(
','
.
join
(
dset
))
_logger
.
info
(
'Pruned Filters after fixing conflict:'
)
pruned_filters
=
set
(
list
(
range
(
ori_channels
)))
-
channel_remain
_logger
.
info
(
str
(
sorted
(
pruned_filters
)))
orig_mask
=
self
.
masks
[
name
][
'weight'
]
_
,
m
=
get_module_by_name
(
self
.
model
,
name
)
new_mask
=
torch
.
zeros_like
(
orig_mask
)
if
type
(
m
).
__name__
==
'Conv2d'
:
if
self
.
conv_prune_dim
==
0
:
new_mask
[
merged_index
,
:,
:,
:]
=
1.
else
:
new_mask
[:,
merged_index
,
:,
:]
=
1.
elif
type
(
m
).
__name__
==
'Linear'
:
new_mask
[:,
merged_index
]
=
1.
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
new_mask
=
merged_index
.
type_as
(
orig_mask
)
else
:
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
self
.
masks
[
name
][
'weight'
]
=
new_mask
if
'bias'
in
self
.
masks
[
name
]
and
self
.
masks
[
name
][
'bias'
]
is
not
None
:
if
type
(
m
).
__name__
==
'Conv2d'
:
assert
self
.
conv_prune_dim
==
0
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
return
self
.
masks
def
detect_mask_prune_dim
(
masks
,
model
):
"""
Detect how the masks of convolutional layers are pruned.
Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.
Returns:
-------
How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should
return 1 for masks generated by AMCPruner, and returns 0 for masks generated by the rest
NNI builtin pruners.
0: filter pruning, prune filters of weights which causes channels of output feature maps are pruned.
1: channel pruning, prune kernels corresponding to each input channels which causes channels of
input feature maps are pruned.
"""
dim0_preserved
,
dim1_preserved
=
0.
,
0.
dim0_num
,
dim1_num
=
0.
,
0.
for
module_name
in
masks
:
_
,
m
=
get_module_by_name
(
model
,
module_name
)
if
m
is
None
or
type
(
m
).
__name__
!=
'Conv2d'
:
continue
mask
=
masks
[
module_name
][
'weight'
].
clone
()
assert
(
mask
>=
0
).
sum
()
==
mask
.
numel
(),
\
"mask values should be greater than or equal to 0."
mask
=
(
mask
>
0
).
int
()
mask
=
mask
.
view
(
mask
.
shape
[
0
],
mask
.
shape
[
1
],
-
1
)
dim0_mask
=
(
mask
.
sum
((
1
,
2
))
>
0
).
int
()
dim1_mask
=
(
mask
.
sum
((
0
,
2
))
>
0
).
int
()
dim0_preserved
+=
dim0_mask
.
sum
().
item
()
dim1_preserved
+=
dim1_mask
.
sum
().
item
()
dim0_num
+=
len
(
dim0_mask
)
dim1_num
+=
len
(
dim1_mask
)
if
dim0_num
==
0
or
dim1_num
==
0
:
_logger
.
warning
(
'no multi-dimension masks found.'
)
return
0
dim0_sparsity
,
dim1_sparsity
=
1.
-
dim0_preserved
/
dim0_num
,
1.
-
dim1_preserved
/
dim1_num
_logger
.
info
(
'dim0 sparsity: %f'
,
dim0_sparsity
)
_logger
.
info
(
'dim1 sparsity: %f'
,
dim1_sparsity
)
if
dim0_sparsity
==
dim1_sparsity
==
0.
:
_logger
.
warning
(
'nothing masked.'
)
if
dim0_sparsity
>
0
and
dim1_sparsity
>
0
:
_logger
.
warning
(
'both dim0 and dim1 masks found.'
)
return
0
if
dim0_sparsity
>=
dim1_sparsity
else
1
src/sdk/pynni/nni/compression/torch/utils/utils.py
0 → 100644
View file @
8bc74a2c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
def
get_module_by_name
(
model
,
module_name
):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list
=
module_name
.
split
(
"."
)
for
name
in
name_list
[:
-
1
]:
if
hasattr
(
model
,
name
):
model
=
getattr
(
model
,
name
)
else
:
return
None
,
None
if
hasattr
(
model
,
name_list
[
-
1
]):
leaf_module
=
getattr
(
model
,
name_list
[
-
1
])
return
model
,
leaf_module
else
:
return
None
,
None
src/sdk/pynni/tests/test_compression_utils.py
View file @
8bc74a2c
...
...
@@ -115,7 +115,7 @@ class AnalysisUtilsTest(TestCase):
pruner
.
export_model
(
ck_file
,
mask_file
)
pruner
.
_unwrap_model
()
# Fix the mask conflict
fixed_mask
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
fixed_mask
,
_
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
# use the channel dependency groud truth to check if
# fix the mask conflict successfully
...
...
src/sdk/pynni/tests/test_model_speedup.py
View file @
8bc74a2c
...
...
@@ -12,6 +12,8 @@ from torchvision.models.resnet import resnet18
from
unittest
import
TestCase
,
main
from
nni.compression.torch
import
L1FilterPruner
,
apply_compression_results
,
ModelSpeedup
from
nni.compression.torch.pruning.weight_masker
import
WeightMasker
from
nni.compression.torch.pruning.one_shot
import
_StructuredFilterPruner
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
...
...
@@ -104,6 +106,74 @@ def zero_bn_bias(model):
shape
=
module
.
running_mean
.
data
.
size
()
module
.
running_mean
=
torch
.
zeros
(
shape
).
to
(
device
)
class
L1ChannelMasker
(
WeightMasker
):
def
__init__
(
self
,
model
,
pruner
):
self
.
model
=
model
self
.
pruner
=
pruner
def
calc_mask
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
msg
=
'module type {} is not supported!'
.
format
(
wrapper
.
type
)
#assert wrapper.type == 'Conv2d', msg
weight
=
wrapper
.
module
.
weight
.
data
bias
=
None
if
hasattr
(
wrapper
.
module
,
'bias'
)
and
wrapper
.
module
.
bias
is
not
None
:
bias
=
wrapper
.
module
.
bias
.
data
if
wrapper
.
weight_mask
is
None
:
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
else
:
mask_weight
=
wrapper
.
weight_mask
.
clone
()
if
bias
is
not
None
:
if
wrapper
.
bias_mask
is
None
:
mask_bias
=
torch
.
ones
(
bias
.
size
()).
type_as
(
bias
).
detach
()
else
:
mask_bias
=
wrapper
.
bias_mask
.
clone
()
else
:
mask_bias
=
None
base_mask
=
{
'weight_mask'
:
mask_weight
,
'bias_mask'
:
mask_bias
}
num_total
=
weight
.
size
(
1
)
num_prune
=
int
(
num_total
*
sparsity
)
if
num_total
<
2
or
num_prune
<
1
:
return
base_mask
w_abs
=
weight
.
abs
()
if
wrapper
.
type
==
'Conv2d'
:
w_abs_structured
=
w_abs
.
sum
((
0
,
2
,
3
))
threshold
=
torch
.
topk
(
w_abs_structured
,
num_prune
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs_structured
,
threshold
)[
None
,
:,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
return
{
'weight_mask'
:
mask_weight
.
detach
()}
else
:
# Linear
assert
wrapper
.
type
==
'Linear'
w_abs_structured
=
w_abs
.
sum
((
0
))
threshold
=
torch
.
topk
(
w_abs_structured
,
num_prune
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs_structured
,
threshold
)[
None
,
:].
expand_as
(
weight
).
type_as
(
weight
)
return
{
'weight_mask'
:
mask_weight
.
detach
(),
'bias_mask'
:
mask_bias
}
class
L1ChannelPruner
(
_StructuredFilterPruner
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
,
optimizer
=
optimizer
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
validate_config
(
self
,
model
,
config_list
):
pass
def
channel_prune
(
model
):
config_list
=
[{
'sparsity'
:
SPARSITY
,
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
'op_names'
:
[
'conv1'
],
'exclude'
:
True
}]
pruner
=
L1ChannelPruner
(
model
,
config_list
)
masker
=
L1ChannelMasker
(
model
,
pruner
)
pruner
.
masker
=
masker
pruner
.
compress
()
pruner
.
export_model
(
model_path
=
MODEL_FILE
,
mask_path
=
MASK_FILE
)
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
...
...
@@ -145,10 +215,20 @@ class SpeedupTestCase(TestCase):
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
test_speedup_integration
(
self
):
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'densenet169'
,
'inception_v3'
]:
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'densenet169'
,
'inception_v3'
,
'resnet50'
]:
kwargs
=
{
'pretrained'
:
True
}
if
model_name
==
'resnet50'
:
# testing multiple groups
kwargs
=
{
'pretrained'
:
False
,
'groups'
:
4
}
Model
=
getattr
(
models
,
model_name
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
speedup_model
=
Model
().
to
(
device
)
net
=
Model
(
**
kwargs
).
to
(
device
)
speedup_model
=
Model
(
**
kwargs
).
to
(
device
)
net
.
eval
()
# this line is necessary
speedup_model
.
eval
()
# random generate the prune config for the pruner
...
...
@@ -165,6 +245,9 @@ class SpeedupTestCase(TestCase):
data
=
torch
.
ones
(
BATCH_SIZE
,
3
,
224
,
224
).
to
(
device
)
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
)
ms
.
speedup_model
()
speedup_model
.
eval
()
ori_out
=
net
(
data
)
speeded_out
=
speedup_model
(
data
)
ori_sum
=
torch
.
sum
(
ori_out
).
item
()
...
...
@@ -174,6 +257,35 @@ class SpeedupTestCase(TestCase):
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
def
test_channel_prune
(
self
):
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
channel_prune
(
orig_net
)
state_dict
=
torch
.
load
(
MODEL_FILE
)
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
orig_net
.
load_state_dict
(
state_dict
)
apply_compression_results
(
orig_net
,
MASK_FILE
)
orig_net
.
eval
()
net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
net
.
load_state_dict
(
state_dict
)
net
.
eval
()
data
=
torch
.
randn
(
BATCH_SIZE
,
3
,
224
,
224
).
to
(
device
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
)
ms
.
speedup_model
()
ms
.
bound_model
(
data
)
net
.
eval
()
ori_sum
=
orig_net
(
data
).
abs
().
sum
().
item
()
speeded_sum
=
net
(
data
).
abs
().
sum
().
item
()
print
(
ori_sum
,
speeded_sum
)
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
def
tearDown
(
self
):
os
.
remove
(
MODEL_FILE
)
os
.
remove
(
MASK_FILE
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment