Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7eedec46
Unverified
Commit
7eedec46
authored
Jul 14, 2021
by
Ningxin Zheng
Committed by
GitHub
Jul 14, 2021
Browse files
Model Speedup Refactor (#3462)
parent
5b99b598
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2218 additions
and
1701 deletions
+2218
-1701
docs/en_US/Compression/CompressionReference.rst
docs/en_US/Compression/CompressionReference.rst
+0
-3
nni/common/graph_utils.py
nni/common/graph_utils.py
+12
-2
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+403
-192
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+453
-126
nni/compression/pytorch/speedup/infer_mask.py
nni/compression/pytorch/speedup/infer_mask.py
+378
-0
nni/compression/pytorch/speedup/infer_shape.py
nni/compression/pytorch/speedup/infer_shape.py
+0
-1146
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+553
-0
nni/compression/pytorch/utils/__init__.py
nni/compression/pytorch/utils/__init__.py
+1
-0
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+50
-90
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+174
-102
nni/compression/pytorch/utils/utils.py
nni/compression/pytorch/utils/utils.py
+52
-0
test/ut/sdk/test_compression_utils.py
test/ut/sdk/test_compression_utils.py
+1
-1
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+141
-39
No files found.
docs/en_US/Compression/CompressionReference.rst
View file @
7eedec46
...
...
@@ -140,9 +140,6 @@ Topology Utilities
.. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency
:members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.CatMaskPadding
:members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict
:members:
...
...
nni/common/graph_utils.py
View file @
7eedec46
...
...
@@ -71,7 +71,11 @@ class TorchGraph:
def
_trace
(
self
,
model
,
dummy_input
):
training
=
model
.
training
model
.
eval
()
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
kw_args
=
{}
if
torch
.
__version__
>=
'1.6.0'
:
# only pytorch with version greater than 1.6.0 has the strict option
kw_args
[
'strict'
]
=
False
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
,
**
kw_args
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
model
.
train
(
training
)
...
...
@@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
().
__init__
(
model
,
dummy_input
,
traced_model
)
self
.
global_count
=
0
self
.
reused_module
=
set
()
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
self
.
_extract_auxiliary_info
()
...
...
@@ -390,9 +395,12 @@ class TorchModuleGraph(TorchGraph):
outputs
.
append
(
output_name
)
else
:
outputs
.
append
(
output_name
)
unique_outputs
=
list
(
set
(
outputs
))
# remove the dumplicated output names
unique_outputs
.
sort
(
key
=
outputs
.
index
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
list
(
inputs
),
outputs
=
list
(
outputs
)
)
node_group
,
inputs
=
list
(
inputs
),
outputs
=
unique_
outputs
)
return
nodepy
def
_extract_cat_info
(
self
,
node_group
,
cpp_node
):
...
...
@@ -724,6 +732,8 @@ class TorchModuleGraph(TorchGraph):
unique_name
=
module_name
if
use_count
>
0
:
unique_name
=
module_name
+
'.%d'
%
use_count
self
.
reused_module
.
add
(
unique_name
)
self
.
reused_module
.
add
(
module_name
)
node_group
=
self
.
_expand_module_node
(
node
,
module_name
,
unique_name
,
module_to_type
[
module_name
],
node_cpps
,
input_to_node
,
output_to_node
,
'module'
)
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
7eedec46
...
...
@@ -3,222 +3,394 @@
import
logging
import
torch
from
.infer_shape
import
ModuleMasks
import
torch.nn
as
nn
_logger
=
logging
.
getLogger
(
__name__
)
replace_module
=
{
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'ConvTranspose2d'
:
lambda
module
,
mask
:
replace_convtranspose2d
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'PReLU'
:
lambda
module
,
mask
:
replace_prelu
(
module
,
mask
),
'ReLU6'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Sigmoid'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
),
'Dropout'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Dropout2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Dropout3d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
)
'BatchNorm2d'
:
lambda
module
,
masks
:
replace_batchnorm2d
(
module
,
masks
),
'BatchNorm1d'
:
lambda
module
,
masks
:
replace_batchnorm1d
(
module
,
masks
),
'Conv2d'
:
lambda
module
,
masks
:
replace_conv2d
(
module
,
masks
),
'Linear'
:
lambda
module
,
masks
:
replace_linear
(
module
,
masks
),
'MaxPool2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'AvgPool2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'AdaptiveAvgPool2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'ReLU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'ReLU6'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'LeakyReLU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'ELU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Hardtanh'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Hardsigmoid'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'LogSigmoid'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'PReLU'
:
lambda
module
,
masks
:
replace_prelu
(
module
,
masks
),
'RReLU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'SELU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'CELU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'GELU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Sigmoid'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'SiLU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Mish'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Tanh'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Softplus'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Softshrink'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Softmax'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Tanhshrink'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Dropout'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Dropout2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Dropout3d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Upsample'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'LayerNorm'
:
lambda
module
,
masks
:
replace_layernorm
(
module
,
masks
),
'ConvTranspose2d'
:
lambda
module
,
masks
:
replace_convtranspose2d
(
module
,
masks
)
}
def
no_replace
(
module
,
mask
):
def
convert_to_coarse_mask
(
t_mask
,
dim
):
"""
Convert the mask tensor to the coarse-grained mask tensor.
Parameters
---------
t_mask: torch.Tensor
The tensor only have 1s and 0s, 0 indicates this value is masked
and 1 indicates the corresponding value is not masked.
dim: int
Try to reduce the mask tensor on this dimension.
Returns
-------
indexes: torch.Tensor
The indexes of the sparsity that can be structurally removed.
remained_indexes: torch.Tensor
The indexes of values that need to be remained.
"""
assert
isinstance
(
t_mask
,
torch
.
Tensor
)
shape
=
list
(
t_mask
.
size
())
n_dims
=
len
(
shape
)
dim_list
=
list
(
range
(
n_dims
))
# try to reduce the mask from the dim-th dimension
dim_list
.
remove
(
dim
)
t_merged
=
torch
.
sum
(
t_mask
,
dim_list
)
assert
t_merged
.
size
(
0
)
==
shape
[
dim
]
all_pruned
=
t_merged
==
0
need_remain
=
t_merged
!=
0
# return the indexes of the sparsity that can be removed
indexes
=
torch
.
nonzero
(
all_pruned
,
as_tuple
=
True
)[
0
]
remained_indexes
=
torch
.
nonzero
(
need_remain
,
as_tuple
=
True
)[
0
]
return
indexes
,
remained_indexes
def
no_replace
(
module
,
masks
):
"""
No need to replace
"""
_logger
.
debug
(
"no need to replace"
)
return
module
def
replace_prelu
(
norm
,
mask
):
def
replace_prelu
(
prelu
,
mask
s
):
"""
Parameters
----------
norm
: torch.nn.
BatchNorm2d
module
: torch.nn.
PReLU
The prelu module to be replace
mask :
ModuleM
asks
The masks of th
is
module
mask
s
:
tuple of m
asks
The
input/output/weight
masks of th
e target
module
Returns
-------
torch.nn.PReLU
The new prelu module
"""
assert
isinstance
(
mask
,
ModuleMasks
)
assert
'weight'
in
mask
.
param_masks
index
=
mask
.
param_masks
[
'weight'
].
mask_index
[
0
]
num_features
=
index
.
size
()[
0
]
# _logger.debug("replace prelu with num_features: %d", num_features)
if
num_features
==
0
:
in_masks
,
output_mask
,
weight_mask
=
masks
assert
len
(
in_masks
)
==
1
assert
isinstance
(
output_mask
,
torch
.
Tensor
)
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
n_remained_in
=
weight_mask
.
size
(
0
)
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
remained_in
,
remained_out
=
remained_in
.
to
(
prelu
.
weight
.
device
),
remained_out
.
to
(
prelu
.
weight
.
device
)
assert
n_remained_in
==
n_remained_out
if
n_remained_in
==
0
:
return
torch
.
nn
.
Identity
()
new_norm
=
torch
.
nn
.
PReLU
(
num_features
)
# assign weights
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
index
)
return
new_norm
new_prelu
=
torch
.
nn
.
PReLU
(
n_remained_in
)
new_prelu
.
weight
.
data
=
torch
.
index_select
(
prelu
.
weight
.
data
,
0
,
remained_in
)
return
new_prelu
def
replace_linear
(
linear
,
mask
):
def
replace_linear
(
linear
,
mask
s
):
"""
This function will replace the original linear according to
the infered masks. This function support the fine-grained and
coarse-grained sparsity. In the fine-grained scenario, this function
will remove the whole column/row that happen to be totally covered by
the masks.
Parameters
----------
linear : torch.nn.Linear
The linear module to be replace
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.Linear
The new linear module
"""
assert
isinstance
(
mask
,
ModuleMasks
)
assert
mask
.
input_mask
is
not
None
assert
mask
.
output_mask
is
None
assert
not
mask
.
param_masks
index
=
mask
.
input_mask
.
mask_index
[
-
1
]
in_features
=
index
.
size
()[
0
]
_logger
.
debug
(
"replace linear with new in_features: %d"
,
in_features
)
new_linear
=
torch
.
nn
.
Linear
(
in_features
=
in_features
,
out_features
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
)
new_linear
.
to
(
linear
.
weight
.
device
)
new_linear
.
weight
.
data
=
torch
.
index_select
(
linear
.
weight
.
data
,
-
1
,
index
.
to
(
linear
.
weight
.
device
))
in_masks
,
output_mask
,
weight_mask
=
masks
assert
isinstance
(
linear
,
nn
.
Linear
)
assert
len
(
in_masks
)
==
1
assert
isinstance
(
output_mask
,
torch
.
Tensor
)
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
# N C K
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
n_remained_in
=
weight_mask
.
size
(
1
)
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
remained_in
,
remained_out
=
remained_in
.
to
(
linear
.
weight
.
device
),
remained_out
.
to
(
linear
.
weight
.
device
)
_logger
.
info
(
"replace linear with new in_features: %d, out_features: %d"
,
n_remained_in
,
n_remained_out
)
need_bias
=
False
if
linear
.
bias
is
not
None
:
new_linear
.
bias
.
data
.
copy_
(
linear
.
bias
.
data
)
need_bias
=
True
new_linear
=
torch
.
nn
.
Linear
(
in_features
=
n_remained_in
,
out_features
=
n_remained_out
,
bias
=
need_bias
)
new_linear
.
to
(
linear
.
weight
.
device
)
# Copy the remained weight from the original module
with
torch
.
no_grad
():
tmp_weight_data
=
torch
.
index_select
(
linear
.
weight
.
data
,
0
,
remained_out
)
new_linear
.
weight
.
data
=
torch
.
index_select
(
tmp_weight_data
,
1
,
remained_in
)
if
linear
.
bias
is
not
None
:
new_linear
.
bias
.
data
=
torch
.
index_select
(
linear
.
bias
.
data
,
0
,
remained_out
)
return
new_linear
def
replace_batchnorm2d
(
norm
,
mask
):
def
replace_batchnorm1d
(
norm
,
masks
):
"""
Parameters
----------
norm : torch.nn.BatchNorm1d
The batchnorm module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.BatchNorm1d
The new batchnorm module
"""
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
norm
,
nn
.
BatchNorm1d
)
in_mask
=
in_masks
[
0
]
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
remained_in
.
size
(
0
)
==
remained_out
.
size
(
0
)
num_features
=
remained_in
.
size
(
0
)
_logger
.
info
(
"replace batchnorm1d with num_features: %d"
,
num_features
)
new_norm
=
torch
.
nn
.
BatchNorm1d
(
num_features
=
num_features
,
eps
=
norm
.
eps
,
momentum
=
norm
.
momentum
,
affine
=
norm
.
affine
,
track_running_stats
=
norm
.
track_running_stats
)
# assign weights
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
)
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
remained_in
)
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
remained_in
)
return
new_norm
def
replace_batchnorm2d
(
norm
,
masks
):
"""
Parameters
----------
norm : torch.nn.BatchNorm2d
The batchnorm module to be replace
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.BatchNorm2d
The new batchnorm module
"""
assert
isinstance
(
mask
,
ModuleMasks
)
assert
'weight'
in
mask
.
param_masks
and
'bias'
in
mask
.
param_masks
index
=
mask
.
param_masks
[
'weight'
].
mask_index
[
0
]
num_features
=
index
.
size
()[
0
]
_logger
.
debug
(
"replace batchnorm2d with num_features: %d"
,
num_features
)
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
norm
,
nn
.
BatchNorm2d
)
in_mask
=
in_masks
[
0
]
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
remained_in
.
size
(
0
)
==
remained_out
.
size
(
0
)
num_features
=
remained_in
.
size
(
0
)
_logger
.
info
(
"replace batchnorm2d with num_features: %d"
,
num_features
)
new_norm
=
torch
.
nn
.
BatchNorm2d
(
num_features
=
num_features
,
eps
=
norm
.
eps
,
momentum
=
norm
.
momentum
,
affine
=
norm
.
affine
,
track_running_stats
=
norm
.
track_running_stats
)
# assign weights
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
index
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
index
)
if
norm
.
track_running_stats
:
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
index
)
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
index
)
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
)
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
remained_in
)
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
remained_in
)
return
new_norm
def
replace_conv2d
(
conv
,
mask
):
def
replace_conv2d
(
conv
,
masks
):
"""
Replace the original conv with a new one according to the infered
masks, the function support the fine-grained sparsity and coarse-grained
sparsity. In the fine-grained scenario, this replace function will replace
the filters that happen to be totally coverd by the fine-grained sparsity.
Parameters
----------
conv : torch.nn.Conv2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.Conv2d
The new conv2d module
"""
assert
isinstance
(
mask
,
ModuleMasks
)
if
mask
.
input_mask
is
None
:
in_channels
=
conv
.
in_channels
else
:
in_channels_index
=
mask
.
input_mask
.
mask_index
[
1
]
in_channels
=
in_channels_index
.
size
()[
0
]
if
mask
.
output_mask
is
None
:
out_channels
=
conv
.
out_channels
else
:
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels
=
out_channels_index
.
size
()[
0
]
groups
=
conv
.
groups
if
conv
.
in_channels
==
conv
.
out_channels
==
conv
.
groups
:
# remove groups for depthwise layers
assert
in_channels
==
out_channels
groups
=
in_channels
_logger
.
debug
(
"replace conv2d %s with in_channels: %d, out_channels: %d"
,
mask
.
module_name
,
in_channels
,
out_channels
)
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
in_masks
,
output_mask
,
weight_masks
=
masks
assert
isinstance
(
conv
,
nn
.
Conv2d
)
# the conv layer should only have one input tensor
assert
len
(
in_masks
)
==
1
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_masks
[
'weight'
]
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
n_remained_in
=
weight_mask
.
size
(
1
)
*
conv
.
groups
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
assert
n_remained_in
==
remained_in
.
size
(
0
)
assert
n_remained_out
==
remained_out
.
size
(
0
)
k_size1
,
k_size2
=
conv
.
kernel_size
# Note: We should resolve the group dependency of the conv layers before
# run into here.
# check if the mask tensor meets the group dependency and calculate the
# new number of the groups after pruning
# the original step size of the input channel for each group
ori_inchannel_step
=
int
(
conv
.
in_channels
/
conv
.
groups
)
# the original step size of the output channel for each group
ori_outchannel_step
=
int
(
conv
.
out_channels
/
conv
.
groups
)
# calculate the new_in_channel_step and new_outchannel_step first
new_inchannel_step
=
new_outchannel_step
=
None
for
groupid
in
range
(
conv
.
groups
):
in_start
=
groupid
*
ori_inchannel_step
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
# remap the global index to the group index
if
len
(
current_input_index
)
==
0
:
# if the whole group are pruned
continue
else
:
new_inchannel_step
=
len
(
current_input_index
)
new_outchannel_step
=
len
(
current_output_index
)
break
tmp_weight
=
torch
.
ones
(
n_remained_out
,
new_inchannel_step
,
k_size1
,
k_size2
)
tmp_weight
=
tmp_weight
.
to
(
conv
.
weight
.
device
)
assert
n_remained_in
%
new_inchannel_step
==
0
assert
n_remained_out
%
new_outchannel_step
==
0
new_groups
=
0
for
groupid
in
range
(
conv
.
groups
):
in_start
=
groupid
*
ori_inchannel_step
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
# remap the global index to the group index
current_input_index
=
[
x
-
in_start
for
x
in
current_input_index
]
if
len
(
current_input_index
)
==
0
:
# if the whole group are pruned
assert
len
(
current_output_index
)
==
0
continue
# check if the number of remained channel of each group are the same
assert
len
(
current_input_index
)
==
new_inchannel_step
assert
len
(
current_output_index
)
==
new_outchannel_step
# copy the weight into tmp_weight
new_out_start
=
new_outchannel_step
*
new_groups
new_out_end
=
new_out_start
+
new_outchannel_step
tmp_weight
[
new_out_start
:
new_out_end
]
=
torch
.
index_select
(
conv
.
weight
[
current_output_index
],
1
,
torch
.
as_tensor
(
current_input_index
,
dtype
=
torch
.
long
).
to
(
conv
.
weight
.
device
))
new_groups
+=
1
_logger
.
debug
(
"replace conv2d with in_channels: %d, out_channels: %d"
,
n_remained_in
,
n_remained_out
)
# need_bias is a flag that indicates that if a conv layer need
# bias, if the original conv doesn't have a bias and there is
# no constant need to be folded into the bias, the need_bias is False.
need_bias
=
conv
.
bias
is
not
None
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
n_remained_in
,
out_channels
=
n_remained_out
,
kernel_size
=
conv
.
kernel_size
,
stride
=
conv
.
stride
,
padding
=
conv
.
padding
,
dilation
=
conv
.
dilation
,
groups
=
groups
,
bias
=
conv
.
bias
is
not
None
,
groups
=
new_
groups
,
bias
=
need_bias
,
padding_mode
=
conv
.
padding_mode
)
new_conv
.
to
(
conv
.
weight
.
device
)
tmp_weight_data
=
tmp_bias_data
=
None
if
mask
.
output_mask
is
not
None
:
tmp_weight_data
=
torch
.
index_select
(
conv
.
weight
.
data
,
0
,
out_channels_index
)
if
conv
.
bias
is
not
None
:
tmp_bias_data
=
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
out_channels_index
)
else
:
tmp_weight_data
=
conv
.
weight
.
data
# For the convolutional layers that have more than one group
# we need to copy the weight group by group, because the input
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step
=
int
(
conv
.
in_channels
/
conv
.
groups
)
in_channels_group
=
int
(
in_channels
/
groups
)
filter_step
=
int
(
out_channels
/
groups
)
if
mask
.
input_mask
is
not
None
and
not
(
in_channels
==
out_channels
==
groups
):
for
groupid
in
range
(
conv
.
groups
):
start
=
groupid
*
input_step
end
=
(
groupid
+
1
)
*
input_step
current_input_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
in_channels_index
.
tolist
()))
if
not
current_input_index
:
# there is no kept channel in current group
# TODO bug here, the groups is directly get from conv.groups, if the whole group is removed,
# then the number of groups in the new_conv also need to change
raise
Exception
(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily"
)
# shift the global index into the group index
current_input_index
=
[
x
-
start
for
x
in
current_input_index
]
# if the groups is larger than 1, the input channels of each
# group should be pruned evenly.
assert
len
(
current_input_index
)
==
in_channels_group
,
\
'Input channels of each group are not pruned evenly'
current_input_index
=
torch
.
tensor
(
current_input_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
f_start
=
groupid
*
filter_step
f_end
=
(
groupid
+
1
)
*
filter_step
new_conv
.
weight
.
data
[
f_start
:
f_end
]
=
torch
.
index_select
(
tmp_weight_data
[
f_start
:
f_end
],
1
,
current_input_index
)
else
:
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
new_conv
.
weight
.
copy_
(
tmp_weight
)
# copy the bias data
if
conv
.
bias
is
not
None
:
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
new_conv
.
bias
.
data
.
copy_
(
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
remained_out
))
return
new_conv
def
replace_convtranspose2d
(
convtrans
,
mask
):
def
replace_convtranspose2d
(
convtrans
,
mask
s
):
"""
We need anothor replace function for
convtranspose2d, because the layout of
...
...
@@ -228,81 +400,120 @@ def replace_convtranspose2d(convtrans, mask):
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
assert
isinstance
(
mask
,
ModuleM
asks
)
in_masks
,
output_mask
,
weight_masks
=
m
asks
assert
isinstance
(
convtrans
,
torch
.
nn
.
ConvTranspose2d
)
if
mask
.
input_mask
is
None
:
in_channels
=
convtrans
.
in_channels
else
:
in_channels_index
=
mask
.
input_mask
.
mask_index
[
1
]
in_channels
=
in_channels_index
.
size
(
0
)
if
mask
.
output_mask
is
None
:
out_channels
=
convtrans
.
out_channels
else
:
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels
=
out_channels_index
.
size
(
0
)
groups
=
convtrans
.
groups
# check if can remove the whole group of filters
if
convtrans
.
in_channels
==
convtrans
.
out_channels
==
convtrans
.
groups
:
# remove groups for depthwise layers
# this needs the group dependency to be fixed before the speedup
assert
in_channels
==
out_channels
groups
=
in_channels
_logger
.
debug
(
'Replace convtranspose2d %s with in_channels:%d out_channels:%d'
,
mask
.
module_name
,
in_channels
,
out_channels
)
new_convtrans
=
torch
.
nn
.
ConvTranspose2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
assert
len
(
in_masks
)
==
1
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_masks
[
'weight'
]
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
# ConvTranspose2d has the weight shape of [N_in, N_out/groups, k1, k2]
n_remained_in
=
weight_mask
.
size
(
0
)
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
1
)
*
convtrans
.
groups
-
pruned_out
.
size
(
0
)
assert
n_remained_in
==
remained_in
.
size
(
0
)
assert
n_remained_out
==
remained_out
.
size
(
0
)
k_size1
,
k_size2
=
convtrans
.
kernel_size
# Note: we should resolve the group dependency of the convtrans layers before
# run into this function
ori_inchannel_step
=
int
(
convtrans
.
in_channels
/
convtrans
.
groups
)
ori_outchannel_step
=
int
(
convtrans
.
out_channels
/
convtrans
.
groups
)
new_inchannel_step
=
new_outchannel_step
=
None
for
groupid
in
range
(
convtrans
.
groups
):
in_start
=
groupid
*
ori_inchannel_step
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
if
len
(
current_input_index
)
==
0
:
# if the whole group are pruned
continue
else
:
new_inchannel_step
=
len
(
current_input_index
)
new_outchannel_step
=
len
(
current_output_index
)
break
tmp_weight
=
torch
.
ones
(
n_remained_in
,
new_outchannel_step
,
k_size1
,
k_size2
)
tmp_weight
=
tmp_weight
.
to
(
convtrans
.
weight
.
device
)
assert
n_remained_in
%
new_inchannel_step
==
0
assert
n_remained_out
%
new_outchannel_step
==
0
new_groups
=
0
for
groupid
in
range
(
convtrans
.
groups
):
# copy the weights of this group
in_start
=
groupid
*
ori_inchannel_step
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
# remap the global index to the group index
# in the convtranspose layer, the groups are on
# the output channel dimension
current_output_index
=
[
x
-
out_start
for
x
in
current_output_index
]
if
len
(
current_input_index
)
==
0
:
# if the whole group are pruned
assert
len
(
current_output_index
)
==
0
continue
# check if the number of remained channel of each group are the same
assert
len
(
current_input_index
)
==
new_inchannel_step
assert
len
(
current_output_index
)
==
new_outchannel_step
# copy the weight into tmp_weight
new_in_start
=
new_inchannel_step
*
new_groups
new_in_end
=
new_in_start
+
new_inchannel_step
tmp_weight
[
new_in_start
:
new_in_end
]
=
torch
.
index_select
(
convtrans
.
weight
[
current_input_index
],
1
,
torch
.
as_tensor
(
current_output_index
,
dtype
=
torch
.
long
).
to
(
convtrans
.
weight
.
device
))
new_groups
+=
1
_logger
.
debug
(
'Replace convtranspose2d with in_channels:%d out_channels:%d'
,
n_remained_in
,
n_remained_out
)
new_convtrans
=
torch
.
nn
.
ConvTranspose2d
(
in_channels
=
n_remained_in
,
out_channels
=
n_remained_out
,
kernel_size
=
convtrans
.
kernel_size
,
stride
=
convtrans
.
stride
,
padding
=
convtrans
.
padding
,
dilation
=
convtrans
.
dilation
,
groups
=
groups
,
groups
=
new_
groups
,
bias
=
convtrans
.
bias
is
not
None
,
padding_mode
=
convtrans
.
padding_mode
)
new_convtrans
.
to
(
convtrans
.
weight
.
device
)
tmp_weight_data
=
None
if
mask
.
input_mask
is
not
None
:
# in convtranspose2d we need to select the input channel first
tmp_weight_data
=
torch
.
index_select
(
convtrans
.
weight
.
data
,
0
,
in_channels_index
)
else
:
tmp_weight_data
=
convtrans
.
weight
.
data
# we need to handle the output channel group by group like the conv layer
out_step
=
int
(
convtrans
.
out_channels
/
convtrans
.
groups
)
out_channel_group
=
int
(
out_channels
/
groups
)
new_in_per_group
=
int
(
in_channels
/
groups
)
if
mask
.
output_mask
is
not
None
and
not
(
in_channels
==
out_channels
==
groups
):
for
groupid
in
range
(
convtrans
.
groups
):
start
=
groupid
*
out_step
end
=
(
groupid
+
1
)
*
out_step
current_output_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
out_channels_index
.
tolist
()))
# we need to shift the index into the group-wise
current_output_index
=
[
x
-
start
for
x
in
current_output_index
]
if
not
current_output_index
:
# No kept channel in the current group
raise
Exception
(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily"
)
assert
len
(
current_output_index
)
==
out_channel_group
,
\
'Output channel of each group should be the same after pruning'
current_output_index
=
torch
.
tensor
(
current_output_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
new_start
=
groupid
*
new_in_per_group
new_end
=
(
groupid
+
1
)
*
new_in_per_group
new_convtrans
.
weight
.
data
[
new_start
:
new_end
]
=
torch
.
index_select
(
tmp_weight_data
[
new_start
:
new_end
],
1
,
current_output_index
)
else
:
new_convtrans
.
weight
.
data
.
copy_
(
tmp_weight_data
)
new_convtrans
.
weight
.
copy_
(
tmp_weight
)
if
convtrans
.
bias
is
not
None
:
if
mask
.
output_mask
is
not
None
:
if
output_mask
is
not
None
:
new_convtrans
.
bias
.
data
[:]
=
torch
.
index_select
(
convtrans
.
bias
.
data
,
0
,
out_channels_index
)
convtrans
.
bias
.
data
,
0
,
remained_out
)
else
:
new_convtrans
.
bias
.
data
.
copy_
(
convtrans
.
bias
.
data
)
return
new_convtrans
def
replace_layernorm
(
layernorm
,
masks
):
in_masks
,
_
,
_
=
masks
assert
isinstance
(
layernorm
,
nn
.
LayerNorm
)
assert
len
(
in_masks
)
==
1
in_mask
=
in_masks
[
0
]
dim_n
=
len
(
in_mask
.
size
())
new_shape
=
[]
for
i
in
range
(
1
,
dim_n
):
sum_dims
=
list
(
range
(
0
,
dim_n
))
sum_dims
.
remove
(
i
)
reduced
=
torch
.
sum
(
in_mask
,
sum_dims
)
n_remained
=
torch
.
sum
(
reduced
>
0
)
new_shape
.
append
(
n_remained
)
return
nn
.
LayerNorm
(
tuple
(
new_shape
),
layernorm
.
eps
,
layernorm
.
elementwise_affine
)
nni/compression/pytorch/speedup/compressor.py
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
queue
import
logging
import
copy
import
torch
import
torch.nn
as
nn
from
nni.common.graph_utils
import
build_module_graph
from
nni.compression.pytorch.utils.mask_conflict
import
fix_mask_conflict
from
nni.compression.pytorch.utils.utils
import
get_module_by_name
from
.compress_modules
import
replace_module
from
.infer_shape
import
ModuleMasks
,
infer_from_mask
,
infer_from_inshape
,
infer_from_outshape
,
set_conv_prune_dim
from
.infer_mask
import
AutoMaskInference
from
.jit_translate
import
jit_to_python_function
from
..utils
import
rand_like_with_shape
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
class
ModelSpeedup
:
"""
This class is to speedup the model with provided weight mask
This class is to speedup the model with provided weight mask
.
"""
def
__init__
(
self
,
model
,
dummy_input
,
masks_file
,
map_location
=
None
):
def
__init__
(
self
,
model
,
dummy_input
,
masks_file
,
map_location
=
None
,
batch_dim
=
0
,
confidence
=
8
):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
dummy_input : pytorch tensor, tuple of tensor, list of tensor
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
device.
masks_file : str
The path of user provided mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int
the index of batch dimension in the dummy_input
confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input.
"""
from
nni.common.graph_utils
import
build_module_graph
assert
confidence
>
1
# The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference
self
.
ori_state_dict
=
copy
.
deepcopy
(
model
.
state_dict
())
self
.
bound_model
=
model
self
.
masks
=
torch
.
load
(
masks_file
,
map_location
)
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
dummy_input
=
dummy_input
self
.
torch_graph
=
build_module_graph
(
model
,
dummy_input
)
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
batch_dim
=
batch_dim
self
.
dummy_input
,
self
.
device
=
self
.
_random_model_input
(
dummy_input
,
confidence
,
batch_dim
)
self
.
torch_graph
=
build_module_graph
(
model
,
self
.
dummy_input
)
# dict object to save the auto inferences objects of the submodules
self
.
auto_inferences
=
{}
# the index dict to find the corresponding torch._C.Value object
# according to the debug name
# we need the dummy_input to infer the mask automaticlly, so we save
# the indexes from tensor's debugname to the torch._C.Value object.
self
.
debugname_to_value
=
{}
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
self
.
masks
=
torch
.
load
(
masks_file
,
map_location
if
map_location
is
not
None
else
str
(
self
.
device
))
def
infer_module_mask
(
self
,
module_name
,
last_module
,
mask
=
None
,
in_shape
=
None
,
out_shape
=
None
):
self
.
constant
=
{}
# self.internal_result save the internal output of the submodules
self
.
internal_result
=
{}
def
_random_model_input
(
self
,
dummy_input
,
confidence
,
batch_dim
):
"""
Infer input shape / output shape based on the module's weight mask / input shape / output shape.
Get the new random dummy input accordint to the original dummy_input
and confidence, batch_dim.
For a module:
Infer its input and output shape from its weight mask
Infer its output shape from its input shape
Infer its input shape from its output shape
Parameters
----------
dummy_input: Tensor or list/dict of Tensors
The dummy_input given by the user.
confidence: int
The new batch size of the generated dummy_input.
batch_dim: int
The index of the batch dimension.
Returns
------
new_dummy_input: Tensor or list/dict of Tensors
The generated dummy_input for mask inference.
device: torch.device
The device of the generated dummy_inputs
"""
input_errmsg
=
'Only support the tensor, list/tuple/dict of tensors as input'
# Some model may use list of tensors as input, for example transformers
new_dummy_input
,
device
=
None
,
None
if
isinstance
(
dummy_input
,
torch
.
Tensor
):
input_shape
=
list
(
dummy_input
.
size
())
# set the batchsize to the confidence ratio
input_shape
[
batch_dim
]
=
confidence
new_dummy_input
=
rand_like_with_shape
(
input_shape
,
dummy_input
)
device
=
dummy_input
.
device
elif
isinstance
(
dummy_input
,
(
tuple
,
list
)):
# else if the dummy input is list/tuple
new_dummy_input
=
[]
old_batchsize
=
dummy_input
[
0
].
size
(
0
)
device
=
dummy_input
[
0
].
device
for
_
,
t_input
in
enumerate
(
dummy_input
):
assert
isinstance
(
t_input
,
torch
.
Tensor
),
input_errmsg
assert
t_input
.
size
(
0
)
==
old_batchsize
,
'The first dimension should be batchsize
\
and the batchsize of all inputs should be the same!'
input_shape
=
list
(
t_input
.
size
())
input_shape
[
batch_dim
]
=
confidence
# rand_func = torch.randint if t_input.dtype
new_dummy_input
.
append
(
rand_like_with_shape
(
input_shape
,
t_input
))
elif
isinstance
(
dummy_input
,
dict
):
new_dummy_input
=
{}
tmp_key
=
list
(
dummy_input
.
keys
())[
0
]
old_batchsize
=
dummy_input
[
tmp_key
].
size
(
0
)
device
=
dummy_input
[
tmp_key
].
device
for
in_name
,
t_input
in
dummy_input
.
items
():
assert
isinstance
(
t_input
,
torch
.
Tensor
),
input_errmsg
assert
old_batchsize
==
t_input
.
size
(
0
),
'The first dimension should be batchsize
\
and the batchsize of all inputs should be the same!'
input_shape
=
list
(
t_input
.
size
())
input_shape
[
batch_dim
]
=
confidence
new_dummy_input
[
in_name
]
=
rand_like_with_shape
(
input_shape
,
t_input
)
else
:
raise
TypeError
(
input_errmsg
)
return
new_dummy_input
,
device
If its input shape is changed, continue infering its predecessors
If its output shape is changed, continue infering its successors
def
_prepare_dummy_input
(
self
,
node
):
"""
Prepare the dummy_input for the auto mask inference.
Parameters
----------
module_name : str
The name of the node
last_module : str
The name of last visited node
mask : tensor of mask or ModuleMasks
Mask of the weights in this node (i.e., module)
in_shape : ModuleMasks
Input shape of this node
out_shape : ModuleMasks
Output shape of this node
"""
input_cmask
=
output_cmask
=
None
if
module_name
in
self
.
inferred_masks
:
module_masks
=
self
.
inferred_masks
[
module_name
]
node: NodePyGroup
Returns
-------
dummy_input: list
List of tensors that will be used as input for the target node.
debugnames: list of strs
Debugnames of the dummy_inputs.
"""
_logger
.
debug
(
'Prepare auto mask inference for node: %s'
,
node
.
unique_name
)
# prepare the inputs and outputs mask for this node,
# if there is already a mask in self.masks, then use
# the original mask tensor, else create a new one.
inputs_name
=
node
.
inputs
# build the dummy_input, in_masks the target node
dummy_input
=
[]
debugnames
=
[]
for
_input
in
inputs_name
:
if
_input
not
in
self
.
internal_result
:
# if the input debug name is not in self.internal_result,
# then this node isn't a output tensor of any predecessor
# nodes. This node is a attribute of the submodule, such as
# weight or bias, etc. We will skip these tensors.
# If we don't want this specific judgement here, we can merge
# the `prim::GetAttr` node of the weight/bias tensor into the key
# node, such as `conv`.
# This is caused by the `meage_module_node` function in the
# _graph_utils.py, because it doesn't merge the prim::GetAttr
# node into the key node. In current version of _graph_utils.py,
# we will only merge the nodes that have same scope name, however,
# the scope name of the correponding prim::GetAttr node of `weight` tensor
# is None.
continue
# The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator.
dummy_input
.
append
(
self
.
internal_result
[
_input
].
detach
())
debugnames
.
append
(
_input
)
return
dummy_input
,
debugnames
def
update_direct_sparsity
(
self
,
node
):
"""
Update the direct sparsity for the target node. Here the direct sparsity
means that the sparsity in the output tensor that caused by the sparsity
in the input tensors/weight tensors.
"""
# this name is consistent with the name returned by named_modules()
module_name
=
node
.
name
_logger
.
info
(
'Update mask for %s'
,
module_name
)
unique_name
=
node
.
unique_name
dummy_input
,
input_debugname
=
self
.
_prepare_dummy_input
(
node
)
# get the input mask from self.masks
# Note: the input mask of the successor nodes are
# already created by the predecessor node
in_masks
=
[
self
.
masks
[
debugname
]
for
debugname
in
input_debugname
]
in_constants
=
[
self
.
constant
[
debugname
]
for
debugname
in
input_debugname
]
if
node
.
type
==
'func'
:
# we cannot get the runable function directly from the jit traced
# graph, so we translate it back to python function, Note: the function
# is appliable to both cpu/gpu devices, the output tensors will be on the
# same device of the input tensors
func
=
jit_to_python_function
(
node
,
self
)
if
func
is
None
:
# no need to infer the sparsity for this node
self
.
auto_inferences
[
unique_name
]
=
None
return
# function doesn't have weights
_auto_infer
=
AutoMaskInference
(
func
,
dummy_input
,
in_masks
,
in_constants
=
in_constants
,
batch_dim
=
self
.
batch_dim
)
else
:
_
,
m
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
module_masks
=
ModuleMasks
(
module_name
,
m
)
self
.
inferred_masks
[
module_name
]
=
module_masks
m_type
=
self
.
torch_graph
.
name_to_node
[
module_name
].
op_type
_logger
.
debug
(
"infer mask of module %s with op_type %s"
,
module_name
,
m_type
)
if
mask
is
not
None
:
_logger
.
debug
(
"mask is not None"
)
if
not
m_type
in
infer_from_mask
:
raise
RuntimeError
(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.
format
(
m_type
,
module_name
))
if
m_type
in
[
'Linear'
]:
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
,
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
)
else
:
input_cmask
,
output_cmask
=
infer_from_mask
[
m_type
](
module_masks
,
mask
)
if
in_shape
is
not
None
:
_logger
.
debug
(
"in_shape is not None"
)
if
not
m_type
in
infer_from_inshape
:
raise
RuntimeError
(
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
.
format
(
m_type
,
module_name
))
if
m_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
]:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
,
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
)
elif
m_type
in
[
'aten::cat'
]:
# To calculate the mask for concat operation, the output shape
# , cat dimension, and the order of the input parameters.
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
,
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
,
last_module
)
else
:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
if
out_shape
is
not
None
:
_logger
.
debug
(
"out_shape is not None"
)
if
not
m_type
in
infer_from_outshape
:
raise
RuntimeError
(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.
format
(
m_type
,
module_name
))
if
m_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
]:
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
,
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
)
else
:
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
weight_mask
=
None
if
module_name
in
self
.
masks
:
weight_mask
=
self
.
masks
[
module_name
]
_
,
module
=
get_module_by_name
(
self
.
bound_model
,
module_name
)
_auto_infer
=
AutoMaskInference
(
module
,
dummy_input
,
in_masks
,
weight_mask
,
in_constants
=
in_constants
,
state_dict
=
copy
.
deepcopy
(
module
.
state_dict
()),
batch_dim
=
self
.
batch_dim
)
self
.
auto_inferences
[
unique_name
]
=
_auto_infer
_auto_infer
.
name
=
node
.
unique_name
_auto_infer
.
update_direct_sparsity
()
# also save the input debug names into the auto_infer
_auto_infer
.
input_debugname
=
input_debugname
# update the mask tensor and the internal output of the submodules
# after manually unpack the tuple/list of tensors, the number of the outputs
# of each node should always be one(Except for the TupleUnpack node at the end
# of the whole model)
assert
len
(
node
.
outputs
)
==
1
,
'The number of the output should be one after the Tuple unpacked manually'
out_debugname
=
node
.
outputs
[
0
]
# update the output mask into self.masks
self
.
masks
[
out_debugname
]
=
_auto_infer
.
output_mask
self
.
constant
[
out_debugname
]
=
_auto_infer
.
out_constant
# update the output result into self.internal_result, so that
# the successor nodes can take these output tensors as inputs.
self
.
internal_result
[
out_debugname
]
=
_auto_infer
.
output
# update the parameter mask of the node
if
input_cmask
:
predecessors
=
self
.
torch_graph
.
find_predecessors
(
module_name
)
for
_module_name
in
predecessors
:
self
.
infer_module_mask
(
_module_name
,
module_name
,
out_shape
=
input_cmask
)
if
output_cmask
:
successors
=
self
.
torch_graph
.
find_successors
(
module_name
)
for
_module_name
in
successors
:
self
.
infer_module_mask
(
_module_name
,
module_name
,
in_shape
=
output_cmask
)
self
.
masks
[
module_name
]
=
_auto_infer
.
weight_mask
def
update_indirect_sparsity
(
self
,
node
):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
Parameters
---------
node: the NodePy
The target node to update the indirect sparsity
"""
module_name
=
node
.
name
_logger
.
info
(
'Update indirect sparsity for %s'
,
module_name
)
unique_name
=
node
.
unique_name
if
unique_name
in
self
.
auto_inferences
and
self
.
auto_inferences
[
unique_name
]
is
not
None
:
# if the auto inference object already in self.auto_inference, then
# directly update the previous one
# self.auto_inferences[unique_name].update()
_logger
.
info
(
'Update the indirect sparsity for the %s'
,
unique_name
)
auto_infer
=
self
.
auto_inferences
[
unique_name
]
auto_infer
.
update_indirect_sparsity
()
# pass the gradient to the predecessor nodes
for
in_id
,
tin
in
enumerate
(
auto_infer
.
dummy_input
):
debug_name
=
auto_infer
.
input_debugname
[
in_id
]
last_output
=
self
.
internal_result
[
debug_name
]
# if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor
if
last_output
.
grad
is
not
None
and
tin
.
grad
is
not
None
:
last_output
.
grad
.
data
+=
tin
.
grad
.
data
else
:
last_output
.
grad
=
tin
.
grad
else
:
_logger
.
warning
(
'Note: %s does not have corresponding mask inference object'
,
node
.
name
)
def
_vnode_to_value
(
self
,
c_node
):
"""
translate the C Value node into the values/tensors.
"""
errmsg
=
"Only support the torch._C.Value type"
assert
isinstance
(
c_node
,
torch
.
_C
.
Value
),
errmsg
if
isinstance
(
c_node
.
type
(),
torch
.
_C
.
TensorType
):
shape
=
tuple
(
c_node
.
type
().
sizes
())
dtype
=
c_node
.
type
().
scalarType
()
# TODO should use a more general way to get the input
if
dtype
.
startswith
(
'Float'
)
or
dtype
.
startswith
(
'Double'
):
return
torch
.
rand
(
shape
).
to
(
self
.
device
)
else
:
# This small range is due to the ·ReLU6·, we will add
# the manual specific mask inference rule for several
# ops in the future, so that we can remove the constraint.
return
torch
.
randint
(
0
,
10
,
shape
,
device
=
self
.
device
)
else
:
value
=
c_node
.
toIValue
()
# TODO support more kinds of value node
errmsg
=
"Doesn't support convert %s to values"
,
str
(
c_node
.
type
())
# currently only support the tensors and constant values
assert
value
is
not
None
,
errmsg
return
value
def
infer_modules_masks
(
self
):
"""
Do shape inference of involved modules, including the shape of weights, inputs, output
"""
for
module_name
,
mask
in
self
.
masks
.
items
():
_logger
.
debug
(
'Start mask inference from %s'
,
module_name
)
if
module_name
not
in
self
.
torch_graph
.
name_to_node
:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
_logger
.
warning
(
'%s has mask, but not found in the traced graph, just skip it.'
,
module_name
)
continue
self
.
infer_module_mask
(
module_name
,
None
,
mask
=
mask
)
Infer the mask for all layers in the module, this function can be divided into
two steps: first, forward inference of the the masks. Second, backward inference
of the mask. We keep repeating these two steps until the masks of the model doesn't
change.
"""
# unpack the tensor tuple/list before the mask inference
self
.
torch_graph
.
unpack_manually
()
# find the input/ouput tensor of the whole graph
graph_input
=
[]
graph_output
=
[]
for
name
,
nodeio
in
self
.
torch_graph
.
nodes_py
.
nodes_io
.
items
():
if
nodeio
.
input_or_output
==
'input'
:
graph_input
.
append
((
name
,
nodeio
))
# also put the graph input tensor into the internal_result
# TODO if we can find the corresponding relation between the value node
# and the dummy_inputs, we can use the inputs value in the dummy_input
value
=
self
.
_vnode_to_value
(
self
.
debugname_to_value
[
name
])
self
.
internal_result
[
name
]
=
value
# create the mask tensor for the input value
if
isinstance
(
self
.
internal_result
[
name
],
torch
.
Tensor
):
self
.
masks
[
name
]
=
torch
.
ones_like
(
value
)
self
.
constant
[
name
]
=
torch
.
zeros_like
(
value
)
elif
nodeio
.
input_or_output
==
'output'
:
graph_output
.
append
((
name
,
nodeio
))
# count the degree for the node in the graph
in_degree
=
{}
out_degree
=
{}
visit_queue
=
queue
.
Queue
()
for
node
in
self
.
torch_graph
.
nodes_py
.
nodes_op
:
successors
=
self
.
torch_graph
.
find_successors
(
node
.
unique_name
)
out_degree
[
node
.
unique_name
]
=
len
(
successors
)
predecessors
=
self
.
torch_graph
.
find_predecessors
(
node
.
unique_name
)
in_degree
[
node
.
unique_name
]
=
len
(
predecessors
)
if
in_degree
[
node
.
unique_name
]
==
0
:
visit_queue
.
put
(
node
)
# Forward mask inference
while
not
visit_queue
.
empty
():
curnode
=
visit_queue
.
get
()
# forward mask inference for curnode
self
.
update_direct_sparsity
(
curnode
)
successors
=
self
.
torch_graph
.
find_successors
(
curnode
.
unique_name
)
for
successor
in
successors
:
in_degree
[
successor
]
-=
1
if
in_degree
[
successor
]
==
0
:
visit_queue
.
put
(
self
.
torch_graph
.
name_to_node
[
successor
])
# backward mask inference
for
unique_name
in
out_degree
:
if
out_degree
[
unique_name
]
==
0
:
visit_queue
.
put
(
self
.
torch_graph
.
name_to_node
[
unique_name
])
while
not
visit_queue
.
empty
():
curnode
=
visit_queue
.
get
()
self
.
update_indirect_sparsity
(
curnode
)
predecessors
=
self
.
torch_graph
.
find_predecessors
(
curnode
.
unique_name
)
for
predecessor
in
predecessors
:
out_degree
[
predecessor
]
-=
1
if
out_degree
[
predecessor
]
==
0
:
visit_queue
.
put
(
self
.
torch_graph
.
name_to_node
[
predecessor
])
def
replace_compressed_modules
(
self
):
"""
...
...
@@ -148,40 +377,138 @@ class ModelSpeedup:
NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
is that ```func``` should be not required to be replaced.
"""
for
module_name
in
self
.
inferred_masks
:
g_node
=
self
.
torch_graph
.
name_to_node
[
module_name
]
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
module_name
,
g_node
.
type
,
g_node
.
op_type
)
if
g_node
.
type
==
'module'
:
super_module
,
leaf_module
=
get_module_by_name
(
self
.
bound_model
,
g_node
.
name
)
m_type
=
g_node
.
op_type
if
not
m_type
in
replace_module
:
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
g_node
.
name
,
m_type
)
compressed_module
=
replace_module
[
m_type
](
leaf_module
,
self
.
inferred_masks
[
module_name
])
setattr
(
super_module
,
g_node
.
name
.
split
(
'.'
)[
-
1
],
compressed_module
)
elif
g_node
.
type
==
'func'
:
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
module_name
,
g_node
.
op_type
)
else
:
raise
RuntimeError
(
"Unsupported node type: {}"
.
format
(
g_node
.
type
))
with
torch
.
no_grad
():
for
unique_name
in
self
.
auto_inferences
:
self
.
replace_submodule
(
unique_name
)
def
replace_submodule
(
self
,
unique_name
,
reindex_dim
=
None
,
reindex
=
None
):
"""
Replace the submodule according to the inferred sparsity.
unique_name: str
The unique_name of the submodule to replace.
reindex_dim: int
The dimension of the re-index operation.
reindex: Reindex
The index tensor. Normally this variable is None. If we want to reindex the
output of this submodule, we can pass the index by this parameter.
"""
class
ReindexModule
(
nn
.
Module
):
"""
ReindexModule is used to resolve the mask conflict when replace the submodule.
Basically, we can use two ways to resolve the mask conflict: (1) unmask some
values(will introduce more computation overhead) (2) reindex and padd the output
tensor of the target op(introduce more memory access overhad). Currently this
method is shutdown, in the future, we will merge these two methods into a graph
pass which is used to resolve the mask conflict.
"""
def
__init__
(
self
,
ori_module
,
reindex_dim
,
reindex
):
super
(
ReindexModule
,
self
).
__init__
()
self
.
ori_module
=
ori_module
self
.
reindex_dim
=
reindex_dim
self
.
reindex
=
reindex
tmp_index
=
[
slice
(
None
,
None
)
for
i
in
range
(
reindex_dim
+
1
)]
# the index for the tensor
tmp_index
[
reindex_dim
]
=
reindex
self
.
t_index
=
tuple
(
tmp_index
)
def
forward
(
self
,
x
):
tmpout
=
self
.
ori_module
(
x
)
shape
=
list
(
tmpout
.
size
())
shape
[
self
.
reindex_dim
]
=
self
.
reindex
.
size
(
0
)
out
=
torch
.
zeros
(
tuple
(
shape
),
device
=
tmpout
.
device
,
requires_grad
=
tmpout
.
requires_grad
)
out
[
self
.
t_index
]
=
tmpout
return
out
assert
unique_name
in
self
.
auto_inferences
g_node
=
self
.
torch_graph
.
name_to_node
[
unique_name
]
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
unique_name
,
g_node
.
type
,
g_node
.
op_type
)
auto_infer
=
self
.
auto_inferences
[
unique_name
]
if
g_node
.
type
==
'module'
:
if
g_node
.
unique_name
in
self
.
torch_graph
.
reused_module
:
if
reindex_dim
is
not
None
:
_logger
.
warning
(
'Cannot replace a reused module with padding operator!!'
)
return
None
super_module
,
leaf_module
=
get_module_by_name
(
self
.
bound_model
,
g_node
.
name
)
m_type
=
g_node
.
op_type
if
not
m_type
in
replace_module
:
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
g_node
.
name
,
m_type
)
compressed_module
=
replace_module
[
m_type
](
leaf_module
,
auto_infer
.
get_masks
())
new_submodule
=
compressed_module
if
reindex_dim
is
None
:
setattr
(
super_module
,
g_node
.
name
.
split
(
'.'
)[
-
1
],
compressed_module
)
elif
reindex_dim
is
not
None
and
reindex
is
not
None
:
# reindex the output of this submodule and replace the orginal module
new_submodule
=
ReindexModule
(
compressed_module
,
reindex_dim
,
reindex
)
setattr
(
super_module
,
g_node
.
name
.
split
(
'.'
)[
-
1
],
new_submodule
)
return
new_submodule
elif
g_node
.
type
==
'func'
:
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
unique_name
,
g_node
.
op_type
)
return
None
else
:
raise
RuntimeError
(
"Unsupported node type: {}"
.
format
(
g_node
.
type
))
def
initialize_speedup
(
self
):
"""
Do some initial work for speedup.
"""
# initialize the self.debugname_to_value
# build a mapping table from the debug name of the tensor
# to its value node in the graph
traced_graph
=
self
.
torch_graph
.
trace
.
graph
for
node
in
traced_graph
.
nodes
():
for
_input
in
node
.
inputs
():
debug_name
=
_input
.
debugName
()
if
debug_name
not
in
self
.
debugname_to_value
:
self
.
debugname_to_value
[
debug_name
]
=
_input
for
_output
in
node
.
outputs
():
debug_name
=
_output
.
debugName
()
if
debug_name
not
in
self
.
debugname_to_value
:
self
.
debugname_to_value
[
debug_name
]
=
_output
# put the model itself into internel_result to perform the
# value inference for the 'prim::GetAttr', the first ClassType
# of the whole graph is the model class
for
graph_input
in
traced_graph
.
inputs
():
if
graph_input
.
type
().
kind
()
==
'ClassType'
:
self
.
internal_result
[
graph_input
.
debugName
()
]
=
self
.
bound_model
break
def
speedup_model
(
self
):
"""
There are basically two steps:
first, do mask/shape inference,
second, replace modules
There are basically two steps: first, do mask/shape inference,
second, replace modules.
"""
training
=
self
.
bound_model
.
training
_logger
.
info
(
"start to speed up the model"
)
_logger
.
info
(
"fix the mask conflict of the interdependent layers"
)
_
,
conv_prune_dim
=
fix_mask_conflict
(
self
.
masks
,
self
.
bound_model
,
self
.
dummy_input
)
set_conv_prune_dim
(
conv_prune_dim
)
_logger
.
info
(
"start to speed up the model"
)
self
.
initialize_speedup
()
training
=
self
.
bound_model
.
training
# set to the evaluation mode
self
.
bound_model
.
train
(
False
)
# TODO suppose to fix the conflict after the sparsity propagation
# which is more elegent
fix_mask_conflict
(
self
.
masks
,
self
.
bound_model
,
self
.
dummy_input
)
_logger
.
info
(
"infer module masks..."
)
self
.
infer_modules_masks
()
_logger
.
info
(
'resolve the mask conflict'
)
# load the original stat dict before replace the model
self
.
bound_model
.
load_state_dict
(
self
.
ori_state_dict
)
_logger
.
info
(
"replace compressed modules..."
)
# the mask conflict should be already resolved
self
.
replace_compressed_modules
()
self
.
bound_model
.
train
(
training
)
_logger
.
info
(
"speedup done"
)
nni/compression/pytorch/speedup/infer_mask.py
0 → 100644
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
import
torch.nn
as
nn
from
..utils
import
randomize_tensor
,
torch_float_dtype
,
torch_integer_dtype
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
STD_DELTA
=
1e-6
class
AutoMaskInference
:
def
__init__
(
self
,
module
,
dummy_input
,
in_masks
=
None
,
weight_mask
=
None
,
\
output_mask
=
None
,
name
=
None
,
in_constants
=
None
,
state_dict
=
None
,
batch_dim
=
0
):
"""
This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according
to the input masks, in constrast, update_indirect_sparsity will
infer the input masks according to given output masks. The newly
found sparsity will be incrementally updated to the original in_masks
and output_mask.
Parameters
----------
module: torch.nn.Module/function
The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module.
in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally
update the given in_masks, else, AutoMaskInference will create a new
in_masks for the target module.
output_mask: torch.Tensor
The output mask of the target module. Similar to in_masks, if output_mask
is not None, then update_direct_sparsity and update_indirect_sparsity will
incrementally update the given output_mask, else AutoMaskInference will create
one output_mask for the target module.
weight_mask: dict of the weight masks
The weight masks of the target module, the key is the corresponding name of
the mask. For example: {'weight':torch.ones(1000, 1000), bias:torch.ones(1000)}
name: str
Name of the target module.
in_constants: list of torch.Tensor
The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor
The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
"""
errmsg
=
'%s is not callable, should pass the nn.Module/function'
%
str
(
module
)
assert
callable
(
module
),
errmsg
self
.
module
=
module
# Initialize the dummy_input
if
isinstance
(
dummy_input
,
list
):
# if there are multiple input variables
self
.
dummy_input
=
dummy_input
else
:
# if there is only one input variable
self
.
dummy_input
=
[
dummy_input
]
# Initialize the masks for input tensors
self
.
in_masks
=
in_masks
if
in_masks
is
not
None
else
[
None
]
*
len
(
self
.
dummy_input
)
self
.
in_constants
=
in_constants
if
in_constants
is
not
None
else
[
torch
.
zeros_like
(
x
)
for
x
in
dummy_input
]
for
in_id
,
_
in
enumerate
(
self
.
in_masks
):
if
self
.
in_masks
[
in_id
]
is
None
and
\
isinstance
(
self
.
dummy_input
[
in_id
],
torch
.
Tensor
):
# if the input mask is None then create a all-ones mask for corresponding input tensor
self
.
in_masks
[
in_id
]
=
torch
.
ones_like
(
self
.
dummy_input
[
in_id
])
# ones_like will put the created mask on the same device with the dummy_input
# Initialize the mask for output tensors
self
.
output
=
self
.
module
(
*
dummy_input
)
# self.output.requires_grad_()
if
output_mask
is
not
None
:
# assume the given output mask is right
self
.
output_mask
=
output_mask
else
:
if
isinstance
(
self
.
output
,
torch
.
Tensor
):
self
.
output_mask
=
torch
.
ones_like
(
self
.
output
)
elif
isinstance
(
self
.
output
,
list
)
or
isinstance
(
self
.
output
,
tuple
):
self
.
output_mask
=
[]
for
o_tensor
in
self
.
output
:
if
isinstance
(
o_tensor
,
torch
.
Tensor
):
self
.
output_mask
.
append
(
torch
.
ones_like
(
o_tensor
))
else
:
# if one of the outputs is not tensor, set the corresponding
# mask to None
self
.
output_mask
.
append
(
None
)
else
:
self
.
output_mask
=
None
# Initialize the mask for the parameters
self
.
weights
=
{}
self
.
weight_mask
=
{}
if
weight_mask
:
self
.
weight_mask
.
update
(
weight_mask
)
if
isinstance
(
self
.
module
,
nn
.
Module
):
# the function should not has parameters
# get all the parameter tensors of the target module
for
name
,
para
in
module
.
named_parameters
():
self
.
weights
[
name
]
=
para
if
name
not
in
self
.
weight_mask
:
self
.
weight_mask
[
name
]
=
torch
.
ones_like
(
para
.
data
)
self
.
name
=
name
self
.
state_dict
=
state_dict
# TODO support the other batch dimension in the future
self
.
batch_dim
=
batch_dim
def
random_init
(
self
,
start
=
0.1
,
end
=
8.0
):
"""
Random initialize the weights of the module. The value of
the tensor will not affect the mask auto inference.
"""
# currently we set the random range to 0.1-8.0 because of the ReLU6,
# if we use a range that far larger than 6, it may infer a wrong mask
# when the confidence is low. In the future, we will add the mask inference
# rules for ReLU6 to break this range constraint.
with
torch
.
no_grad
():
for
tensor
in
self
.
dummy_input
:
if
isinstance
(
tensor
,
torch
.
Tensor
)
and
len
(
tensor
.
size
())
>
0
:
# if the tensor is a scalar, then skip this tensor
randomize_tensor
(
tensor
,
start
,
end
)
for
para
in
self
.
weights
:
randomize_tensor
(
self
.
weights
[
para
].
data
,
start
,
end
)
def
zero_grad
(
self
):
"""
Set the gradient of the weight, input tensor to be zeros.
"""
with
torch
.
no_grad
():
# set the weight's gradient to zero
if
isinstance
(
self
.
module
,
nn
.
Module
):
self
.
module
.
zero_grad
()
# also zero the gradient of the input tensors
for
tensor
in
self
.
dummy_input
:
if
isinstance
(
tensor
,
torch
.
Tensor
):
if
tensor
.
grad
is
not
None
:
tensor
.
grad
.
data
.
zero_
()
def
requires_grad_
(
self
,
flag
=
True
):
"""
Set the requires_grad of input tensor and parameters to flag.
"""
for
t_in
in
self
.
dummy_input
:
if
isinstance
(
t_in
,
torch
.
Tensor
)
and
t_in
.
dtype
in
torch_float_dtype
:
# only float type can require the gradient
# enable the auto gradient
t_in
.
requires_grad_
(
flag
)
for
para_name
in
self
.
weights
:
if
self
.
weights
[
para_name
].
dtype
in
torch_float_dtype
:
self
.
weights
[
para_name
].
requires_grad_
(
flag
)
def
apply_mask
(
self
):
self
.
__apply_input_mask
()
self
.
__apply_weight_mask
()
def
__apply_input_mask
(
self
):
"""
Apply the mask of the input tensor.
"""
with
torch
.
no_grad
():
# apply the input mask
for
tid
,
in_tensor
in
enumerate
(
self
.
dummy_input
):
if
isinstance
(
in_tensor
,
torch
.
Tensor
)
and
self
.
in_masks
[
tid
]
is
not
None
:
in_tensor
.
data
=
in_tensor
.
data
*
\
self
.
in_masks
[
tid
]
+
\
(
1
-
self
.
in_masks
[
tid
])
*
self
.
in_constants
[
tid
]
def
__apply_weight_mask
(
self
):
"""
Apply the weight mask of this module.
"""
with
torch
.
no_grad
():
# apply the weight mask
for
para
in
self
.
weights
:
if
para
in
self
.
weight_mask
:
self
.
weights
[
para
].
data
*=
self
.
weight_mask
[
para
].
data
def
isconstants
(
self
,
tout
):
"""
Find the constants in the tensor tout. This function return a mask tensor that
indicates if a value in tout is a constant, and return one more tensor to indicate
that the values of the constant.
Paramters
---------
tout: torch.Tensor
The target output tensor to find the constants
Returns
-------
mask: torch.Tensor
The mask tensor(same shape with tout) that indicates that whether
the correponding value is a constant.
constant: torch.Tensor
The mask tensot(same shape with tout) that indicates the values of
the constants in the tout.
"""
assert
isinstance
(
tout
,
torch
.
Tensor
)
out_mask
=
torch
.
ones_like
(
tout
)
constant
=
torch
.
zeros_like
(
tout
)
# judge if tout is a scalar(tensor that only have one value)
if
len
(
tout
.
size
())
==
0
:
# tout is a scalar tensor, for the scalar tensor, we take
# this scalar as a constant, usually, the scalar tensor is returned
# by the size() function
constant
=
tout
return
out_mask
,
constant
if
tout
.
dtype
in
torch_integer_dtype
:
# Pytorch cannot use torch.mean and torch.std to process
# intergers :( , so if dtype of the input tensor is integer, we need
# check if is the constant by ourselves
# Note: the first dimension should be the batch dimension
same
=
tout
[:]
==
tout
[
0
]
reduced
=
torch
.
sum
(
same
,
dim
=
0
)
is_constant
=
reduced
==
tout
.
size
(
0
)
out_mask
[:,
is_constant
]
=
0
constant
[:,
is_constant
]
=
tout
[
0
][
is_constant
]
else
:
# calculate the std of the output among batch dimension
std
=
torch
.
std
(
tout
,
dim
=
0
)
# calculate the mean value of the output among the batch dimension
mean
=
torch
.
mean
(
tout
,
dim
=
0
)
mask_pos
=
std
<
STD_DELTA
out_mask
[:,
mask_pos
]
=
0
constant
[:,
mask_pos
]
=
mean
[
mask_pos
]
return
out_mask
,
constant
def
update_indirect_sparsity
(
self
):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
"""
# Each node only update the output mask when we backwards
# update the output mask, this is because that some op may
# have the broadcast operation, for example, OP A's output
# tensor may be taken by two OPs(B, C) as inputs. So we cannot
# directly update the input mask at the OP B or C. We can only
# update the mask of C's output tensor only when B and C are
# already updated(gradient are already calculated and added to
# C's output tensor).
# Besides, updating the mask of C's output tensor equals to updating
# the input mask of OP B and C.
if
isinstance
(
self
.
output
,
torch
.
Tensor
)
and
self
.
output
.
grad
is
not
None
:
# if output have gradient which means this node has successor
# nodes and the successor nodes have already update their indirect
# sparsity
# we can mask the values whose gradient is always zeros
gradient_sum
=
torch
.
sum
(
torch
.
abs
(
self
.
output
.
grad
.
data
),
dim
=
0
)
_grad_zero
=
gradient_sum
==
0
for
batchid
in
range
(
self
.
output
.
size
(
0
)):
# set the same mask value for the whole batche
self
.
output_mask
[
batchid
][
_grad_zero
]
=
0
elif
isinstance
(
self
.
output
,
tuple
)
or
isinstance
(
self
.
output
,
list
):
assert
isinstance
(
self
.
output_mask
,
(
tuple
,
list
))
for
oid
,
tout
in
enumerate
(
self
.
output
):
errmsg
=
'The output only support tensor/list of tensors'
assert
isinstance
(
tout
,
torch
.
Tensor
),
errmsg
gradient_sum
=
torch
.
sum
(
torch
.
abs
(
self
.
output
.
grad
.
data
),
dim
=
0
)
_grad_zero
=
gradient_sum
==
0
for
batchid
in
range
(
self
.
output
.
size
(
0
)):
# set the same mask value for the whole batch
self
.
output_mask
[
oid
][
batchid
][
_grad_zero
]
=
0
self
.
requires_grad_
(
True
)
# Forward inference with auto gradient enabled
# Note: tensors that need gradient cannot be used in the in-place operator
self
.
random_init
()
self
.
apply_mask
()
# Some operator may have the in_place operations, so we need to clone the input
# before passing to the self.module
tmp_dummy_input
=
[
x
.
clone
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
self
.
dummy_input
]
output
=
self
.
module
(
*
tmp_dummy_input
)
if
output
.
grad_fn
is
None
:
# the output does not have the gradient function
return
# Note: output maybe tensor or list/tuple of tensors
if
isinstance
(
output
,
torch
.
Tensor
):
output
.
backward
(
self
.
output_mask
)
elif
isinstance
(
output
,
list
)
or
isinstance
(
output
,
tuple
):
for
tid
,
t_out
in
enumerate
(
output
):
t_out
.
backward
(
self
.
output_mask
[
tid
])
# update the sparsity of the paramters
for
para_name
in
self
.
weights
:
grad_zero
=
self
.
weights
[
para_name
].
grad
.
data
==
0
self
.
weight_mask
[
para_name
][
grad_zero
]
=
0
def
update_direct_sparsity
(
self
):
# we don't need the gradient in the forward inference
out_mask
=
None
constant
=
None
with
torch
.
no_grad
():
# Note: we need randomly init the input one more time here!
# Because some operation have the in-place operation, such as relu_,
# the in-place operation may modify or write 0s into the dummy_input
self
.
random_init
()
# apply the mask for the input tensor and the weight tensor
self
.
apply_mask
()
# Note: due to the in-place operator, such as relu_,
# ori_out may be the same tensor with dummy_input,
# so we use clone and detach to create a new tensor with
# the same values.
out
=
self
.
module
(
*
self
.
dummy_input
)
if
isinstance
(
out
,
torch
.
Tensor
):
out_mask
,
constant
=
self
.
isconstants
(
out
.
clone
().
detach
())
elif
isinstance
(
out
,
tuple
)
or
isinstance
(
out
,
list
):
out_mask
=
[]
constant
=
[]
for
tout
in
out
:
_mask
,
_constant
=
self
.
isconstants
(
tout
.
clone
().
detach
())
out_mask
.
append
(
_mask
)
constant
.
append
(
_constant
)
else
:
_logger
.
warning
(
'Only support the OP whose output is tensor/tuple of tensor/list of tensor'
)
# We also need random the parameters of the module, because if the weight of the model has
# a unmasked 0, then our out sparsity inference may be wrong
# However, after radomizing the weight/parameters, the constant in the output tensors may
# be different from the constants that calculated from its original stata_dict. However,
# so to get the right constant to eliminate the bias between model before and after sparsity
# inference, we need to reload its state_dict and recalculate the constant
# Currently we also get the constant values at the same time when infering the mask, in
# the future, we will separate the constant inference into a single graph pass.
if
len
(
self
.
weights
)
>
0
and
self
.
state_dict
is
not
None
:
self
.
module
.
load_state_dict
(
self
.
state_dict
)
# apply weight mask
self
.
__apply_weight_mask
()
out
=
self
.
module
(
*
self
.
dummy_input
).
clone
().
detach
()
if
isinstance
(
out
,
torch
.
Tensor
):
constant
=
torch
.
zeros_like
(
out
)
constant_pos
=
out_mask
==
0
constant
[
constant_pos
]
=
out
[
constant_pos
]
elif
isinstance
(
out
,
(
list
,
tuple
)):
constant
=
[]
for
i
,
tout
in
enumerate
(
out
):
_tmp
=
torch
.
zeros_like
(
tout
)
sparsity_pos
=
out_mask
[
i
]
==
0
_tmp
[
sparsity_pos
]
=
tout
[
sparsity_pos
]
constant
.
append
(
_tmp
)
if
isinstance
(
out_mask
,
torch
.
Tensor
):
assert
isinstance
(
self
.
output_mask
,
torch
.
Tensor
)
self
.
output_mask
*=
out_mask
elif
isinstance
(
out_mask
,
list
):
for
i
,
_
in
enumerate
(
out_mask
):
self
.
output_mask
[
i
]
*=
out_mask
[
i
]
else
:
_logger
.
warning
(
'There is no output sparsity'
)
# also save the out_constant
self
.
out_constant
=
constant
def
get_masks
(
self
):
return
(
self
.
in_masks
,
self
.
output_mask
,
self
.
weight_mask
)
nni/compression/pytorch/speedup/infer_shape.py
deleted
100644 → 0
View file @
5b99b598
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
For each operation or module, there are two functions.
One is given output shape, infer its input shape and initialization parameters (e.g., weight's shape)
The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape)
"""
import
logging
import
torch
_logger
=
logging
.
getLogger
(
__name__
)
conv_prune_dim
=
-
1
def
set_conv_prune_dim
(
dim
):
"""
Parameters:
dim: int
0: filter pruning
1: channel pruning
"""
global
conv_prune_dim
conv_prune_dim
=
dim
class
CoarseMask
:
"""
Coarse grained mask for a given tensor, here tensor could be weights,
input tensor, or output tensor
"""
def
__init__
(
self
,
num_dim
):
"""
Parameters
----------
num_dim : int
The number of dimensions of the tensor that will be masked
"""
self
.
mask_index
=
[
None
for
_
in
range
(
num_dim
)]
def
add_index_mask
(
self
,
dim
,
index
):
"""
Add mask for the specified dimension
Parameters
----------
dim : int
The dimension to add mask
index : tensor
The mask for this dimension, its a 1 dimension tensor which specifies
the index of the elements that are not pruned
"""
self
.
mask_index
[
dim
]
=
index
@
staticmethod
def
merge_index
(
index_a
,
index_b
):
"""
Parameters
----------
index_a : tensor
One index (1-dimension) tensor
index_b : tensor
The other index (1-dimension) tensor
Returns
-------
tensor
The merged index (1-dimension) tensor
Note that: the output tensor will be moved
to the same device as index_a.
"""
device
=
index_a
.
device
s
=
set
()
for
num
in
index_a
.
tolist
():
# we need to transfer the tensor to list here
# first, directly traversing the tensor by for
# loop will return the list of tensor(x) object,
# even the value are the same, but they are different
# tensor objects, so the set will contains multiple
# tensor objects that has the same value. For example
# for num in torch.ones(2):
# s.add(num)
# s will be {tensor(1), tensor(1)}
s
.
add
(
num
)
for
num
in
index_b
.
tolist
():
s
.
add
(
num
)
# move the output tensor to the same device with index_a
return
torch
.
tensor
(
sorted
(
s
)).
to
(
device
)
# pylint: disable=not-callable
def
merge
(
self
,
cmask
):
"""
Merge another CoarseMask
Parameters
----------
cmask : CoarseMask
Another CoarseMask to merge
Returns
-------
list
The member variable ```mask_index```
"""
assert
isinstance
(
cmask
,
CoarseMask
)
assert
len
(
self
.
mask_index
)
==
len
(
cmask
.
mask_index
),
\
"Only masks with the same number of dimensions can be merged"
for
i
,
index
in
enumerate
(
self
.
mask_index
):
if
index
is
None
:
self
.
mask_index
[
i
]
=
cmask
.
mask_index
[
i
]
elif
cmask
.
mask_index
[
i
]
is
not
None
:
self
.
mask_index
[
i
]
=
CoarseMask
.
merge_index
(
self
.
mask_index
[
i
],
cmask
.
mask_index
[
i
])
return
self
.
mask_index
def
__repr__
(
self
):
return
'mask_index: {}'
.
format
(
self
.
mask_index
)
def
eq_on_dim
(
self
,
other
,
dim
):
assert
isinstance
(
other
,
CoarseMask
)
if
self
.
mask_index
[
dim
]
is
None
and
other
.
mask_index
[
dim
]
is
None
:
return
True
elif
isinstance
(
self
.
mask_index
[
dim
],
torch
.
Tensor
)
\
and
isinstance
(
other
.
mask_index
[
dim
],
torch
.
Tensor
):
return
torch
.
equal
(
self
.
mask_index
[
dim
],
other
.
mask_index
[
dim
])
else
:
return
False
def
__eq__
(
self
,
other
):
assert
isinstance
(
other
,
CoarseMask
)
if
len
(
self
.
mask_index
)
!=
len
(
other
.
mask_index
):
return
False
for
i
in
range
(
len
(
self
.
mask_index
)):
if
not
self
.
eq_on_dim
(
other
,
i
):
return
False
return
True
def
__lt__
(
self
,
other
):
"""
Judge if the mask is a subset of another CoarseMask.
"""
assert
isinstance
(
other
,
CoarseMask
)
for
dim
,
_
in
enumerate
(
self
.
mask_index
):
# if self has more dimensions
if
dim
>=
len
(
other
.
mask_index
):
return
False
if
self
.
mask_index
[
dim
]
is
None
:
# if no mask on this dimension, then we have less
# masks then the other CoraseMask.
continue
elif
other
.
mask_index
[
dim
]
is
None
:
return
False
else
:
s1
=
set
(
self
.
mask_index
[
dim
].
tolist
())
s2
=
set
(
other
.
mask_index
[
dim
].
tolist
())
if
not
s1
<
s2
:
return
False
return
True
def
__le__
(
self
,
other
):
"""
Return if self's mask is less or equal to other's mask.
"""
assert
isinstance
(
other
,
CoarseMask
)
if
self
.
__lt__
(
other
)
or
self
.
__eq__
(
other
):
return
True
return
False
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
class
ModuleMasks
:
"""
The masks of a module, including the masks for weights, inputs, output
"""
def
__init__
(
self
,
module_name
,
module
=
None
):
"""
Parameters
----------
module_name : str
The name of the module or function
"""
self
.
module_name
=
module_name
self
.
module
=
module
self
.
param_masks
=
dict
()
self
.
input_mask
=
None
self
.
output_mask
=
None
def
set_param_masks
(
self
,
name
,
mask
):
"""
Parameters
----------
name : str
The name of the weight
mask : CoarseMask
The mask for this weight
"""
self
.
param_masks
[
name
]
=
mask
def
set_input_mask
(
self
,
mask
):
"""
Parameters
----------
mask : CoarseMask
The mask for input
"""
self
.
input_mask
=
mask
def
set_output_mask
(
self
,
mask
):
"""
Parameters
----------
mask : CoarseMask
The mask for output
"""
self
.
output_mask
=
mask
def
__repr__
(
self
):
return
'module_name: {}, input_mask: {}, output_mask: {}, param_masks: {}'
.
format
(
self
.
module_name
,
self
.
input_mask
,
self
.
output_mask
,
self
.
param_masks
)
"""
Infer input and output shape of a module/function from its weight mask
"""
infer_from_mask
=
{
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_mask
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_mask
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_mask
(
module_masks
,
mask
),
'Linear'
:
lambda
module_masks
,
mask
,
shape
:
linear_mask
(
module_masks
,
mask
,
shape
)
}
"""
Infer output and weight shape of a module/function from its input shape
"""
infer_from_inshape
=
{
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'PReLU'
:
lambda
module_masks
,
mask
:
prelu_inshape
(
module_masks
,
mask
),
'Sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::adaptive_avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::size'
:
lambda
module_masks
,
mask
:
size_inshape
(
module_masks
,
mask
),
'aten::view'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
'aten::reshape'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
# support only start_dim=1
'aten::flatten'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
'Linear'
:
lambda
module_masks
,
mask
:
linear_inshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
),
'aten::add_'
:
lambda
module_masks
,
mask
:
add_inshape
(
module_masks
,
mask
),
'aten::add'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
# mul has the similar behaviour with add, they both request
# the input tesors to have the same shape
'aten::mul'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
'aten::mul_'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'aten::detach'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
)
}
"""
Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape
=
{
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_outshape
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_outshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_outshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'aten::adaptive_avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'AvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'ReLU'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'PReLU'
:
lambda
module_masks
,
mask
:
prelu_outshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_outshape
(
module_masks
,
mask
),
'aten::add_'
:
lambda
module_masks
,
mask
:
add_outshape
(
module_masks
,
mask
),
'aten::add'
:
lambda
module_mask
,
mask
:
add_outshape
(
module_mask
,
mask
),
'aten::flatten'
:
lambda
module_mask
,
mask
,
shape
:
view_outshape
(
module_mask
,
mask
,
shape
),
'aten::view'
:
lambda
module_masks
,
mask
,
shape
:
view_outshape
(
module_masks
,
mask
,
shape
),
'aten::reshape'
:
lambda
module_masks
,
mask
,
shape
:
view_outshape
(
module_masks
,
mask
,
shape
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_outshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'aten::detach'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
)
}
def
dropout_inshape
(
module_masks
,
mask
):
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
module_masks
.
output_mask
# if alreay visited
assert
module_masks
.
input_mask
<=
mask
# It should be the same, we pass the masks by the reference(not the value),
# so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
module_masks
.
output_mask
def
dropout_outshape
(
module_masks
,
mask
):
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
return
module_masks
.
input_mask
# if alreay visited
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
return
module_masks
.
output_mask
def
cat_inshape
(
module_masks
,
mask
,
cat_info
,
last_visited
):
"""
Inference the output mask of the cat operation from the
input mask.
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the Conv2d
mask : CoarseMask
The mask of its input tensor
cat_info: dict
Dict object that records the necessary information
of cat operation, such as the order of the input
tensors.
last_visited: str
The unique_name of the last visited node group.
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
out_shape
=
cat_info
[
'out_shape'
]
cat_dim
=
cat_info
[
'cat_dim'
]
in_order
=
cat_info
[
'in_order'
]
in_shape
=
cat_info
[
'in_shape'
]
if
module_masks
.
output_mask
is
None
:
# First visit to this cat node
# initialize the mask based on
# the number of the output channel.
output_mask
=
CoarseMask
(
num_dim
=
len
(
out_shape
))
for
dim
,
_
in
enumerate
(
out_shape
):
if
dim
==
cat_dim
:
if
mask
.
mask_index
[
dim
]
is
None
:
continue
device
=
mask
.
mask_index
[
dim
].
device
# calculate the offset of the mask
pos
=
in_order
.
index
(
last_visited
)
offsets
=
[
in_shape
[
i
][
cat_dim
]
for
i
,
_
in
enumerate
(
in_shape
)]
offset
=
0
for
i
in
range
(
pos
):
offset
+=
offsets
[
i
]
_tmp_mask
=
(
mask
.
mask_index
[
dim
]
+
offset
).
to
(
device
)
output_mask
.
mask_index
[
dim
]
=
_tmp_mask
else
:
# directly copy the mask
if
mask
.
mask_index
[
dim
]
is
not
None
:
output_mask
.
mask_index
[
dim
]
=
mask
.
mask_index
[
dim
].
data
.
clone
(
)
module_masks
.
set_output_mask
(
output_mask
)
return
module_masks
.
output_mask
# If this cat node is already visited, we need
# validating if the mask is legel, for cat operation,
# the mask on the 'cat_dim' dimension should be stitched
# together. In the other dimensions, the mask should be
# the same, else the mask is not legal.
for
dim
,
_
in
enumerate
(
out_shape
):
if
dim
==
cat_dim
:
if
mask
.
mask_index
[
dim
]
is
None
:
continue
pos
=
in_order
.
index
(
last_visited
)
offsets
=
[
in_shape
[
i
][
cat_dim
]
for
i
,
_
in
enumerate
(
in_shape
)]
offset
=
0
for
i
in
range
(
pos
):
offset
+=
offsets
[
i
]
device
=
mask
.
mask_index
[
dim
].
device
new_mask
=
mask
.
mask_index
[
dim
]
+
offset
module_masks
.
output_mask
.
mask_index
[
dim
]
=
CoarseMask
.
merge_index
(
module_masks
.
output_mask
.
mask_index
[
dim
],
new_mask
).
to
(
device
)
else
:
assert
module_masks
.
output_mask
.
eq_on_dim
(
mask
,
dim
)
return
module_masks
.
output_mask
def
add_inshape
(
module_masks
,
mask
):
"""
Inference the output mask of the add operation from the
input mask.
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
# module_masks.input_mask = mask
return
mask
# If alreay visited, validate if have the conflict
# if the mask is different with previous input_mask
# then there is a mask confilct.
if
mask
!=
module_masks
.
input_mask
:
raise
Exception
(
'Mask conflict happenes!'
)
return
None
def
add_outshape
(
module_masks
,
mask
):
"""
Inference the input mask of the add operation from the
output mask.
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
return
mask
else
:
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
return
mask
def
batchnorm2d_inshape
(
module_masks
,
mask
):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
weight_cmask
=
CoarseMask
(
num_dim
=
1
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
weight_cmask
)
return
mask
def
batchnorm2d_outshape
(
module_masks
,
mask
):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
len
(
mask
.
mask_index
)
in
[
2
,
4
]
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
weight_cmask
=
CoarseMask
(
num_dim
=
1
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
weight_cmask
)
return
mask
def
linear_inshape
(
module_masks
,
mask
):
"""
Coarse grained input mask does not change the shape of weights and output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the linear
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor, ```None``` means shape of output tensor is not changed
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
0
]
is
None
if
module_masks
.
input_mask
is
not
None
:
assert
module_masks
.
input_mask
<=
mask
module_masks
.
set_input_mask
(
mask
)
return
None
def
view_inshape
(
module_masks
,
mask
,
shape
):
"""
This is a limited support
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework.
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its input tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
"""
# NOTE: the case constrained by the following four asserts
assert
shape
[
'in_shape'
][
0
]
==
shape
[
'out_shape'
][
0
]
assert
len
(
shape
[
'in_shape'
])
==
4
assert
len
(
shape
[
'out_shape'
])
==
2
assert
shape
[
'out_shape'
][
1
]
==
shape
[
'in_shape'
][
1
]
*
\
shape
[
'in_shape'
][
2
]
*
shape
[
'in_shape'
][
3
]
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
# due to the cat operation, the same node may be
# accessed more than once
if
module_masks
.
input_mask
is
not
None
:
assert
module_masks
.
input_mask
<=
mask
module_masks
.
set_input_mask
(
mask
)
output_cmask
=
CoarseMask
(
num_dim
=
2
)
index
=
[]
step_size
=
shape
[
'in_shape'
][
2
]
*
shape
[
'in_shape'
][
3
]
for
loc
in
mask
.
mask_index
[
1
]:
index
.
extend
([
loc
*
step_size
+
i
for
i
in
range
(
step_size
)])
output_cmask
.
add_index_mask
(
dim
=
1
,
index
=
torch
.
tensor
(
index
).
to
(
mask
.
mask_index
[
1
].
device
))
# pylint: disable=not-callable
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
def
view_outshape
(
module_masks
,
mask
,
shape
):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its output tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its input tensor
"""
# NOTE: the case constrained by the following four asserts
assert
shape
[
'in_shape'
][
0
]
==
shape
[
'out_shape'
][
0
]
assert
len
(
shape
[
'in_shape'
])
==
4
assert
len
(
shape
[
'out_shape'
])
==
2
assert
shape
[
'out_shape'
][
1
]
==
shape
[
'in_shape'
][
1
]
*
\
shape
[
'in_shape'
][
2
]
*
shape
[
'in_shape'
][
3
]
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_output_mask
(
mask
)
input_cmask
=
CoarseMask
(
num_dim
=
4
)
index
=
set
()
step_size
=
shape
[
'in_shape'
][
2
]
*
shape
[
'in_shape'
][
3
]
for
loc
in
mask
.
mask_index
[
1
]:
index
.
add
(
loc
//
step_size
)
index
=
sorted
(
list
(
index
))
input_cmask
.
add_index_mask
(
dim
=
1
,
index
=
torch
.
tensor
(
index
).
to
(
mask
.
mask_index
[
1
].
device
))
# pylint: disable=not-callable
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
def
size_inshape
(
module_masks
,
mask
):
"""
No need to do anything for this ```size``` op
"""
return
None
def
mean_inshape
(
module_masks
,
mask
,
shape
):
"""
Similar to view operation, currently mask inference only supports
the mean operation on the 3rd and 4th dimensions.
"""
assert
shape
[
'in_shape'
][
0
]
==
shape
[
'out_shape'
][
0
]
assert
shape
[
'out_shape'
][
1
]
==
shape
[
'in_shape'
][
1
]
assert
len
(
shape
[
'in_shape'
])
==
4
assert
len
(
shape
[
'out_shape'
])
==
2
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
module_masks
.
set_input_mask
(
mask
)
output_cmask
=
CoarseMask
(
num_dim
=
2
)
output_cmask
.
add_index_mask
(
dim
=
1
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
def
mean_outshape
(
module_masks
,
mask
,
shape
):
"""
Similar to view operation, currently mask inference only supports
the mean operation on the 3rd and 4th dimensions.
"""
assert
shape
[
'in_shape'
][
0
]
==
shape
[
'out_shape'
][
0
]
assert
shape
[
'out_shape'
][
1
]
==
shape
[
'in_shape'
][
1
]
assert
len
(
shape
[
'in_shape'
])
==
4
assert
len
(
shape
[
'out_shape'
])
==
2
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_output_mask
(
mask
)
input_cmask
=
CoarseMask
(
num_dim
=
4
)
input_cmask
.
add_index_mask
(
dim
=
1
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
def
maxpool2d_inshape
(
module_masks
,
mask
):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
if
module_masks
.
input_mask
is
not
None
:
assert
module_masks
.
input_mask
<=
mask
# assert module_masks.input_mask is None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
maxpool2d_outshape
(
module_masks
,
mask
):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
prelu_inshape
(
module_masks
,
mask
):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the PReLU
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
weight_cmask
=
CoarseMask
(
num_dim
=
1
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
return
mask
def
prelu_outshape
(
module_masks
,
mask
):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the PReLU
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
return
mask
def
relu_inshape
(
module_masks
,
mask
):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
input_mask
is
not
None
:
# mask conflict should be solved before speedup
assert
module_masks
.
input_mask
<=
mask
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
relu_outshape
(
module_masks
,
mask
):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
output_mask
is
not
None
:
# mask conflict should be solved before speedup
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
batchnorm2d_mask
(
module_masks
,
mask
):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert
'weight'
in
mask
and
'bias'
in
mask
sum_mask
=
mask
[
'weight'
]
+
mask
[
'bias'
]
nonzero_index
=
torch
.
nonzero
(
sum_mask
,
as_tuple
=
True
)[
0
]
# infer shape of parameters
param_cmask
=
CoarseMask
(
num_dim
=
1
)
param_cmask
.
add_index_mask
(
dim
=
0
,
index
=
nonzero_index
)
module_masks
.
set_param_masks
(
'weight'
,
param_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
param_cmask
)
# infer shape of input tensor
input_cmask
=
CoarseMask
(
num_dim
=
4
)
input_cmask
.
add_index_mask
(
dim
=
1
,
index
=
torch
.
nonzero
(
mask
[
'weight'
],
as_tuple
=
True
)[
0
])
module_masks
.
set_input_mask
(
input_cmask
)
# infer shape of output tensor
output_cmask
=
CoarseMask
(
num_dim
=
4
)
output_cmask
.
add_index_mask
(
dim
=
1
,
index
=
nonzero_index
)
module_masks
.
set_output_mask
(
output_cmask
)
return
input_cmask
,
output_cmask
def
linear_mask
(
module_masks
,
mask
,
shape
):
"""
Infer input and output shape from weight mask with limitations:
Only support infer input mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the Linear
mask : dict
The mask of its weights, from the user provided mask file
shape: dict
Shape of its input and output tensors
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert
'weight'
in
mask
num_input_dim
=
len
(
shape
[
'in_shape'
])
# Input data of Linear module can have multiple dimensions.
# here we only support infer coarse mask on the first dimension (dimension 0)
nonzero_index
=
torch
.
nonzero
(
mask
[
'weight'
].
sum
(
0
),
as_tuple
=
True
)[
0
]
# infer shape of input tensor
input_cmask
=
CoarseMask
(
num_dim
=
num_input_dim
)
input_cmask
.
add_index_mask
(
dim
=
num_input_dim
-
1
,
index
=
nonzero_index
)
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
,
None
def
conv2d_mask
(
module_masks
,
mask
):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
def
convert_to_coarse_mask
(
mask
,
dim
=
0
):
"""
Parameters
----------
mask : dict
Weight mask from user provided mask file
dim: int
0: filter pruning
1: channel pruning
Returns
-------
LongTensor, CoarseMask, CoarseMask
Index of the masked dimension, weight mask, bias mask
"""
assert
'weight'
in
mask
assert
isinstance
(
mask
[
'weight'
],
torch
.
Tensor
)
assert
dim
in
[
0
,
1
]
weight_mask
=
mask
[
'weight'
]
sum_idx
=
(
1
,
2
,
3
)
if
dim
==
0
else
(
0
,
2
,
3
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
index
=
index
.
long
().
to
(
weight_mask
.
device
)
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
.
add_index_mask
(
dim
=
dim
,
index
=
index
)
bias_cmask
=
None
if
dim
==
0
and
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
bias_index
=
torch
.
nonzero
(
mask
[
'bias'
],
as_tuple
=
True
)[
0
]
assert
torch
.
all
(
torch
.
eq
(
index
,
bias_index
)),
\
"bias mask should be consistent with weight mask"
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
bias_index
)
return
index
,
weight_cmask
,
bias_cmask
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
,
dim
=
conv_prune_dim
)
if
index
is
None
:
# TODO: fine grained mask speedup
return
None
,
None
# deal with coarse grain mask
# mask conflict should be solved by fix_mask_conflict before speedup
if
'weight'
in
module_masks
.
param_masks
:
assert
module_masks
.
param_masks
[
'weight'
]
==
weight_cmask
else
:
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
if
conv_prune_dim
==
0
:
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
io_cmask
=
CoarseMask
(
num_dim
=
4
)
io_cmask
.
add_index_mask
(
dim
=
1
,
index
=
index
)
if
conv_prune_dim
==
0
:
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
io_cmask
)
else
:
assert
module_masks
.
output_mask
==
io_cmask
return
None
,
module_masks
.
output_mask
else
:
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
io_cmask
)
else
:
assert
module_masks
.
input_mask
==
io_cmask
return
module_masks
.
input_mask
,
None
def
conv2d_inshape
(
module_masks
,
mask
):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
else
:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert
module_masks
.
input_mask
==
mask
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
def
conv2d_outshape
(
module_masks
,
mask
):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its output tensor
Returns
-------
CoarseMask
The mask of its input tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
if
module_masks
.
output_mask
is
None
:
module_masks
.
output_mask
=
mask
else
:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
def
convtranspose2d_mask
(
module_masks
,
mask
):
# TODO support the Convtranspose2d Pruning for the L1FilterPruner
raise
Exception
(
"Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later"
)
def
convtranspose2d_inshape
(
module_masks
,
mask
):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
else
:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert
module_masks
.
input_mask
==
mask
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
def
convtranspose2d_outshape
(
module_masks
,
mask
):
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
if
module_masks
.
output_mask
is
None
:
module_masks
.
output_mask
=
mask
else
:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
# Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2
weight_cmask
.
add_index_mask
(
dim
=
1
,
index
=
mask
.
mask_index
[
1
])
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
nni/compression/pytorch/speedup/jit_translate.py
0 → 100644
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
re
import
logging
from
functools
import
partial
import
torch
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
def
translate_list
(
list_node
,
speedup
=
None
):
"""
Get the list of values from the list construct node.
Parameters
---------
list_node: Torch.C.Value
The cpp node of the target list.
speedup: ModuleSpeed
The Module speedup module.
Returns
-------
values: list
The list of values in the target cpp list node.
"""
# the node that create the list
create_node
=
list_node
.
node
()
assert
create_node
.
kind
()
==
'prim::ListConstruct'
inputs
=
list
(
create_node
.
inputs
())
values
=
[]
for
_i
in
inputs
:
debugName
=
_i
.
debugName
()
if
speedup
is
not
None
and
debugName
in
speedup
.
internal_result
:
# this value is the result of the other nodes, such as
# ate::size
values
.
append
(
speedup
.
internal_result
[
debugName
].
item
())
else
:
# if the corresponding value is a constant
values
.
append
(
_i
.
toIValue
())
return
values
def
parse_constant
(
cvalue
,
speedup
):
"""
Parse the constant values from this Node
Parameters
----------
cvalue: Torch.C.Value
The cpp node of the target constant value.
speedup: ModelSpeedup
The Model speedup module.
Returns
-------
value: int/float/tensor
The constant values parsed from the node.
"""
logger
.
debug
(
'Try to parse the constant value: %s'
,
cvalue
.
debugName
())
if
cvalue
.
toIValue
()
is
not
None
:
return
cvalue
.
toIValue
()
if
cvalue
.
debugName
()
in
speedup
.
internal_result
:
return
speedup
.
internal_result
[
cvalue
.
debugName
()]
# Get the operator node of the this value
op_node
=
cvalue
.
node
()
inputs
=
op_node
.
inputs
()
input_values
=
[
parse_constant
(
_i
,
speedup
)
for
_i
in
inputs
]
func
=
trans_from_jit_to_python
[
op_node
.
kind
()](
op_node
,
speedup
)
return
func
(
*
input_values
)
def
dropout_python
(
node
,
speedup
):
return
torch
.
dropout
def
flatten_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
start_dim
=
inputs
[
1
].
toIValue
()
end_dim
=
inputs
[
2
].
toIValue
()
new_flatten
=
partial
(
torch
.
flatten
,
start_dim
=
start_dim
,
end_dim
=
end_dim
)
return
new_flatten
def
relu_inplace_python
(
node
,
speedup
):
return
torch
.
relu_
def
relu_python
(
node
,
speedup
):
return
torch
.
relu
def
sigmoid_python
(
node
,
speedup
):
return
torch
.
sigmoid
def
mean_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
keep_dim
=
inputs
[
2
].
toIValue
()
new_mean
=
partial
(
torch
.
mean
,
dim
=
tuple
(
dim_list
),
keepdim
=
keep_dim
)
return
new_mean
def
add_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
# this input is a constant value
# TODO: what if this input is a constant tensor
if
input_i
.
toIValue
()
is
not
None
:
constant
=
parse_constant
(
input_i
,
speedup
)
break
if
constant
is
None
:
return
torch
.
add
else
:
new_add
=
partial
(
torch
.
add
,
constant
)
return
new_add
def
floor_div_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
divisor
=
inputs
[
1
]
constant
=
None
if
divisor
.
debugName
()
not
in
speedup
.
internal_result
:
# divisor is a constant value/tensor
constant
=
parse_constant
(
divisor
,
speedup
)
if
constant
is
None
:
return
torch
.
floor_divide
else
:
new_op
=
partial
(
torch
.
floor_divide
,
other
=
constant
)
return
new_op
def
mul_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
constant
=
parse_constant
(
input_i
,
speedup
)
# both two inputs cannot be constants at the same time
break
if
constant
is
None
:
return
torch
.
mul
else
:
new_mul
=
partial
(
torch
.
mul
,
constant
)
return
new_mul
def
transpose_python
(
node
,
speedup
):
return
torch
.
t
def
transpose2_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_1
=
inputs
[
1
].
toIValue
()
dim_2
=
inputs
[
2
].
toIValue
()
new_transpose
=
partial
(
torch
.
transpose
,
dim0
=
dim_1
,
dim1
=
dim_2
)
return
new_transpose
def
matmul_python
(
node
,
speedup
):
return
torch
.
matmul
def
div_python
(
node
,
speedup
):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
if
inputs
[
1
].
debugName
()
in
speedup
.
internal_result
:
# the second input parameters is the output of the other
# nodes
return
torch
.
div
else
:
other
=
inputs
[
1
].
toIValue
()
new_div
=
partial
(
torch
.
div
,
other
=
other
)
return
new_div
def
softmax_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
new_softmax
=
partial
(
torch
.
softmax
,
dim
=
dim
)
return
new_softmax
def
contiguous_python
(
node
,
speedup
):
class
contiguousModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
contiguous
()
return
contiguousModule
()
def
gelu_python
(
node
,
speedup
):
return
torch
.
nn
.
GELU
()
def
avgpool2d_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
kernel_size
=
translate_list
(
inputs
[
1
],
speedup
)
stride
=
translate_list
(
inputs
[
2
],
speedup
)
padding
=
translate_list
(
inputs
[
3
],
speedup
)
new_avgpool
=
partial
(
torch
.
nn
.
functional
.
avg_pool2d
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
)
return
new_avgpool
def
adaptive_avgpool_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_size
=
translate_list
(
inputs
[
1
],
speedup
)
new_avgpool
=
torch
.
nn
.
AdaptiveAvgPool2d
(
output_size
)
return
new_avgpool
def
tupleunpack_python
(
node
,
speedup
):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return
None
def
num2tensor_python
(
node
,
speedup
):
return
torch
.
nn
.
Identity
()
def
exp_python
(
node
,
speedup
):
return
torch
.
exp
def
squeeze_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
None
if
len
(
inputs
)
>
1
:
dim
=
parse_constant
(
inputs
[
1
],
speedup
)
new_squeeze
=
partial
(
torch
.
squeeze
,
dim
=
dim
)
return
new_squeeze
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
def
slice_python
(
node
,
speedup
):
class
SliceMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sliceobj
):
super
(
SliceMoudle
,
self
).
__init__
()
self
.
sliceobj
=
sliceobj
def
forward
(
self
,
x
,
*
args
):
# args is for the slice dimension and indexes, however,
# we already get them from the cpp nodes. Note, though, we
# don't need the slice indexes any more, we cannot remove this
# parameter here, because, there may be multiple inputs passed from
# previous nodes such as aten::size
logger
.
info
(
'Model has Slice operation, and the operand size=%s, Slice object:%s'
,
str
(
x
.
size
()),
str
(
self
.
sliceobj
))
return
x
[
self
.
sliceobj
]
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
slice_dim
=
parse_constant
(
inputs
[
1
],
speedup
)
slice_start
=
parse_constant
(
inputs
[
2
],
speedup
)
slice_end
=
parse_constant
(
inputs
[
3
],
speedup
)
slice_step
=
parse_constant
(
inputs
[
4
],
speedup
)
slice_obj
=
slice
(
slice_start
,
slice_end
,
slice_step
)
slice_list
=
[]
for
_
in
range
(
slice_dim
):
slice_list
.
append
(
slice
(
None
,
None
))
logger
.
info
(
'Slice dim:%s, Slice obj:%s'
,
str
(
slice_dim
),
str
(
slice_obj
))
slice_list
.
append
(
slice_obj
)
return
SliceMoudle
(
tuple
(
slice_list
))
def
select_python
(
node
,
speedup
):
class
SelectModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
index
):
super
(
SelectModule
,
self
).
__init__
()
self
.
dim
=
dim
self
.
index
=
index
def
forward
(
self
,
x
):
return
x
.
select
(
self
.
dim
,
self
.
index
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
index
=
inputs
[
2
].
toIValue
()
return
SelectModule
(
dim
,
index
)
def
size_python
(
node
,
speedup
):
# return None
class
SizeMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sizedim
):
super
(
SizeMoudle
,
self
).
__init__
()
self
.
sizedim
=
sizedim
def
forward
(
self
,
x
):
return
torch
.
as_tensor
([
x
.
size
(
self
.
sizedim
)],
dtype
=
torch
.
long
)
# return torch.tensor(x.size(self.sizedim))
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_dim
=
inputs
[
1
].
toIValue
()
return
SizeMoudle
(
size_dim
)
def
toint_python
(
node
,
speedup
):
class
ToIntModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
to
(
torch
.
int
)
return
ToIntModule
()
def
view_python
(
node
,
speedup
):
class
ViewModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ViewModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'View Module output size: %s'
,
str
(
self
.
shape
))
def
forward
(
self
,
*
args
):
return
args
[
0
].
view
(
self
.
shape
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
return
ViewModule
(
shape
)
def
reshape_python
(
node
,
speedup
):
class
ReshapeModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ReshapeModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'Reshape Module output size: %s'
,
str
(
self
.
shape
))
def
forward
(
self
,
*
args
):
return
args
[
0
].
view
(
self
.
shape
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
return
ReshapeModule
(
shape
)
def
permute_python
(
node
,
speedup
):
class
PermuteModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dimlist
):
super
(
PermuteModule
,
self
).
__init__
()
self
.
dimlist
=
dimlist
def
forward
(
self
,
x
):
return
x
.
permute
(
self
.
dimlist
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
return
PermuteModule
(
dim_list
)
def
getattr_python
(
node
,
speedup
):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class
GetModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
key
):
super
(
GetModule
,
self
).
__init__
()
self
.
key
=
key
def
forward
(
self
,
obj
):
logger
.
info
(
'Get attribute: %s'
,
self
.
key
)
return
getattr
(
obj
,
self
.
key
)
# get the name of the attribute, for example
# prim::GetAttr[name="module_list"](%self.1)
assert
node
.
kind
()
==
'prim::GetAttr'
pattern
=
'\[name=
\"
(.*?)
\"
\]'
key_words
=
re
.
findall
(
pattern
,
str
(
node
))
assert
len
(
key_words
)
==
1
return
GetModule
(
key_words
[
0
])
def
upsample_bilinear2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
def
forward
(
self
,
*
args
):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return
torch
.
nn
.
functional
.
upsample_bilinear
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
3
].
node
()
size_list
=
None
scale_list
=
None
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
scale_list
=
translate_list
(
inputs
[
3
],
speedup
)
return
UpsampleModule
(
size_list
,
scale_list
)
def
typeas_python
(
node
,
speedup
):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
"""
class
TypeasModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float
):
self
.
example
=
torch
.
zeros
(
1
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
return
x
.
type_as
(
self
.
example
)
return
TypeasModule
()
def
to_python
(
node
,
speedup
):
# for the time being, only device parameters are supported
class
ToModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
):
super
(
ToModule
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
.
to
(
device
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
device
=
inputs
[
3
].
toIValue
()
return
ToModule
(
device
)
def
cat_python
(
node
,
speedup
):
class
CatModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
cat_dim
):
super
(
CatModule
,
self
).
__init__
()
self
.
cat_dim
=
cat_dim
def
forward
(
self
,
*
args
):
return
torch
.
cat
(
args
,
dim
=
self
.
cat_dim
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
return
CatModule
(
dim
)
trans_from_jit_to_python
=
{
'aten::add'
:
add_python
,
'aten::add_'
:
add_python
,
'aten::mul'
:
mul_python
,
'aten::mul_'
:
mul_python
,
'aten::relu'
:
relu_python
,
'aten::relu_'
:
relu_inplace_python
,
'aten::sigmoid'
:
sigmoid_python
,
'aten::sigmoid_'
:
sigmoid_python
,
# tanh behaives like relu
'aten::tanh'
:
relu_python
,
'aten::tanh_'
:
relu_python
,
'aten::flatten'
:
flatten_python
,
'aten::mean'
:
mean_python
,
'aten::dropout'
:
dropout_python
,
'aten::slice'
:
slice_python
,
'aten::select'
:
select_python
,
'aten::size'
:
size_python
,
'aten::t'
:
transpose_python
,
'aten::transpose'
:
transpose2_python
,
'aten::Int'
:
toint_python
,
'aten::view'
:
view_python
,
'aten::reshape'
:
reshape_python
,
'aten::permute'
:
permute_python
,
'aten::matmul'
:
matmul_python
,
'aten::div'
:
div_python
,
'aten::floor_divide'
:
floor_div_python
,
'aten::softmax'
:
softmax_python
,
'aten::contiguous'
:
contiguous_python
,
'aten::gelu'
:
gelu_python
,
'aten::cat'
:
cat_python
,
'aten::avg_pool2d'
:
avgpool2d_python
,
'aten::max_pool2d'
:
avgpool2d_python
,
'aten::adaptive_avg_pool2d'
:
adaptive_avgpool_python
,
'aten::to'
:
to_python
,
'aten::type_as'
:
typeas_python
,
'aten::upsample_bilinear2d'
:
upsample_bilinear2d_python
,
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::GetAttr'
:
getattr_python
}
def
jit_to_python_function
(
node
,
speedup
):
"""
Return a callable object to inference the mask according to the
node.op_type.
Parameters
---------
node: NodeGroup
The target node to inference the mask
speedup: ModelSpeedup
The speedup object of the target model.
Returns
------
func: callable object(nn.Module/function)
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
logger
.
debug
(
'Translate C function %s into its python version'
,
node
.
op_type
)
if
node
.
op_type
not
in
trans_from_jit_to_python
:
logger
.
error
(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~'
,
node
.
op_type
)
# return None to skip the mask inference for this node
return
None
return
trans_from_jit_to_python
[
node
.
op_type
](
node
,
speedup
)
nni/compression/pytorch/utils/__init__.py
View file @
7eedec46
from
.utils
import
*
\ No newline at end of file
nni/compression/pytorch/utils/mask_conflict.py
View file @
7eedec46
...
...
@@ -4,10 +4,10 @@ import os
import
logging
import
torch
import
numpy
as
np
from
.shape_dependency
import
ChannelDependency
,
GroupDependency
,
CatPaddingDependency
,
InputChannelDependency
from
.shape_dependency
import
ChannelDependency
,
GroupDependency
,
InputChannelDependency
from
.utils
import
get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
'FixMaskConflict'
)
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
...
...
@@ -21,7 +21,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
A dict object that stores the masks or the path of the mask file
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
dummy_input : torch.Tensor
/list of tensors/dict of tensors
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
...
...
@@ -48,9 +48,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks
=
fix_group_mask
.
fix_mask
()
fix_channel_mask
=
ChannelMaskConflict
(
masks
,
model
,
dummy_input
,
traced
)
masks
=
fix_channel_mask
.
fix_mask
()
padding_cat_mask
=
CatMaskPadding
(
masks
,
model
,
dummy_input
,
traced
)
masks
=
padding_cat_mask
.
fix_mask
()
return
masks
,
fix_channel_mask
.
conv_prune_dim
return
masks
class
MaskFix
:
...
...
@@ -78,70 +76,6 @@ class MaskFix:
torch
.
save
(
self
.
masks
,
path
)
class
CatMaskPadding
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
=
None
,
traced
=
None
):
"""
CatMaskPadding find the layers whose output tensor is passed
to the same cat operation. The cat operation concatnates the
masks of the input tensors as the output mask, so when some
of the input layers of the cat operation are not pruned, we still
need to pass the masks of these non-pruned layers(the mask are
all ones) to the cat operation to ensure the shape of the output
mask is right.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super
(
CatMaskPadding
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
def
fix_mask
(
self
):
cat_padding_depen
=
CatPaddingDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
name_to_module
=
{}
for
name
,
module
in
self
.
model
.
named_modules
():
name_to_module
[
name
]
=
module
depen
=
cat_padding_depen
.
dependency_sets
for
layers
in
depen
:
device
=
None
count
=
0
for
layer
in
layers
:
if
layer
in
self
.
masks
:
count
+=
1
if
device
is
None
:
device
=
self
.
masks
[
layer
][
'weight'
].
device
if
count
==
0
:
# no layer is pruned
continue
elif
count
==
len
(
layers
):
# all the layers have been pruned
continue
# pad the mask for the non-pruned layers
for
layer
in
layers
:
if
layer
in
self
.
masks
:
continue
module
=
name_to_module
[
layer
]
w_shape
=
module
.
weight
.
data
.
size
()
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
b_mask
=
None
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
# module.bias may be None
b_shape
=
module
.
bias
.
data
.
size
()
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
return
self
.
masks
class
GroupMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
...
...
@@ -172,9 +106,11 @@ class GroupMaskConflict(MaskFix):
group_depen
=
GroupDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depens
=
group_depen
.
dependency
min_groups
=
group_depen
.
min_groups
_logger
.
info
(
depens
)
for
layername
in
depens
:
group
=
depens
[
layername
]
group_max
=
depens
[
layername
]
group_min
=
min_groups
[
layername
]
if
layername
not
in
self
.
masks
:
# this layer not pruned
continue
...
...
@@ -187,29 +123,43 @@ class GroupMaskConflict(MaskFix):
# In fine-grained pruning, skip this layer
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
continue
assert
shape
[
0
]
%
group
==
0
assert
shape
[
0
]
%
group
_max
==
0
# Find the number of masked filter for each group (mini_masked).
# Because we have to keep the pruned filter can still
# be divided into the same number of groups, so we only can
# prune mini_masked filters for each group.
step
=
shape
[
0
]
/
group
step
=
shape
[
0
]
/
group
_max
group_masked
=
[]
for
i
in
range
(
group
):
for
i
in
range
(
group
_max
):
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
group_masked
.
append
(
_tmp_list
)
mini_masked
=
min
([
len
(
x
)
for
x
in
group_masked
])
need_unmask
=
set
()
for
gm
in
group_masked
:
for
i
in
range
(
mini_masked
,
len
(
gm
)):
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
pos
=
gm
[
i
]
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
shape
[
1
:])
if
'bias'
in
self
.
masks
[
layername
]
and
self
.
masks
[
layername
][
'bias'
]
is
not
None
:
self
.
masks
[
layername
][
'bias'
][
pos
]
=
1
need_unmask
.
add
(
pos
)
step
=
shape
[
0
]
/
group_min
for
i
in
range
(
group_min
):
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
if
len
(
_tmp_list
)
==
step
:
# if the whole group is removed, then we don't have to unmask for
# the filters in this group
for
pos
in
_tmp_list
:
if
pos
in
need_unmask
:
need_unmask
.
remove
(
pos
)
for
pos
in
need_unmask
:
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
shape
[
1
:])
if
hasattr
(
self
.
masks
[
layername
],
'bias'
):
self
.
masks
[
layername
][
'bias'
][
pos
]
=
1
return
self
.
masks
...
...
@@ -234,9 +184,14 @@ class ChannelMaskConflict(MaskFix):
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
_logger
.
info
(
'
de
tected conv prune dim
:
%
s
'
,
self
.
conv_prune_dim
)
_logger
.
info
(
'
Dec
tected conv prune dim
"
%
d
'
,
self
.
conv_prune_dim
)
def
fix_mask
(
self
):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
...
...
@@ -274,7 +229,12 @@ class ChannelMaskConflict(MaskFix):
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
fine_grained
=
True
elif
type
(
m
).
__name__
==
'Linear'
:
channel_masks
.
append
((
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
if
self
.
conv_prune_dim
==
1
:
channel_masks
.
append
(
(
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
else
:
channel_masks
.
append
(
(
mask
.
abs
().
sum
(
1
)
!=
0
).
int
())
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
channel_masks
.
append
(
mask
.
int
())
elif
type
(
m
).
__name__
==
'ConvTranspose2d'
:
...
...
@@ -293,9 +253,7 @@ class ChannelMaskConflict(MaskFix):
# no mask means not pruned, equivlent to full masks
channel_masks
.
append
(
None
)
if
fine_grained
:
_logger
.
info
(
'fine-grained mask detected, skip solving conflict for this set: %s'
,
dset
)
continue
_logger
.
info
(
"Fine-grianed mask detected"
)
if
all
(
x
is
None
for
x
in
channel_masks
):
continue
num_channels_list
=
[
len
(
x
)
...
...
@@ -306,7 +264,8 @@ class ChannelMaskConflict(MaskFix):
for
i
,
dim_mask
in
enumerate
(
channel_masks
):
if
dim_mask
is
None
:
channel_masks
[
i
]
=
torch
.
ones
(
num_channels
).
int
().
to
(
device
)
channel_masks
[
i
]
=
torch
.
ones
(
num_channels
).
int
().
to
(
device
)
# merge masks with 'or'
merged_channel_mask
=
channel_masks
[
0
].
clone
()
...
...
@@ -329,19 +288,22 @@ class ChannelMaskConflict(MaskFix):
else
:
new_mask
[:,
merged_index
,
:,
:]
=
1.
elif
type
(
m
).
__name__
==
'Linear'
:
new_mask
[:,
merged_index
]
=
1.
if
self
.
conv_prune_dim
==
0
:
new_mask
[
merged_index
,
:]
=
1
elif
self
.
conv_prune_dim
==
1
:
new_mask
[:,
merged_index
]
=
1.
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
new_mask
=
merged_channel_mask
.
type_as
(
orig_mask
)
else
:
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
self
.
masks
[
name
][
'weight'
]
=
new_mask
if
'bias'
in
self
.
masks
[
name
]
and
self
.
masks
[
name
][
'bias'
]
is
not
None
:
if
type
(
m
).
__name__
==
'Conv2d'
:
assert
self
.
conv_prune_dim
==
0
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
if
self
.
conv_prune_dim
==
0
:
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
return
self
.
masks
...
...
@@ -349,14 +311,12 @@ class ChannelMaskConflict(MaskFix):
def
detect_mask_prune_dim
(
masks
,
model
):
"""
Detect how the masks of convolutional layers are pruned.
Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.
Returns:
-------
How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
7eedec46
...
...
@@ -3,18 +3,34 @@
import
csv
import
logging
import
numpy
as
np
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPadding
Dependency'
,
'InputChannelDependency'
]
__all__
=
[
'ChannelDependency'
,
'Group
Dependency'
,
'InputChannelDependency'
]
CONV_TYPE
=
'aten::_convolution'
ADD_TYPES
=
[
'aten::add'
,
'aten::add_'
]
MUL_TYPES
=
[
'aten::mul'
,
'atem::mul_'
]
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
def
lcm_list
(
L
):
lcm
=
1
for
i
in
L
:
lcm
=
np
.
lcm
(
lcm
,
i
)
return
lcm
def
gcd_list
(
L
):
gcd
=
L
[
0
]
for
i
in
L
:
gcd
=
np
.
gcd
(
gcd
,
i
)
return
gcd
class
Dependency
:
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
...
...
@@ -38,6 +54,35 @@ class Dependency:
raise
NotImplementedError
def
reshape_break_channel_dependency
(
op_node
):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
in_channel
=
in_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
...
...
@@ -80,6 +125,9 @@ class ChannelDependency(Dependency):
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
continue
elif
curnode
.
op_type
in
RESHAPE_OPS
:
if
reshape_break_channel_dependency
(
curnode
):
continue
parents
=
self
.
graph
.
find_predecessors
(
curnode
.
unique_name
)
parents
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
parents
]
for
parent
in
parents
:
...
...
@@ -176,7 +224,7 @@ class ChannelDependency(Dependency):
d_sets
=
[]
visited
=
set
()
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
!=
'Conv2d'
or
node
in
visited
:
if
(
node
.
op_type
!=
'Conv2d'
and
node
.
op_type
!=
'Linear'
)
or
node
in
visited
:
continue
tmp_set
=
set
()
if
node
.
name
not
in
self
.
dependency
:
...
...
@@ -190,35 +238,6 @@ class ChannelDependency(Dependency):
return
d_sets
def
reshape_break_channel_dependency
(
op_node
):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
in_channel
=
in_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
class
InputChannelDependency
(
ChannelDependency
):
"""
Some pruners may prune the input channel of the convolutional
...
...
@@ -295,67 +314,6 @@ class InputChannelDependency(ChannelDependency):
self
.
dependency
[
layer
]
=
dependency_set
class
CatPaddingDependency
(
ChannelDependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
(
CatPaddingDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
build_dependency
(
self
):
"""
Build the cat padding dependencies.
If the output features of several layers are stitched together
by cat operation, then these layers have cat padding dependencies.
This is because when inferring the cat mask, we need all the input
masks for the cat operation. At this time we need to know the source
of all input vectors of a cat operation.
"""
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
parent_layers
=
[]
if
node
.
op_type
==
CAT_TYPE
:
parent_layers
=
self
.
_get_parent_layers
(
node
)
dependency_set
=
set
(
parent_layers
)
# merge the dependencies
for
parent
in
parent_layers
:
if
parent
in
self
.
dependency
:
dependency_set
.
update
(
self
.
dependency
[
parent
])
# save the dependencies
for
_node
in
dependency_set
:
self
.
dependency
[
_node
]
=
dependency_set
@
property
def
dependency_sets
(
self
):
d_sets
=
[]
visited
=
set
()
for
nodename
in
self
.
dependency
:
if
nodename
in
visited
:
continue
d_sets
.
append
(
self
.
dependency
[
nodename
])
return
d_sets
def
export
(
self
,
filepath
):
"""
Export the dependencies into a file.
In the output file, each line contains a set of layers
whose output features are stitched together by the cat
operation.
output example:
Dependency Set, Layers
set1, Conv1, Conv2
set2, Conv3, Conv4
"""
header
=
[
'Dependency Set'
,
'Layers'
]
setid
=
0
with
open
(
filepath
,
'w'
)
as
csvf
:
csv_w
=
csv
.
writer
(
csvf
,
delimiter
=
','
)
csv_w
.
writerow
(
header
)
for
layers
in
self
.
dependency_sets
:
setid
+=
1
row
=
[
'Set %d'
%
setid
]
row
.
extend
(
list
(
layers
))
csv_w
.
writerow
(
row
)
class
GroupDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
...
...
@@ -372,6 +330,7 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
self
.
min_groups
=
{}
super
(
GroupDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_parent_convs
(
self
,
node
):
...
...
@@ -451,27 +410,33 @@ class GroupDependency(Dependency):
key: the name of conv layers, value: the minimum value that the number of
filters should be divisible to.
"""
self
.
groups
=
{}
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
==
'Conv2d'
or
node
.
op_type
==
'ConvTranspose2d'
:
group
=
self
.
_get_conv_groups
(
node
)
if
node
.
name
in
self
.
dependency
:
if
node
.
name
in
self
.
groups
:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
self
.
dependency
[
node
.
name
]
=
max
(
self
.
dependency
[
node
.
name
],
group
)
self
.
groups
[
node
.
name
].
append
(
group
)
else
:
self
.
dependency
[
node
.
name
]
=
group
self
.
groups
[
node
.
name
]
=
[
group
]
if
group
>
1
:
# for the conv layer whose group is larger than 1, it will require the number
# of output channels of their parent conv layer to be divisible by group.
parent_convs
=
self
.
_get_parent_convs
(
node
)
for
parent
in
parent_convs
:
if
parent
in
self
.
dependency
:
self
.
dependency
[
parent
]
=
max
(
self
.
dependency
[
parent
],
group
)
if
parent
in
self
.
groups
:
self
.
groups
[
parent
].
append
(
group
)
else
:
self
.
dependency
[
parent
]
=
group
self
.
groups
[
parent
]
=
[
group
]
for
name
in
self
.
groups
:
self
.
dependency
[
name
]
=
lcm_list
(
self
.
groups
[
name
])
if
min
(
self
.
groups
[
name
])
==
gcd_list
(
self
.
groups
[
name
]):
self
.
min_groups
[
name
]
=
min
(
self
.
groups
[
name
])
else
:
self
.
min_groups
[
name
]
=
1
return
self
.
dependency
def
export
(
self
,
filepath
):
...
...
@@ -501,3 +466,110 @@ class GroupDependency(Dependency):
@
property
def
dependency_sets
(
self
):
return
self
.
dependency
class
ReshapeDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
Some model may have the view/reshape functions, such functions may have fixed parameters
and cannot be replaced at all. Therefore, these functions may have some constraints on
their input shapes. In this class, we find the direct input conv/linear layers of these
reshape functions. If you get the shape conflict when run the forward inference on the
speeduped model, please try remove these layers from the pruner config list and try again.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super
(
ReshapeDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_parent_layers
(
self
,
node
):
"""
Find the nearest father conv layers for the target node.
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""
parent_layers
=
[]
queue
=
[]
queue
.
append
(
node
)
while
queue
:
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
continue
parents
=
self
.
graph
.
find_predecessors
(
curnode
.
unique_name
)
parents
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
parents
]
for
parent
in
parents
:
queue
.
append
(
parent
)
return
parent_layers
def
build_dependency
(
self
):
"""
Build the channel dependency for the conv layers
in the model.
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self
.
graph
.
unpack_manually
()
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
parent_layers
=
[]
# find the node that contains aten::add
# or aten::cat operations
if
node
.
op_type
in
[
'aten::view'
,
'aten::reshape'
]:
logger
.
info
(
'Detect reshape-like functions: %s'
,
node
.
op_type
)
parent_layers
=
self
.
_get_parent_layers
(
node
)
print
(
'Parent layers'
,
parent_layers
)
self
.
dependency
[
node
.
unique_name
]
=
parent_layers
def
export
(
self
,
filepath
):
"""
export the reshape dependencies as a csv file.
Output example:
Reshape OP, Dependent Layers
model.view.1,layer1.1.conv2,layer1.0.conv2,conv1
model.mean.1,layer1.0.conv1
model.reshape.1,layer1.1.conv1
"""
header
=
[
'Reshape OP'
,
'Dependent Layers'
]
with
open
(
filepath
,
'w'
)
as
csvf
:
csv_w
=
csv
.
writer
(
csvf
,
delimiter
=
','
)
csv_w
.
writerow
(
header
)
for
reshape_op
in
self
.
dependency
:
row
=
[
reshape_op
].
extend
(
self
.
dependency
[
reshape_op
])
csv_w
.
writerow
(
row
)
@
property
def
dependency_sets
(
self
):
"""
Get the list of the dependency set.
Returns
-------
dependency_sets : list
list of the dependency sets. For example,
[set(['conv1', 'conv2']), set(['conv3', 'conv4'])]
"""
d_sets
=
[]
for
reshape_node
in
self
.
dependency
:
d_sets
.
extend
(
self
.
dependency
[
reshape_node
])
d_sets
=
list
(
set
(
d_sets
))
return
d_sets
nni/compression/pytorch/utils/utils.py
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
from
.shape_dependency
import
ReshapeDependency
torch_float_dtype
=
[
torch
.
float
,
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
half
,
torch
.
double
]
torch_integer_dtype
=
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int32
,
torch
.
long
,
torch
.
bool
]
def
get_module_by_name
(
model
,
module_name
):
"""
...
...
@@ -28,3 +33,50 @@ def get_module_by_name(model, module_name):
return
model
,
leaf_module
else
:
return
None
,
None
def
rand_like_with_shape
(
shape
,
ori_t
):
"""
Return a new random tensor like the original
tensor.
"""
assert
isinstance
(
ori_t
,
torch
.
Tensor
)
device
=
ori_t
.
device
dtype
=
ori_t
.
dtype
require_grad
=
ori_t
.
requires_grad
lower_bound
=
torch
.
min
(
ori_t
)
higher_bound
=
torch
.
max
(
ori_t
)
if
dtype
in
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int16
,
torch
.
long
,
torch
.
bool
]:
return
torch
.
randint
(
lower_bound
,
higher_bound
+
1
,
shape
,
dtype
=
dtype
,
device
=
device
)
else
:
return
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
require_grad
)
def
randomize_tensor
(
tensor
,
start
=
1
,
end
=
100
):
"""
Randomize the target tensor according to the given
range.
"""
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
dtype
in
torch_integer_dtype
:
# integer tensor can only be randomized by the torch.randint
# torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype)
pass
else
:
# we can use nn.init.uniform_ to randomize this tensor
# Note: the tensor that with integer type cannot be randomize
# with nn.init.uniform_
torch
.
nn
.
init
.
uniform_
(
tensor
.
data
,
start
,
end
)
def
not_safe_to_prune
(
model
,
dummy_input
):
"""
Get the layers that are safe to prune(will not bring the shape conflict).
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset
=
ReshapeDependency
(
model
,
dummy_input
)
return
reshape_dset
.
dependency_sets
\ No newline at end of file
test/ut/sdk/test_compression_utils.py
View file @
7eedec46
...
...
@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
pruner
.
export_model
(
ck_file
,
mask_file
)
pruner
.
_unwrap_model
()
# Fix the mask conflict
fixed_mask
,
_
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
fixed_mask
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
# use the channel dependency groud truth to check if
# fix the mask conflict successfully
...
...
test/ut/sdk/test_model_speedup.py
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
gc
import
psutil
import
sys
import
numpy
as
np
...
...
@@ -9,18 +11,20 @@ import torch
import
torchvision.models
as
models
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision.models.vgg
import
vgg16
from
torchvision.models.vgg
import
vgg16
,
vgg11
from
torchvision.models.resnet
import
resnet18
from
torchvision.models.mobilenet
import
mobilenet_v2
import
unittest
from
unittest
import
TestCase
,
main
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
,
LevelPruner
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner
import
DependencyAwarePruner
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
BATCH_SIZE
=
2
# the relative distance
RELATIVE_THRESHOLD
=
0.01
...
...
@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
return
x
class
TupleUnpack_backbone
(
nn
.
Module
):
def
__init__
(
self
,
width
):
super
(
TupleUnpack_backbone
,
self
).
__init__
()
self
.
model_backbone
=
mobilenet_v2
(
pretrained
=
False
,
width_mult
=
width
,
num_classes
=
3
)
def
forward
(
self
,
x
):
x1
=
self
.
model_backbone
.
features
[:
7
](
x
)
x2
=
self
.
model_backbone
.
features
[
7
:
14
](
x1
)
x3
=
self
.
model_backbone
.
features
[
14
:
18
](
x2
)
return
[
x1
,
x2
,
x3
]
class
TupleUnpack_FPN
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_FPN
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
32
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
96
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv3
=
nn
.
Conv2d
(
320
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
# self.init_weights()
def
forward
(
self
,
inputs
):
"""Forward function."""
laterals
=
[]
laterals
.
append
(
self
.
conv1
(
inputs
[
0
]))
# inputs[0]==x1
laterals
.
append
(
self
.
conv2
(
inputs
[
1
]))
# inputs[1]==x2
laterals
.
append
(
self
.
conv3
(
inputs
[
2
]))
# inputs[2]==x3
return
laterals
class
TupleUnpack_Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_Model
,
self
).
__init__
()
self
.
backbone
=
TupleUnpack_backbone
(
1.0
)
self
.
fpn
=
TupleUnpack_FPN
()
def
forward
(
self
,
x
):
x1
=
self
.
backbone
(
x
)
out
=
self
.
fpn
(
x1
)
return
out
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
)
SPARSITY
=
0.5
MODEL_FILE
,
MASK_FILE
=
'./11_model.pth'
,
'./l1_mask.pth'
...
...
@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
'sparsity'
:
sparsity
})
return
cfg_list
def
generate_random_sparsity_v2
(
model
):
"""
Only select 50% layers to prune.
...
...
@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model):
if
np
.
random
.
uniform
(
0
,
1.0
)
>
0.5
:
sparsity
=
np
.
random
.
uniform
(
0.5
,
0.99
)
cfg_list
.
append
({
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
name
],
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
def
zero_bn_bias
(
model
):
with
torch
.
no_grad
():
for
name
,
module
in
model
.
named_modules
():
...
...
@@ -231,19 +286,6 @@ def channel_prune(model):
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
model
=
vgg16
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
MASK_FILE
)
ms
.
speedup_model
()
orig_model
=
vgg16
()
assert
model
.
training
assert
model
.
features
[
2
].
out_channels
==
int
(
orig_model
.
features
[
2
].
out_channels
*
SPARSITY
)
assert
model
.
classifier
[
0
].
in_features
==
int
(
orig_model
.
classifier
[
0
].
in_features
*
SPARSITY
)
def
test_speedup_bigmodel
(
self
):
prune_model_l1
(
BigModel
())
...
...
@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
mask_out
=
model
(
dummy_input
)
model
.
train
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
assert
model
.
training
...
...
@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
new_model
=
TransposeModel
()
state_dict
=
torch
.
load
(
MODEL_FILE
)
new_model
.
load_state_dict
(
state_dict
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
zero_bn_bias
(
ori_model
)
zero_bn_bias
(
new_model
)
...
...
@@ -297,26 +339,34 @@ class SpeedupTestCase(TestCase):
new_out
=
new_model
(
dummy_input
)
ori_sum
=
torch
.
sum
(
ori_out
)
speeded_sum
=
torch
.
sum
(
new_out
)
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
# FIXME: This test case might fail randomly, no idea why
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
def
test_speedup_integration
(
self
):
# skip this test on windows(7GB mem available) due to memory limit
# Note: hack trick, may be updated in the future
if
'win'
in
sys
.
platform
or
'Win'
in
sys
.
platform
:
print
(
'Skip test_speedup_integration on windows due to memory limit!'
)
def
test_speedup_integration_small
(
self
):
model_list
=
[
'resnet18'
,
'mobilenet_v2'
,
'alexnet'
]
self
.
speedup_integration
(
model_list
)
def
test_speedup_integration_big
(
self
):
model_list
=
[
'vgg11'
,
'vgg16'
,
'resnet34'
,
'squeezenet1_1'
,
'densenet121'
,
'resnet50'
,
'wide_resnet50_2'
]
mem_info
=
psutil
.
virtual_memory
()
ava_gb
=
mem_info
.
available
/
1024.0
/
1024
/
1024
print
(
'Avaliable memory size: %.2f GB'
%
ava_gb
)
if
ava_gb
<
8.0
:
# memory size is too small that we may run into an OOM exception
# Skip this test in the pipeline test due to memory limitation
return
self
.
speedup_integration
(
model_list
)
def
speedup_integration
(
self
,
model_list
,
speedup_cfg
=
None
):
Gen_cfg_funcs
=
[
generate_random_sparsity
,
generate_random_sparsity_v2
]
for
model_name
in
[
'resnet18'
,
'mobilenet_v2'
,
'squeezenet1_1'
,
'densenet121'
,
'densenet169'
,
# 'inception_v3' inception is too large and may fail the pipeline
'resnet50'
]:
#
for model_name in [
'vgg16',
'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121'
,
#
# 'inception_v3' inception is too large and may fail the pipeline
#
'resnet50']:
for
model_name
in
model_list
:
for
gen_cfg_func
in
Gen_cfg_funcs
:
kwargs
=
{
'pretrained'
:
True
...
...
@@ -334,7 +384,10 @@ class SpeedupTestCase(TestCase):
speedup_model
.
eval
()
# random generate the prune config for the pruner
cfgs
=
gen_cfg_func
(
net
)
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
if
len
(
cfgs
)
==
0
:
continue
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
...
@@ -345,7 +398,10 @@ class SpeedupTestCase(TestCase):
zero_bn_bias
(
speedup_model
)
data
=
torch
.
ones
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
)
if
speedup_cfg
is
None
:
speedup_cfg
=
{}
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
,
confidence
=
2
,
**
speedup_cfg
)
ms
.
speedup_model
()
speedup_model
.
eval
()
...
...
@@ -355,12 +411,13 @@ class SpeedupTestCase(TestCase):
ori_sum
=
torch
.
sum
(
ori_out
).
item
()
speeded_sum
=
torch
.
sum
(
speeded_out
).
item
()
print
(
'Sum of the output of %s (before speedup):'
%
model_name
,
ori_sum
)
print
(
'Sum of the output of %s (after speedup):'
%
model_name
,
speeded_sum
)
model_name
,
ori_sum
)
print
(
'Sum of the output of %s (after
speedup):'
%
model_name
,
speeded_sum
)
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
print
(
"Collecting Garbage"
)
gc
.
collect
(
2
)
def
test_channel_prune
(
self
):
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
...
...
@@ -378,7 +435,7 @@ class SpeedupTestCase(TestCase):
net
.
eval
()
data
=
torch
.
randn
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
bound_model
(
data
)
...
...
@@ -391,11 +448,56 @@ class SpeedupTestCase(TestCase):
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
def
test_speedup_tupleunpack
(
self
):
"""This test is reported in issue3645"""
model
=
TupleUnpack_Model
()
cfg_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.5
}]
dummy_input
=
torch
.
rand
(
2
,
3
,
224
,
224
)
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
model
(
dummy_input
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
def
test_finegrained_speedup
(
self
):
""" Test the speedup on the fine-grained sparsity"""
class
MLP
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MLP
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc2
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc3
=
nn
.
Linear
(
1024
,
512
)
self
.
fc4
=
nn
.
Linear
(
512
,
10
)
def
forward
(
self
,
x
):
x
=
x
.
view
(
-
1
,
1024
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
x
=
self
.
fc4
(
x
)
return
x
model
=
MLP
().
to
(
device
)
dummy_input
=
torch
.
rand
(
16
,
1
,
32
,
32
).
to
(
device
)
cfg_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.99
}]
pruner
=
LevelPruner
(
model
,
cfg_list
)
pruner
.
compress
()
print
(
'Original Arch'
)
print
(
model
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
4
)
ms
.
speedup_model
()
print
(
"Fine-grained speeduped model"
)
print
(
model
)
def
tearDown
(
self
):
if
os
.
path
.
exists
(
MODEL_FILE
):
os
.
remove
(
MODEL_FILE
)
if
os
.
path
.
exists
(
MASK_FILE
):
os
.
remove
(
MASK_FILE
)
# GC to release memory
gc
.
collect
(
2
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment