Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7eedec46
Unverified
Commit
7eedec46
authored
Jul 14, 2021
by
Ningxin Zheng
Committed by
GitHub
Jul 14, 2021
Browse files
Model Speedup Refactor (#3462)
parent
5b99b598
Changes
13
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2218 additions
and
1701 deletions
+2218
-1701
docs/en_US/Compression/CompressionReference.rst
docs/en_US/Compression/CompressionReference.rst
+0
-3
nni/common/graph_utils.py
nni/common/graph_utils.py
+12
-2
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+403
-192
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+453
-126
nni/compression/pytorch/speedup/infer_mask.py
nni/compression/pytorch/speedup/infer_mask.py
+378
-0
nni/compression/pytorch/speedup/infer_shape.py
nni/compression/pytorch/speedup/infer_shape.py
+0
-1146
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+553
-0
nni/compression/pytorch/utils/__init__.py
nni/compression/pytorch/utils/__init__.py
+1
-0
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+50
-90
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+174
-102
nni/compression/pytorch/utils/utils.py
nni/compression/pytorch/utils/utils.py
+52
-0
test/ut/sdk/test_compression_utils.py
test/ut/sdk/test_compression_utils.py
+1
-1
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+141
-39
No files found.
docs/en_US/Compression/CompressionReference.rst
View file @
7eedec46
...
@@ -140,9 +140,6 @@ Topology Utilities
...
@@ -140,9 +140,6 @@ Topology Utilities
.. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency
.. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency
:members:
:members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.CatMaskPadding
:members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict
:members:
:members:
...
...
nni/common/graph_utils.py
View file @
7eedec46
...
@@ -71,7 +71,11 @@ class TorchGraph:
...
@@ -71,7 +71,11 @@ class TorchGraph:
def
_trace
(
self
,
model
,
dummy_input
):
def
_trace
(
self
,
model
,
dummy_input
):
training
=
model
.
training
training
=
model
.
training
model
.
eval
()
model
.
eval
()
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
kw_args
=
{}
if
torch
.
__version__
>=
'1.6.0'
:
# only pytorch with version greater than 1.6.0 has the strict option
kw_args
[
'strict'
]
=
False
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
,
**
kw_args
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
model
.
train
(
training
)
model
.
train
(
training
)
...
@@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
().
__init__
(
model
,
dummy_input
,
traced_model
)
super
().
__init__
(
model
,
dummy_input
,
traced_model
)
self
.
global_count
=
0
self
.
global_count
=
0
self
.
reused_module
=
set
()
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
self
.
_extract_auxiliary_info
()
self
.
_extract_auxiliary_info
()
...
@@ -390,9 +395,12 @@ class TorchModuleGraph(TorchGraph):
...
@@ -390,9 +395,12 @@ class TorchModuleGraph(TorchGraph):
outputs
.
append
(
output_name
)
outputs
.
append
(
output_name
)
else
:
else
:
outputs
.
append
(
output_name
)
outputs
.
append
(
output_name
)
unique_outputs
=
list
(
set
(
outputs
))
# remove the dumplicated output names
unique_outputs
.
sort
(
key
=
outputs
.
index
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
list
(
inputs
),
outputs
=
list
(
outputs
)
)
node_group
,
inputs
=
list
(
inputs
),
outputs
=
unique_
outputs
)
return
nodepy
return
nodepy
def
_extract_cat_info
(
self
,
node_group
,
cpp_node
):
def
_extract_cat_info
(
self
,
node_group
,
cpp_node
):
...
@@ -724,6 +732,8 @@ class TorchModuleGraph(TorchGraph):
...
@@ -724,6 +732,8 @@ class TorchModuleGraph(TorchGraph):
unique_name
=
module_name
unique_name
=
module_name
if
use_count
>
0
:
if
use_count
>
0
:
unique_name
=
module_name
+
'.%d'
%
use_count
unique_name
=
module_name
+
'.%d'
%
use_count
self
.
reused_module
.
add
(
unique_name
)
self
.
reused_module
.
add
(
module_name
)
node_group
=
self
.
_expand_module_node
(
node_group
=
self
.
_expand_module_node
(
node
,
module_name
,
unique_name
,
module_to_type
[
module_name
],
node
,
module_name
,
unique_name
,
module_to_type
[
module_name
],
node_cpps
,
input_to_node
,
output_to_node
,
'module'
)
node_cpps
,
input_to_node
,
output_to_node
,
'module'
)
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
7eedec46
This diff is collapsed.
Click to expand it.
nni/compression/pytorch/speedup/compressor.py
View file @
7eedec46
This diff is collapsed.
Click to expand it.
nni/compression/pytorch/speedup/infer_mask.py
0 → 100644
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
import
torch.nn
as
nn
from
..utils
import
randomize_tensor
,
torch_float_dtype
,
torch_integer_dtype
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
STD_DELTA
=
1e-6
class
AutoMaskInference
:
def
__init__
(
self
,
module
,
dummy_input
,
in_masks
=
None
,
weight_mask
=
None
,
\
output_mask
=
None
,
name
=
None
,
in_constants
=
None
,
state_dict
=
None
,
batch_dim
=
0
):
"""
This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according
to the input masks, in constrast, update_indirect_sparsity will
infer the input masks according to given output masks. The newly
found sparsity will be incrementally updated to the original in_masks
and output_mask.
Parameters
----------
module: torch.nn.Module/function
The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module.
in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally
update the given in_masks, else, AutoMaskInference will create a new
in_masks for the target module.
output_mask: torch.Tensor
The output mask of the target module. Similar to in_masks, if output_mask
is not None, then update_direct_sparsity and update_indirect_sparsity will
incrementally update the given output_mask, else AutoMaskInference will create
one output_mask for the target module.
weight_mask: dict of the weight masks
The weight masks of the target module, the key is the corresponding name of
the mask. For example: {'weight':torch.ones(1000, 1000), bias:torch.ones(1000)}
name: str
Name of the target module.
in_constants: list of torch.Tensor
The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor
The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
"""
errmsg
=
'%s is not callable, should pass the nn.Module/function'
%
str
(
module
)
assert
callable
(
module
),
errmsg
self
.
module
=
module
# Initialize the dummy_input
if
isinstance
(
dummy_input
,
list
):
# if there are multiple input variables
self
.
dummy_input
=
dummy_input
else
:
# if there is only one input variable
self
.
dummy_input
=
[
dummy_input
]
# Initialize the masks for input tensors
self
.
in_masks
=
in_masks
if
in_masks
is
not
None
else
[
None
]
*
len
(
self
.
dummy_input
)
self
.
in_constants
=
in_constants
if
in_constants
is
not
None
else
[
torch
.
zeros_like
(
x
)
for
x
in
dummy_input
]
for
in_id
,
_
in
enumerate
(
self
.
in_masks
):
if
self
.
in_masks
[
in_id
]
is
None
and
\
isinstance
(
self
.
dummy_input
[
in_id
],
torch
.
Tensor
):
# if the input mask is None then create a all-ones mask for corresponding input tensor
self
.
in_masks
[
in_id
]
=
torch
.
ones_like
(
self
.
dummy_input
[
in_id
])
# ones_like will put the created mask on the same device with the dummy_input
# Initialize the mask for output tensors
self
.
output
=
self
.
module
(
*
dummy_input
)
# self.output.requires_grad_()
if
output_mask
is
not
None
:
# assume the given output mask is right
self
.
output_mask
=
output_mask
else
:
if
isinstance
(
self
.
output
,
torch
.
Tensor
):
self
.
output_mask
=
torch
.
ones_like
(
self
.
output
)
elif
isinstance
(
self
.
output
,
list
)
or
isinstance
(
self
.
output
,
tuple
):
self
.
output_mask
=
[]
for
o_tensor
in
self
.
output
:
if
isinstance
(
o_tensor
,
torch
.
Tensor
):
self
.
output_mask
.
append
(
torch
.
ones_like
(
o_tensor
))
else
:
# if one of the outputs is not tensor, set the corresponding
# mask to None
self
.
output_mask
.
append
(
None
)
else
:
self
.
output_mask
=
None
# Initialize the mask for the parameters
self
.
weights
=
{}
self
.
weight_mask
=
{}
if
weight_mask
:
self
.
weight_mask
.
update
(
weight_mask
)
if
isinstance
(
self
.
module
,
nn
.
Module
):
# the function should not has parameters
# get all the parameter tensors of the target module
for
name
,
para
in
module
.
named_parameters
():
self
.
weights
[
name
]
=
para
if
name
not
in
self
.
weight_mask
:
self
.
weight_mask
[
name
]
=
torch
.
ones_like
(
para
.
data
)
self
.
name
=
name
self
.
state_dict
=
state_dict
# TODO support the other batch dimension in the future
self
.
batch_dim
=
batch_dim
def
random_init
(
self
,
start
=
0.1
,
end
=
8.0
):
"""
Random initialize the weights of the module. The value of
the tensor will not affect the mask auto inference.
"""
# currently we set the random range to 0.1-8.0 because of the ReLU6,
# if we use a range that far larger than 6, it may infer a wrong mask
# when the confidence is low. In the future, we will add the mask inference
# rules for ReLU6 to break this range constraint.
with
torch
.
no_grad
():
for
tensor
in
self
.
dummy_input
:
if
isinstance
(
tensor
,
torch
.
Tensor
)
and
len
(
tensor
.
size
())
>
0
:
# if the tensor is a scalar, then skip this tensor
randomize_tensor
(
tensor
,
start
,
end
)
for
para
in
self
.
weights
:
randomize_tensor
(
self
.
weights
[
para
].
data
,
start
,
end
)
def
zero_grad
(
self
):
"""
Set the gradient of the weight, input tensor to be zeros.
"""
with
torch
.
no_grad
():
# set the weight's gradient to zero
if
isinstance
(
self
.
module
,
nn
.
Module
):
self
.
module
.
zero_grad
()
# also zero the gradient of the input tensors
for
tensor
in
self
.
dummy_input
:
if
isinstance
(
tensor
,
torch
.
Tensor
):
if
tensor
.
grad
is
not
None
:
tensor
.
grad
.
data
.
zero_
()
def
requires_grad_
(
self
,
flag
=
True
):
"""
Set the requires_grad of input tensor and parameters to flag.
"""
for
t_in
in
self
.
dummy_input
:
if
isinstance
(
t_in
,
torch
.
Tensor
)
and
t_in
.
dtype
in
torch_float_dtype
:
# only float type can require the gradient
# enable the auto gradient
t_in
.
requires_grad_
(
flag
)
for
para_name
in
self
.
weights
:
if
self
.
weights
[
para_name
].
dtype
in
torch_float_dtype
:
self
.
weights
[
para_name
].
requires_grad_
(
flag
)
def
apply_mask
(
self
):
self
.
__apply_input_mask
()
self
.
__apply_weight_mask
()
def
__apply_input_mask
(
self
):
"""
Apply the mask of the input tensor.
"""
with
torch
.
no_grad
():
# apply the input mask
for
tid
,
in_tensor
in
enumerate
(
self
.
dummy_input
):
if
isinstance
(
in_tensor
,
torch
.
Tensor
)
and
self
.
in_masks
[
tid
]
is
not
None
:
in_tensor
.
data
=
in_tensor
.
data
*
\
self
.
in_masks
[
tid
]
+
\
(
1
-
self
.
in_masks
[
tid
])
*
self
.
in_constants
[
tid
]
def
__apply_weight_mask
(
self
):
"""
Apply the weight mask of this module.
"""
with
torch
.
no_grad
():
# apply the weight mask
for
para
in
self
.
weights
:
if
para
in
self
.
weight_mask
:
self
.
weights
[
para
].
data
*=
self
.
weight_mask
[
para
].
data
def
isconstants
(
self
,
tout
):
"""
Find the constants in the tensor tout. This function return a mask tensor that
indicates if a value in tout is a constant, and return one more tensor to indicate
that the values of the constant.
Paramters
---------
tout: torch.Tensor
The target output tensor to find the constants
Returns
-------
mask: torch.Tensor
The mask tensor(same shape with tout) that indicates that whether
the correponding value is a constant.
constant: torch.Tensor
The mask tensot(same shape with tout) that indicates the values of
the constants in the tout.
"""
assert
isinstance
(
tout
,
torch
.
Tensor
)
out_mask
=
torch
.
ones_like
(
tout
)
constant
=
torch
.
zeros_like
(
tout
)
# judge if tout is a scalar(tensor that only have one value)
if
len
(
tout
.
size
())
==
0
:
# tout is a scalar tensor, for the scalar tensor, we take
# this scalar as a constant, usually, the scalar tensor is returned
# by the size() function
constant
=
tout
return
out_mask
,
constant
if
tout
.
dtype
in
torch_integer_dtype
:
# Pytorch cannot use torch.mean and torch.std to process
# intergers :( , so if dtype of the input tensor is integer, we need
# check if is the constant by ourselves
# Note: the first dimension should be the batch dimension
same
=
tout
[:]
==
tout
[
0
]
reduced
=
torch
.
sum
(
same
,
dim
=
0
)
is_constant
=
reduced
==
tout
.
size
(
0
)
out_mask
[:,
is_constant
]
=
0
constant
[:,
is_constant
]
=
tout
[
0
][
is_constant
]
else
:
# calculate the std of the output among batch dimension
std
=
torch
.
std
(
tout
,
dim
=
0
)
# calculate the mean value of the output among the batch dimension
mean
=
torch
.
mean
(
tout
,
dim
=
0
)
mask_pos
=
std
<
STD_DELTA
out_mask
[:,
mask_pos
]
=
0
constant
[:,
mask_pos
]
=
mean
[
mask_pos
]
return
out_mask
,
constant
def
update_indirect_sparsity
(
self
):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
"""
# Each node only update the output mask when we backwards
# update the output mask, this is because that some op may
# have the broadcast operation, for example, OP A's output
# tensor may be taken by two OPs(B, C) as inputs. So we cannot
# directly update the input mask at the OP B or C. We can only
# update the mask of C's output tensor only when B and C are
# already updated(gradient are already calculated and added to
# C's output tensor).
# Besides, updating the mask of C's output tensor equals to updating
# the input mask of OP B and C.
if
isinstance
(
self
.
output
,
torch
.
Tensor
)
and
self
.
output
.
grad
is
not
None
:
# if output have gradient which means this node has successor
# nodes and the successor nodes have already update their indirect
# sparsity
# we can mask the values whose gradient is always zeros
gradient_sum
=
torch
.
sum
(
torch
.
abs
(
self
.
output
.
grad
.
data
),
dim
=
0
)
_grad_zero
=
gradient_sum
==
0
for
batchid
in
range
(
self
.
output
.
size
(
0
)):
# set the same mask value for the whole batche
self
.
output_mask
[
batchid
][
_grad_zero
]
=
0
elif
isinstance
(
self
.
output
,
tuple
)
or
isinstance
(
self
.
output
,
list
):
assert
isinstance
(
self
.
output_mask
,
(
tuple
,
list
))
for
oid
,
tout
in
enumerate
(
self
.
output
):
errmsg
=
'The output only support tensor/list of tensors'
assert
isinstance
(
tout
,
torch
.
Tensor
),
errmsg
gradient_sum
=
torch
.
sum
(
torch
.
abs
(
self
.
output
.
grad
.
data
),
dim
=
0
)
_grad_zero
=
gradient_sum
==
0
for
batchid
in
range
(
self
.
output
.
size
(
0
)):
# set the same mask value for the whole batch
self
.
output_mask
[
oid
][
batchid
][
_grad_zero
]
=
0
self
.
requires_grad_
(
True
)
# Forward inference with auto gradient enabled
# Note: tensors that need gradient cannot be used in the in-place operator
self
.
random_init
()
self
.
apply_mask
()
# Some operator may have the in_place operations, so we need to clone the input
# before passing to the self.module
tmp_dummy_input
=
[
x
.
clone
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
self
.
dummy_input
]
output
=
self
.
module
(
*
tmp_dummy_input
)
if
output
.
grad_fn
is
None
:
# the output does not have the gradient function
return
# Note: output maybe tensor or list/tuple of tensors
if
isinstance
(
output
,
torch
.
Tensor
):
output
.
backward
(
self
.
output_mask
)
elif
isinstance
(
output
,
list
)
or
isinstance
(
output
,
tuple
):
for
tid
,
t_out
in
enumerate
(
output
):
t_out
.
backward
(
self
.
output_mask
[
tid
])
# update the sparsity of the paramters
for
para_name
in
self
.
weights
:
grad_zero
=
self
.
weights
[
para_name
].
grad
.
data
==
0
self
.
weight_mask
[
para_name
][
grad_zero
]
=
0
def
update_direct_sparsity
(
self
):
# we don't need the gradient in the forward inference
out_mask
=
None
constant
=
None
with
torch
.
no_grad
():
# Note: we need randomly init the input one more time here!
# Because some operation have the in-place operation, such as relu_,
# the in-place operation may modify or write 0s into the dummy_input
self
.
random_init
()
# apply the mask for the input tensor and the weight tensor
self
.
apply_mask
()
# Note: due to the in-place operator, such as relu_,
# ori_out may be the same tensor with dummy_input,
# so we use clone and detach to create a new tensor with
# the same values.
out
=
self
.
module
(
*
self
.
dummy_input
)
if
isinstance
(
out
,
torch
.
Tensor
):
out_mask
,
constant
=
self
.
isconstants
(
out
.
clone
().
detach
())
elif
isinstance
(
out
,
tuple
)
or
isinstance
(
out
,
list
):
out_mask
=
[]
constant
=
[]
for
tout
in
out
:
_mask
,
_constant
=
self
.
isconstants
(
tout
.
clone
().
detach
())
out_mask
.
append
(
_mask
)
constant
.
append
(
_constant
)
else
:
_logger
.
warning
(
'Only support the OP whose output is tensor/tuple of tensor/list of tensor'
)
# We also need random the parameters of the module, because if the weight of the model has
# a unmasked 0, then our out sparsity inference may be wrong
# However, after radomizing the weight/parameters, the constant in the output tensors may
# be different from the constants that calculated from its original stata_dict. However,
# so to get the right constant to eliminate the bias between model before and after sparsity
# inference, we need to reload its state_dict and recalculate the constant
# Currently we also get the constant values at the same time when infering the mask, in
# the future, we will separate the constant inference into a single graph pass.
if
len
(
self
.
weights
)
>
0
and
self
.
state_dict
is
not
None
:
self
.
module
.
load_state_dict
(
self
.
state_dict
)
# apply weight mask
self
.
__apply_weight_mask
()
out
=
self
.
module
(
*
self
.
dummy_input
).
clone
().
detach
()
if
isinstance
(
out
,
torch
.
Tensor
):
constant
=
torch
.
zeros_like
(
out
)
constant_pos
=
out_mask
==
0
constant
[
constant_pos
]
=
out
[
constant_pos
]
elif
isinstance
(
out
,
(
list
,
tuple
)):
constant
=
[]
for
i
,
tout
in
enumerate
(
out
):
_tmp
=
torch
.
zeros_like
(
tout
)
sparsity_pos
=
out_mask
[
i
]
==
0
_tmp
[
sparsity_pos
]
=
tout
[
sparsity_pos
]
constant
.
append
(
_tmp
)
if
isinstance
(
out_mask
,
torch
.
Tensor
):
assert
isinstance
(
self
.
output_mask
,
torch
.
Tensor
)
self
.
output_mask
*=
out_mask
elif
isinstance
(
out_mask
,
list
):
for
i
,
_
in
enumerate
(
out_mask
):
self
.
output_mask
[
i
]
*=
out_mask
[
i
]
else
:
_logger
.
warning
(
'There is no output sparsity'
)
# also save the out_constant
self
.
out_constant
=
constant
def
get_masks
(
self
):
return
(
self
.
in_masks
,
self
.
output_mask
,
self
.
weight_mask
)
nni/compression/pytorch/speedup/infer_shape.py
deleted
100644 → 0
View file @
5b99b598
This diff is collapsed.
Click to expand it.
nni/compression/pytorch/speedup/jit_translate.py
0 → 100644
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
re
import
logging
from
functools
import
partial
import
torch
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
def
translate_list
(
list_node
,
speedup
=
None
):
"""
Get the list of values from the list construct node.
Parameters
---------
list_node: Torch.C.Value
The cpp node of the target list.
speedup: ModuleSpeed
The Module speedup module.
Returns
-------
values: list
The list of values in the target cpp list node.
"""
# the node that create the list
create_node
=
list_node
.
node
()
assert
create_node
.
kind
()
==
'prim::ListConstruct'
inputs
=
list
(
create_node
.
inputs
())
values
=
[]
for
_i
in
inputs
:
debugName
=
_i
.
debugName
()
if
speedup
is
not
None
and
debugName
in
speedup
.
internal_result
:
# this value is the result of the other nodes, such as
# ate::size
values
.
append
(
speedup
.
internal_result
[
debugName
].
item
())
else
:
# if the corresponding value is a constant
values
.
append
(
_i
.
toIValue
())
return
values
def
parse_constant
(
cvalue
,
speedup
):
"""
Parse the constant values from this Node
Parameters
----------
cvalue: Torch.C.Value
The cpp node of the target constant value.
speedup: ModelSpeedup
The Model speedup module.
Returns
-------
value: int/float/tensor
The constant values parsed from the node.
"""
logger
.
debug
(
'Try to parse the constant value: %s'
,
cvalue
.
debugName
())
if
cvalue
.
toIValue
()
is
not
None
:
return
cvalue
.
toIValue
()
if
cvalue
.
debugName
()
in
speedup
.
internal_result
:
return
speedup
.
internal_result
[
cvalue
.
debugName
()]
# Get the operator node of the this value
op_node
=
cvalue
.
node
()
inputs
=
op_node
.
inputs
()
input_values
=
[
parse_constant
(
_i
,
speedup
)
for
_i
in
inputs
]
func
=
trans_from_jit_to_python
[
op_node
.
kind
()](
op_node
,
speedup
)
return
func
(
*
input_values
)
def
dropout_python
(
node
,
speedup
):
return
torch
.
dropout
def
flatten_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
start_dim
=
inputs
[
1
].
toIValue
()
end_dim
=
inputs
[
2
].
toIValue
()
new_flatten
=
partial
(
torch
.
flatten
,
start_dim
=
start_dim
,
end_dim
=
end_dim
)
return
new_flatten
def
relu_inplace_python
(
node
,
speedup
):
return
torch
.
relu_
def
relu_python
(
node
,
speedup
):
return
torch
.
relu
def
sigmoid_python
(
node
,
speedup
):
return
torch
.
sigmoid
def
mean_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
keep_dim
=
inputs
[
2
].
toIValue
()
new_mean
=
partial
(
torch
.
mean
,
dim
=
tuple
(
dim_list
),
keepdim
=
keep_dim
)
return
new_mean
def
add_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
# this input is a constant value
# TODO: what if this input is a constant tensor
if
input_i
.
toIValue
()
is
not
None
:
constant
=
parse_constant
(
input_i
,
speedup
)
break
if
constant
is
None
:
return
torch
.
add
else
:
new_add
=
partial
(
torch
.
add
,
constant
)
return
new_add
def
floor_div_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
divisor
=
inputs
[
1
]
constant
=
None
if
divisor
.
debugName
()
not
in
speedup
.
internal_result
:
# divisor is a constant value/tensor
constant
=
parse_constant
(
divisor
,
speedup
)
if
constant
is
None
:
return
torch
.
floor_divide
else
:
new_op
=
partial
(
torch
.
floor_divide
,
other
=
constant
)
return
new_op
def
mul_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
None
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
constant
=
parse_constant
(
input_i
,
speedup
)
# both two inputs cannot be constants at the same time
break
if
constant
is
None
:
return
torch
.
mul
else
:
new_mul
=
partial
(
torch
.
mul
,
constant
)
return
new_mul
def
transpose_python
(
node
,
speedup
):
return
torch
.
t
def
transpose2_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_1
=
inputs
[
1
].
toIValue
()
dim_2
=
inputs
[
2
].
toIValue
()
new_transpose
=
partial
(
torch
.
transpose
,
dim0
=
dim_1
,
dim1
=
dim_2
)
return
new_transpose
def
matmul_python
(
node
,
speedup
):
return
torch
.
matmul
def
div_python
(
node
,
speedup
):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
if
inputs
[
1
].
debugName
()
in
speedup
.
internal_result
:
# the second input parameters is the output of the other
# nodes
return
torch
.
div
else
:
other
=
inputs
[
1
].
toIValue
()
new_div
=
partial
(
torch
.
div
,
other
=
other
)
return
new_div
def
softmax_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
new_softmax
=
partial
(
torch
.
softmax
,
dim
=
dim
)
return
new_softmax
def
contiguous_python
(
node
,
speedup
):
class
contiguousModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
contiguous
()
return
contiguousModule
()
def
gelu_python
(
node
,
speedup
):
return
torch
.
nn
.
GELU
()
def
avgpool2d_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
kernel_size
=
translate_list
(
inputs
[
1
],
speedup
)
stride
=
translate_list
(
inputs
[
2
],
speedup
)
padding
=
translate_list
(
inputs
[
3
],
speedup
)
new_avgpool
=
partial
(
torch
.
nn
.
functional
.
avg_pool2d
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
)
return
new_avgpool
def
adaptive_avgpool_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
output_size
=
translate_list
(
inputs
[
1
],
speedup
)
new_avgpool
=
torch
.
nn
.
AdaptiveAvgPool2d
(
output_size
)
return
new_avgpool
def
tupleunpack_python
(
node
,
speedup
):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return
None
def
num2tensor_python
(
node
,
speedup
):
return
torch
.
nn
.
Identity
()
def
exp_python
(
node
,
speedup
):
return
torch
.
exp
def
squeeze_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
None
if
len
(
inputs
)
>
1
:
dim
=
parse_constant
(
inputs
[
1
],
speedup
)
new_squeeze
=
partial
(
torch
.
squeeze
,
dim
=
dim
)
return
new_squeeze
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
def
slice_python
(
node
,
speedup
):
class
SliceMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sliceobj
):
super
(
SliceMoudle
,
self
).
__init__
()
self
.
sliceobj
=
sliceobj
def
forward
(
self
,
x
,
*
args
):
# args is for the slice dimension and indexes, however,
# we already get them from the cpp nodes. Note, though, we
# don't need the slice indexes any more, we cannot remove this
# parameter here, because, there may be multiple inputs passed from
# previous nodes such as aten::size
logger
.
info
(
'Model has Slice operation, and the operand size=%s, Slice object:%s'
,
str
(
x
.
size
()),
str
(
self
.
sliceobj
))
return
x
[
self
.
sliceobj
]
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
slice_dim
=
parse_constant
(
inputs
[
1
],
speedup
)
slice_start
=
parse_constant
(
inputs
[
2
],
speedup
)
slice_end
=
parse_constant
(
inputs
[
3
],
speedup
)
slice_step
=
parse_constant
(
inputs
[
4
],
speedup
)
slice_obj
=
slice
(
slice_start
,
slice_end
,
slice_step
)
slice_list
=
[]
for
_
in
range
(
slice_dim
):
slice_list
.
append
(
slice
(
None
,
None
))
logger
.
info
(
'Slice dim:%s, Slice obj:%s'
,
str
(
slice_dim
),
str
(
slice_obj
))
slice_list
.
append
(
slice_obj
)
return
SliceMoudle
(
tuple
(
slice_list
))
def
select_python
(
node
,
speedup
):
class
SelectModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
index
):
super
(
SelectModule
,
self
).
__init__
()
self
.
dim
=
dim
self
.
index
=
index
def
forward
(
self
,
x
):
return
x
.
select
(
self
.
dim
,
self
.
index
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
index
=
inputs
[
2
].
toIValue
()
return
SelectModule
(
dim
,
index
)
def
size_python
(
node
,
speedup
):
# return None
class
SizeMoudle
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sizedim
):
super
(
SizeMoudle
,
self
).
__init__
()
self
.
sizedim
=
sizedim
def
forward
(
self
,
x
):
return
torch
.
as_tensor
([
x
.
size
(
self
.
sizedim
)],
dtype
=
torch
.
long
)
# return torch.tensor(x.size(self.sizedim))
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_dim
=
inputs
[
1
].
toIValue
()
return
SizeMoudle
(
size_dim
)
def
toint_python
(
node
,
speedup
):
class
ToIntModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
to
(
torch
.
int
)
return
ToIntModule
()
def
view_python
(
node
,
speedup
):
class
ViewModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ViewModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'View Module output size: %s'
,
str
(
self
.
shape
))
def
forward
(
self
,
*
args
):
return
args
[
0
].
view
(
self
.
shape
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
return
ViewModule
(
shape
)
def
reshape_python
(
node
,
speedup
):
class
ReshapeModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
shape
):
super
(
ReshapeModule
,
self
).
__init__
()
self
.
shape
=
shape
logger
.
info
(
'Reshape Module output size: %s'
,
str
(
self
.
shape
))
def
forward
(
self
,
*
args
):
return
args
[
0
].
view
(
self
.
shape
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
return
ReshapeModule
(
shape
)
def
permute_python
(
node
,
speedup
):
class
PermuteModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dimlist
):
super
(
PermuteModule
,
self
).
__init__
()
self
.
dimlist
=
dimlist
def
forward
(
self
,
x
):
return
x
.
permute
(
self
.
dimlist
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim_list
=
translate_list
(
inputs
[
1
],
speedup
)
return
PermuteModule
(
dim_list
)
def
getattr_python
(
node
,
speedup
):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class
GetModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
key
):
super
(
GetModule
,
self
).
__init__
()
self
.
key
=
key
def
forward
(
self
,
obj
):
logger
.
info
(
'Get attribute: %s'
,
self
.
key
)
return
getattr
(
obj
,
self
.
key
)
# get the name of the attribute, for example
# prim::GetAttr[name="module_list"](%self.1)
assert
node
.
kind
()
==
'prim::GetAttr'
pattern
=
'\[name=
\"
(.*?)
\"
\]'
key_words
=
re
.
findall
(
pattern
,
str
(
node
))
assert
len
(
key_words
)
==
1
return
GetModule
(
key_words
[
0
])
def
upsample_bilinear2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
def
forward
(
self
,
*
args
):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return
torch
.
nn
.
functional
.
upsample_bilinear
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
3
].
node
()
size_list
=
None
scale_list
=
None
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
scale_list
=
translate_list
(
inputs
[
3
],
speedup
)
return
UpsampleModule
(
size_list
,
scale_list
)
def
typeas_python
(
node
,
speedup
):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
"""
class
TypeasModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float
):
self
.
example
=
torch
.
zeros
(
1
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
return
x
.
type_as
(
self
.
example
)
return
TypeasModule
()
def
to_python
(
node
,
speedup
):
# for the time being, only device parameters are supported
class
ToModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
):
super
(
ToModule
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
.
to
(
device
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
device
=
inputs
[
3
].
toIValue
()
return
ToModule
(
device
)
def
cat_python
(
node
,
speedup
):
class
CatModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
cat_dim
):
super
(
CatModule
,
self
).
__init__
()
self
.
cat_dim
=
cat_dim
def
forward
(
self
,
*
args
):
return
torch
.
cat
(
args
,
dim
=
self
.
cat_dim
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
dim
=
inputs
[
1
].
toIValue
()
return
CatModule
(
dim
)
trans_from_jit_to_python
=
{
'aten::add'
:
add_python
,
'aten::add_'
:
add_python
,
'aten::mul'
:
mul_python
,
'aten::mul_'
:
mul_python
,
'aten::relu'
:
relu_python
,
'aten::relu_'
:
relu_inplace_python
,
'aten::sigmoid'
:
sigmoid_python
,
'aten::sigmoid_'
:
sigmoid_python
,
# tanh behaives like relu
'aten::tanh'
:
relu_python
,
'aten::tanh_'
:
relu_python
,
'aten::flatten'
:
flatten_python
,
'aten::mean'
:
mean_python
,
'aten::dropout'
:
dropout_python
,
'aten::slice'
:
slice_python
,
'aten::select'
:
select_python
,
'aten::size'
:
size_python
,
'aten::t'
:
transpose_python
,
'aten::transpose'
:
transpose2_python
,
'aten::Int'
:
toint_python
,
'aten::view'
:
view_python
,
'aten::reshape'
:
reshape_python
,
'aten::permute'
:
permute_python
,
'aten::matmul'
:
matmul_python
,
'aten::div'
:
div_python
,
'aten::floor_divide'
:
floor_div_python
,
'aten::softmax'
:
softmax_python
,
'aten::contiguous'
:
contiguous_python
,
'aten::gelu'
:
gelu_python
,
'aten::cat'
:
cat_python
,
'aten::avg_pool2d'
:
avgpool2d_python
,
'aten::max_pool2d'
:
avgpool2d_python
,
'aten::adaptive_avg_pool2d'
:
adaptive_avgpool_python
,
'aten::to'
:
to_python
,
'aten::type_as'
:
typeas_python
,
'aten::upsample_bilinear2d'
:
upsample_bilinear2d_python
,
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::GetAttr'
:
getattr_python
}
def
jit_to_python_function
(
node
,
speedup
):
"""
Return a callable object to inference the mask according to the
node.op_type.
Parameters
---------
node: NodeGroup
The target node to inference the mask
speedup: ModelSpeedup
The speedup object of the target model.
Returns
------
func: callable object(nn.Module/function)
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
logger
.
debug
(
'Translate C function %s into its python version'
,
node
.
op_type
)
if
node
.
op_type
not
in
trans_from_jit_to_python
:
logger
.
error
(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~'
,
node
.
op_type
)
# return None to skip the mask inference for this node
return
None
return
trans_from_jit_to_python
[
node
.
op_type
](
node
,
speedup
)
nni/compression/pytorch/utils/__init__.py
View file @
7eedec46
from
.utils
import
*
\ No newline at end of file
nni/compression/pytorch/utils/mask_conflict.py
View file @
7eedec46
...
@@ -4,10 +4,10 @@ import os
...
@@ -4,10 +4,10 @@ import os
import
logging
import
logging
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
.shape_dependency
import
ChannelDependency
,
GroupDependency
,
CatPaddingDependency
,
InputChannelDependency
from
.shape_dependency
import
ChannelDependency
,
GroupDependency
,
InputChannelDependency
from
.utils
import
get_module_by_name
from
.utils
import
get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
# logging.basicConfig(level = logging.DEBUG)
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
'FixMaskConflict'
)
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
...
@@ -21,7 +21,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
...
@@ -21,7 +21,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
A dict object that stores the masks or the path of the mask file
A dict object that stores the masks or the path of the mask file
model : torch.nn.Module
model : torch.nn.Module
model to fix the mask conflict
model to fix the mask conflict
dummy_input : torch.Tensor
dummy_input : torch.Tensor
/list of tensors/dict of tensors
input example to trace the model
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
the traced model of the target model, is this parameter is not None,
...
@@ -48,9 +48,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
...
@@ -48,9 +48,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks
=
fix_group_mask
.
fix_mask
()
masks
=
fix_group_mask
.
fix_mask
()
fix_channel_mask
=
ChannelMaskConflict
(
masks
,
model
,
dummy_input
,
traced
)
fix_channel_mask
=
ChannelMaskConflict
(
masks
,
model
,
dummy_input
,
traced
)
masks
=
fix_channel_mask
.
fix_mask
()
masks
=
fix_channel_mask
.
fix_mask
()
padding_cat_mask
=
CatMaskPadding
(
masks
,
model
,
dummy_input
,
traced
)
return
masks
masks
=
padding_cat_mask
.
fix_mask
()
return
masks
,
fix_channel_mask
.
conv_prune_dim
class
MaskFix
:
class
MaskFix
:
...
@@ -78,70 +76,6 @@ class MaskFix:
...
@@ -78,70 +76,6 @@ class MaskFix:
torch
.
save
(
self
.
masks
,
path
)
torch
.
save
(
self
.
masks
,
path
)
class
CatMaskPadding
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
=
None
,
traced
=
None
):
"""
CatMaskPadding find the layers whose output tensor is passed
to the same cat operation. The cat operation concatnates the
masks of the input tensors as the output mask, so when some
of the input layers of the cat operation are not pruned, we still
need to pass the masks of these non-pruned layers(the mask are
all ones) to the cat operation to ensure the shape of the output
mask is right.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super
(
CatMaskPadding
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
def
fix_mask
(
self
):
cat_padding_depen
=
CatPaddingDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
name_to_module
=
{}
for
name
,
module
in
self
.
model
.
named_modules
():
name_to_module
[
name
]
=
module
depen
=
cat_padding_depen
.
dependency_sets
for
layers
in
depen
:
device
=
None
count
=
0
for
layer
in
layers
:
if
layer
in
self
.
masks
:
count
+=
1
if
device
is
None
:
device
=
self
.
masks
[
layer
][
'weight'
].
device
if
count
==
0
:
# no layer is pruned
continue
elif
count
==
len
(
layers
):
# all the layers have been pruned
continue
# pad the mask for the non-pruned layers
for
layer
in
layers
:
if
layer
in
self
.
masks
:
continue
module
=
name_to_module
[
layer
]
w_shape
=
module
.
weight
.
data
.
size
()
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
b_mask
=
None
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
# module.bias may be None
b_shape
=
module
.
bias
.
data
.
size
()
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
return
self
.
masks
class
GroupMaskConflict
(
MaskFix
):
class
GroupMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
"""
...
@@ -172,9 +106,11 @@ class GroupMaskConflict(MaskFix):
...
@@ -172,9 +106,11 @@ class GroupMaskConflict(MaskFix):
group_depen
=
GroupDependency
(
group_depen
=
GroupDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depens
=
group_depen
.
dependency
depens
=
group_depen
.
dependency
min_groups
=
group_depen
.
min_groups
_logger
.
info
(
depens
)
_logger
.
info
(
depens
)
for
layername
in
depens
:
for
layername
in
depens
:
group
=
depens
[
layername
]
group_max
=
depens
[
layername
]
group_min
=
min_groups
[
layername
]
if
layername
not
in
self
.
masks
:
if
layername
not
in
self
.
masks
:
# this layer not pruned
# this layer not pruned
continue
continue
...
@@ -187,28 +123,42 @@ class GroupMaskConflict(MaskFix):
...
@@ -187,28 +123,42 @@ class GroupMaskConflict(MaskFix):
# In fine-grained pruning, skip this layer
# In fine-grained pruning, skip this layer
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
continue
continue
assert
shape
[
0
]
%
group
==
0
assert
shape
[
0
]
%
group
_max
==
0
# Find the number of masked filter for each group (mini_masked).
# Find the number of masked filter for each group (mini_masked).
# Because we have to keep the pruned filter can still
# Because we have to keep the pruned filter can still
# be divided into the same number of groups, so we only can
# be divided into the same number of groups, so we only can
# prune mini_masked filters for each group.
# prune mini_masked filters for each group.
step
=
shape
[
0
]
/
group
step
=
shape
[
0
]
/
group
_max
group_masked
=
[]
group_masked
=
[]
for
i
in
range
(
group
):
for
i
in
range
(
group
_max
):
_start
=
step
*
i
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
group_masked
.
append
(
_tmp_list
)
group_masked
.
append
(
_tmp_list
)
mini_masked
=
min
([
len
(
x
)
for
x
in
group_masked
])
mini_masked
=
min
([
len
(
x
)
for
x
in
group_masked
])
need_unmask
=
set
()
for
gm
in
group_masked
:
for
gm
in
group_masked
:
for
i
in
range
(
mini_masked
,
len
(
gm
)):
for
i
in
range
(
mini_masked
,
len
(
gm
)):
# To keep the output channel number still being divisible to
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
# groups, we set the masks of following filters to be zero.
pos
=
gm
[
i
]
pos
=
gm
[
i
]
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
need_unmask
.
add
(
pos
)
shape
[
1
:])
step
=
shape
[
0
]
/
group_min
if
'bias'
in
self
.
masks
[
layername
]
and
self
.
masks
[
layername
][
'bias'
]
is
not
None
:
for
i
in
range
(
group_min
):
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
if
len
(
_tmp_list
)
==
step
:
# if the whole group is removed, then we don't have to unmask for
# the filters in this group
for
pos
in
_tmp_list
:
if
pos
in
need_unmask
:
need_unmask
.
remove
(
pos
)
for
pos
in
need_unmask
:
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
shape
[
1
:])
if
hasattr
(
self
.
masks
[
layername
],
'bias'
):
self
.
masks
[
layername
][
'bias'
][
pos
]
=
1
self
.
masks
[
layername
][
'bias'
][
pos
]
=
1
return
self
.
masks
return
self
.
masks
...
@@ -234,9 +184,14 @@ class ChannelMaskConflict(MaskFix):
...
@@ -234,9 +184,14 @@ class ChannelMaskConflict(MaskFix):
super
(
ChannelMaskConflict
,
self
).
__init__
(
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
_logger
.
info
(
'
de
tected conv prune dim
:
%
s
'
,
self
.
conv_prune_dim
)
_logger
.
info
(
'
Dec
tected conv prune dim
"
%
d
'
,
self
.
conv_prune_dim
)
def
fix_mask
(
self
):
def
fix_mask
(
self
):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
"""
"""
Fix the mask conflict before the mask inference for the layers that
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
has shape dependencies. This function should be called before the
...
@@ -274,7 +229,12 @@ class ChannelMaskConflict(MaskFix):
...
@@ -274,7 +229,12 @@ class ChannelMaskConflict(MaskFix):
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
fine_grained
=
True
fine_grained
=
True
elif
type
(
m
).
__name__
==
'Linear'
:
elif
type
(
m
).
__name__
==
'Linear'
:
channel_masks
.
append
((
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
if
self
.
conv_prune_dim
==
1
:
channel_masks
.
append
(
(
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
else
:
channel_masks
.
append
(
(
mask
.
abs
().
sum
(
1
)
!=
0
).
int
())
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
channel_masks
.
append
(
mask
.
int
())
channel_masks
.
append
(
mask
.
int
())
elif
type
(
m
).
__name__
==
'ConvTranspose2d'
:
elif
type
(
m
).
__name__
==
'ConvTranspose2d'
:
...
@@ -293,9 +253,7 @@ class ChannelMaskConflict(MaskFix):
...
@@ -293,9 +253,7 @@ class ChannelMaskConflict(MaskFix):
# no mask means not pruned, equivlent to full masks
# no mask means not pruned, equivlent to full masks
channel_masks
.
append
(
None
)
channel_masks
.
append
(
None
)
if
fine_grained
:
if
fine_grained
:
_logger
.
info
(
_logger
.
info
(
"Fine-grianed mask detected"
)
'fine-grained mask detected, skip solving conflict for this set: %s'
,
dset
)
continue
if
all
(
x
is
None
for
x
in
channel_masks
):
if
all
(
x
is
None
for
x
in
channel_masks
):
continue
continue
num_channels_list
=
[
len
(
x
)
num_channels_list
=
[
len
(
x
)
...
@@ -306,7 +264,8 @@ class ChannelMaskConflict(MaskFix):
...
@@ -306,7 +264,8 @@ class ChannelMaskConflict(MaskFix):
for
i
,
dim_mask
in
enumerate
(
channel_masks
):
for
i
,
dim_mask
in
enumerate
(
channel_masks
):
if
dim_mask
is
None
:
if
dim_mask
is
None
:
channel_masks
[
i
]
=
torch
.
ones
(
num_channels
).
int
().
to
(
device
)
channel_masks
[
i
]
=
torch
.
ones
(
num_channels
).
int
().
to
(
device
)
# merge masks with 'or'
# merge masks with 'or'
merged_channel_mask
=
channel_masks
[
0
].
clone
()
merged_channel_mask
=
channel_masks
[
0
].
clone
()
...
@@ -329,17 +288,20 @@ class ChannelMaskConflict(MaskFix):
...
@@ -329,17 +288,20 @@ class ChannelMaskConflict(MaskFix):
else
:
else
:
new_mask
[:,
merged_index
,
:,
:]
=
1.
new_mask
[:,
merged_index
,
:,
:]
=
1.
elif
type
(
m
).
__name__
==
'Linear'
:
elif
type
(
m
).
__name__
==
'Linear'
:
if
self
.
conv_prune_dim
==
0
:
new_mask
[
merged_index
,
:]
=
1
elif
self
.
conv_prune_dim
==
1
:
new_mask
[:,
merged_index
]
=
1.
new_mask
[:,
merged_index
]
=
1.
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
new_mask
=
merged_channel_mask
.
type_as
(
orig_mask
)
new_mask
=
merged_channel_mask
.
type_as
(
orig_mask
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
self
.
masks
[
name
][
'weight'
]
=
new_mask
self
.
masks
[
name
][
'weight'
]
=
new_mask
if
'bias'
in
self
.
masks
[
name
]
and
self
.
masks
[
name
][
'bias'
]
is
not
None
:
if
'bias'
in
self
.
masks
[
name
]
and
self
.
masks
[
name
][
'bias'
]
is
not
None
:
if
type
(
m
).
__name__
==
'Conv2d'
:
if
type
(
m
).
__name__
==
'Conv2d'
:
assert
self
.
conv_prune_dim
==
0
assert
self
.
conv_prune_dim
==
0
if
self
.
conv_prune_dim
==
0
:
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
self
.
masks
[
name
][
'bias'
])
...
@@ -349,14 +311,12 @@ class ChannelMaskConflict(MaskFix):
...
@@ -349,14 +311,12 @@ class ChannelMaskConflict(MaskFix):
def
detect_mask_prune_dim
(
masks
,
model
):
def
detect_mask_prune_dim
(
masks
,
model
):
"""
"""
Detect how the masks of convolutional layers are pruned.
Detect how the masks of convolutional layers are pruned.
Parameters
Parameters
----------
----------
masks: dict
masks: dict
A dict object that stores the masks.
A dict object that stores the masks.
model: nn.Module
model: nn.Module
Model object which the mask can be applied on.
Model object which the mask can be applied on.
Returns:
Returns:
-------
-------
How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should
How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
7eedec46
...
@@ -3,18 +3,34 @@
...
@@ -3,18 +3,34 @@
import
csv
import
csv
import
logging
import
logging
import
numpy
as
np
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPadding
Dependency'
,
'InputChannelDependency'
]
__all__
=
[
'ChannelDependency'
,
'Group
Dependency'
,
'InputChannelDependency'
]
CONV_TYPE
=
'aten::_convolution'
CONV_TYPE
=
'aten::_convolution'
ADD_TYPES
=
[
'aten::add'
,
'aten::add_'
]
ADD_TYPES
=
[
'aten::add'
,
'aten::add_'
]
MUL_TYPES
=
[
'aten::mul'
,
'atem::mul_'
]
CAT_TYPE
=
'aten::cat'
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
def
lcm_list
(
L
):
lcm
=
1
for
i
in
L
:
lcm
=
np
.
lcm
(
lcm
,
i
)
return
lcm
def
gcd_list
(
L
):
gcd
=
L
[
0
]
for
i
in
L
:
gcd
=
np
.
gcd
(
gcd
,
i
)
return
gcd
class
Dependency
:
class
Dependency
:
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
"""
...
@@ -38,6 +54,35 @@ class Dependency:
...
@@ -38,6 +54,35 @@ class Dependency:
raise
NotImplementedError
raise
NotImplementedError
def
reshape_break_channel_dependency
(
op_node
):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
in_channel
=
in_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
class
ChannelDependency
(
Dependency
):
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
"""
...
@@ -80,6 +125,9 @@ class ChannelDependency(Dependency):
...
@@ -80,6 +125,9 @@ class ChannelDependency(Dependency):
# find the first met conv
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
parent_layers
.
append
(
curnode
.
name
)
continue
continue
elif
curnode
.
op_type
in
RESHAPE_OPS
:
if
reshape_break_channel_dependency
(
curnode
):
continue
parents
=
self
.
graph
.
find_predecessors
(
curnode
.
unique_name
)
parents
=
self
.
graph
.
find_predecessors
(
curnode
.
unique_name
)
parents
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
parents
]
parents
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
parents
]
for
parent
in
parents
:
for
parent
in
parents
:
...
@@ -176,7 +224,7 @@ class ChannelDependency(Dependency):
...
@@ -176,7 +224,7 @@ class ChannelDependency(Dependency):
d_sets
=
[]
d_sets
=
[]
visited
=
set
()
visited
=
set
()
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
!=
'Conv2d'
or
node
in
visited
:
if
(
node
.
op_type
!=
'Conv2d'
and
node
.
op_type
!=
'Linear'
)
or
node
in
visited
:
continue
continue
tmp_set
=
set
()
tmp_set
=
set
()
if
node
.
name
not
in
self
.
dependency
:
if
node
.
name
not
in
self
.
dependency
:
...
@@ -190,35 +238,6 @@ class ChannelDependency(Dependency):
...
@@ -190,35 +238,6 @@ class ChannelDependency(Dependency):
return
d_sets
return
d_sets
def
reshape_break_channel_dependency
(
op_node
):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
in_channel
=
in_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
class
InputChannelDependency
(
ChannelDependency
):
class
InputChannelDependency
(
ChannelDependency
):
"""
"""
Some pruners may prune the input channel of the convolutional
Some pruners may prune the input channel of the convolutional
...
@@ -295,67 +314,6 @@ class InputChannelDependency(ChannelDependency):
...
@@ -295,67 +314,6 @@ class InputChannelDependency(ChannelDependency):
self
.
dependency
[
layer
]
=
dependency_set
self
.
dependency
[
layer
]
=
dependency_set
class
CatPaddingDependency
(
ChannelDependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
(
CatPaddingDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
build_dependency
(
self
):
"""
Build the cat padding dependencies.
If the output features of several layers are stitched together
by cat operation, then these layers have cat padding dependencies.
This is because when inferring the cat mask, we need all the input
masks for the cat operation. At this time we need to know the source
of all input vectors of a cat operation.
"""
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
parent_layers
=
[]
if
node
.
op_type
==
CAT_TYPE
:
parent_layers
=
self
.
_get_parent_layers
(
node
)
dependency_set
=
set
(
parent_layers
)
# merge the dependencies
for
parent
in
parent_layers
:
if
parent
in
self
.
dependency
:
dependency_set
.
update
(
self
.
dependency
[
parent
])
# save the dependencies
for
_node
in
dependency_set
:
self
.
dependency
[
_node
]
=
dependency_set
@
property
def
dependency_sets
(
self
):
d_sets
=
[]
visited
=
set
()
for
nodename
in
self
.
dependency
:
if
nodename
in
visited
:
continue
d_sets
.
append
(
self
.
dependency
[
nodename
])
return
d_sets
def
export
(
self
,
filepath
):
"""
Export the dependencies into a file.
In the output file, each line contains a set of layers
whose output features are stitched together by the cat
operation.
output example:
Dependency Set, Layers
set1, Conv1, Conv2
set2, Conv3, Conv4
"""
header
=
[
'Dependency Set'
,
'Layers'
]
setid
=
0
with
open
(
filepath
,
'w'
)
as
csvf
:
csv_w
=
csv
.
writer
(
csvf
,
delimiter
=
','
)
csv_w
.
writerow
(
header
)
for
layers
in
self
.
dependency_sets
:
setid
+=
1
row
=
[
'Set %d'
%
setid
]
row
.
extend
(
list
(
layers
))
csv_w
.
writerow
(
row
)
class
GroupDependency
(
Dependency
):
class
GroupDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
"""
...
@@ -372,6 +330,7 @@ class GroupDependency(Dependency):
...
@@ -372,6 +330,7 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
need to trace the model again.
"""
"""
self
.
min_groups
=
{}
super
(
GroupDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
GroupDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_parent_convs
(
self
,
node
):
def
_get_parent_convs
(
self
,
node
):
...
@@ -451,27 +410,33 @@ class GroupDependency(Dependency):
...
@@ -451,27 +410,33 @@ class GroupDependency(Dependency):
key: the name of conv layers, value: the minimum value that the number of
key: the name of conv layers, value: the minimum value that the number of
filters should be divisible to.
filters should be divisible to.
"""
"""
self
.
groups
=
{}
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
==
'Conv2d'
or
node
.
op_type
==
'ConvTranspose2d'
:
if
node
.
op_type
==
'Conv2d'
or
node
.
op_type
==
'ConvTranspose2d'
:
group
=
self
.
_get_conv_groups
(
node
)
group
=
self
.
_get_conv_groups
(
node
)
if
node
.
name
in
self
.
groups
:
if
node
.
name
in
self
.
dependency
:
# the conv layer whose group is larger than 1 will require that
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
# it's number of output channel to be divisible by the number of group.
self
.
dependency
[
node
.
name
]
=
max
(
self
.
groups
[
node
.
name
].
append
(
group
)
self
.
dependency
[
node
.
name
],
group
)
else
:
else
:
self
.
dependency
[
node
.
name
]
=
group
self
.
groups
[
node
.
name
]
=
[
group
]
if
group
>
1
:
if
group
>
1
:
# for the conv layer whose group is larger than 1, it will require the number
# for the conv layer whose group is larger than 1, it will require the number
# of output channels of their parent conv layer to be divisible by group.
# of output channels of their parent conv layer to be divisible by group.
parent_convs
=
self
.
_get_parent_convs
(
node
)
parent_convs
=
self
.
_get_parent_convs
(
node
)
for
parent
in
parent_convs
:
for
parent
in
parent_convs
:
if
parent
in
self
.
dependency
:
if
parent
in
self
.
groups
:
self
.
dependency
[
parent
]
=
max
(
self
.
groups
[
parent
].
append
(
group
)
self
.
dependency
[
parent
],
group
)
else
:
else
:
self
.
dependency
[
parent
]
=
group
self
.
groups
[
parent
]
=
[
group
]
for
name
in
self
.
groups
:
self
.
dependency
[
name
]
=
lcm_list
(
self
.
groups
[
name
])
if
min
(
self
.
groups
[
name
])
==
gcd_list
(
self
.
groups
[
name
]):
self
.
min_groups
[
name
]
=
min
(
self
.
groups
[
name
])
else
:
self
.
min_groups
[
name
]
=
1
return
self
.
dependency
return
self
.
dependency
def
export
(
self
,
filepath
):
def
export
(
self
,
filepath
):
...
@@ -501,3 +466,110 @@ class GroupDependency(Dependency):
...
@@ -501,3 +466,110 @@ class GroupDependency(Dependency):
@
property
@
property
def
dependency_sets
(
self
):
def
dependency_sets
(
self
):
return
self
.
dependency
return
self
.
dependency
class
ReshapeDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
Some model may have the view/reshape functions, such functions may have fixed parameters
and cannot be replaced at all. Therefore, these functions may have some constraints on
their input shapes. In this class, we find the direct input conv/linear layers of these
reshape functions. If you get the shape conflict when run the forward inference on the
speeduped model, please try remove these layers from the pruner config list and try again.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super
(
ReshapeDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_parent_layers
(
self
,
node
):
"""
Find the nearest father conv layers for the target node.
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""
parent_layers
=
[]
queue
=
[]
queue
.
append
(
node
)
while
queue
:
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
continue
parents
=
self
.
graph
.
find_predecessors
(
curnode
.
unique_name
)
parents
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
parents
]
for
parent
in
parents
:
queue
.
append
(
parent
)
return
parent_layers
def
build_dependency
(
self
):
"""
Build the channel dependency for the conv layers
in the model.
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self
.
graph
.
unpack_manually
()
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
parent_layers
=
[]
# find the node that contains aten::add
# or aten::cat operations
if
node
.
op_type
in
[
'aten::view'
,
'aten::reshape'
]:
logger
.
info
(
'Detect reshape-like functions: %s'
,
node
.
op_type
)
parent_layers
=
self
.
_get_parent_layers
(
node
)
print
(
'Parent layers'
,
parent_layers
)
self
.
dependency
[
node
.
unique_name
]
=
parent_layers
def
export
(
self
,
filepath
):
"""
export the reshape dependencies as a csv file.
Output example:
Reshape OP, Dependent Layers
model.view.1,layer1.1.conv2,layer1.0.conv2,conv1
model.mean.1,layer1.0.conv1
model.reshape.1,layer1.1.conv1
"""
header
=
[
'Reshape OP'
,
'Dependent Layers'
]
with
open
(
filepath
,
'w'
)
as
csvf
:
csv_w
=
csv
.
writer
(
csvf
,
delimiter
=
','
)
csv_w
.
writerow
(
header
)
for
reshape_op
in
self
.
dependency
:
row
=
[
reshape_op
].
extend
(
self
.
dependency
[
reshape_op
])
csv_w
.
writerow
(
row
)
@
property
def
dependency_sets
(
self
):
"""
Get the list of the dependency set.
Returns
-------
dependency_sets : list
list of the dependency sets. For example,
[set(['conv1', 'conv2']), set(['conv3', 'conv4'])]
"""
d_sets
=
[]
for
reshape_node
in
self
.
dependency
:
d_sets
.
extend
(
self
.
dependency
[
reshape_node
])
d_sets
=
list
(
set
(
d_sets
))
return
d_sets
nni/compression/pytorch/utils/utils.py
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
torch
from
.shape_dependency
import
ReshapeDependency
torch_float_dtype
=
[
torch
.
float
,
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
half
,
torch
.
double
]
torch_integer_dtype
=
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int32
,
torch
.
long
,
torch
.
bool
]
def
get_module_by_name
(
model
,
module_name
):
def
get_module_by_name
(
model
,
module_name
):
"""
"""
...
@@ -28,3 +33,50 @@ def get_module_by_name(model, module_name):
...
@@ -28,3 +33,50 @@ def get_module_by_name(model, module_name):
return
model
,
leaf_module
return
model
,
leaf_module
else
:
else
:
return
None
,
None
return
None
,
None
def
rand_like_with_shape
(
shape
,
ori_t
):
"""
Return a new random tensor like the original
tensor.
"""
assert
isinstance
(
ori_t
,
torch
.
Tensor
)
device
=
ori_t
.
device
dtype
=
ori_t
.
dtype
require_grad
=
ori_t
.
requires_grad
lower_bound
=
torch
.
min
(
ori_t
)
higher_bound
=
torch
.
max
(
ori_t
)
if
dtype
in
[
torch
.
uint8
,
torch
.
int16
,
torch
.
short
,
torch
.
int16
,
torch
.
long
,
torch
.
bool
]:
return
torch
.
randint
(
lower_bound
,
higher_bound
+
1
,
shape
,
dtype
=
dtype
,
device
=
device
)
else
:
return
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
require_grad
)
def
randomize_tensor
(
tensor
,
start
=
1
,
end
=
100
):
"""
Randomize the target tensor according to the given
range.
"""
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
dtype
in
torch_integer_dtype
:
# integer tensor can only be randomized by the torch.randint
# torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype)
pass
else
:
# we can use nn.init.uniform_ to randomize this tensor
# Note: the tensor that with integer type cannot be randomize
# with nn.init.uniform_
torch
.
nn
.
init
.
uniform_
(
tensor
.
data
,
start
,
end
)
def
not_safe_to_prune
(
model
,
dummy_input
):
"""
Get the layers that are safe to prune(will not bring the shape conflict).
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset
=
ReshapeDependency
(
model
,
dummy_input
)
return
reshape_dset
.
dependency_sets
\ No newline at end of file
test/ut/sdk/test_compression_utils.py
View file @
7eedec46
...
@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
...
@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
pruner
.
export_model
(
ck_file
,
mask_file
)
pruner
.
export_model
(
ck_file
,
mask_file
)
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
# Fix the mask conflict
# Fix the mask conflict
fixed_mask
,
_
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
fixed_mask
=
fix_mask_conflict
(
mask_file
,
net
,
dummy_input
)
# use the channel dependency groud truth to check if
# use the channel dependency groud truth to check if
# fix the mask conflict successfully
# fix the mask conflict successfully
...
...
test/ut/sdk/test_model_speedup.py
View file @
7eedec46
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
os
import
os
import
gc
import
psutil
import
psutil
import
sys
import
sys
import
numpy
as
np
import
numpy
as
np
...
@@ -9,18 +11,20 @@ import torch
...
@@ -9,18 +11,20 @@ import torch
import
torchvision.models
as
models
import
torchvision.models
as
models
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision.models.vgg
import
vgg16
from
torchvision.models.vgg
import
vgg16
,
vgg11
from
torchvision.models.resnet
import
resnet18
from
torchvision.models.resnet
import
resnet18
from
torchvision.models.mobilenet
import
mobilenet_v2
import
unittest
import
unittest
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
,
LevelPruner
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner
import
DependencyAwarePruner
from
nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner
import
DependencyAwarePruner
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
BATCH_SIZE
=
2
BATCH_SIZE
=
2
# the relative distance
# the relative distance
RELATIVE_THRESHOLD
=
0.01
RELATIVE_THRESHOLD
=
0.01
...
@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
...
@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
return
x
return
x
class
TupleUnpack_backbone
(
nn
.
Module
):
def
__init__
(
self
,
width
):
super
(
TupleUnpack_backbone
,
self
).
__init__
()
self
.
model_backbone
=
mobilenet_v2
(
pretrained
=
False
,
width_mult
=
width
,
num_classes
=
3
)
def
forward
(
self
,
x
):
x1
=
self
.
model_backbone
.
features
[:
7
](
x
)
x2
=
self
.
model_backbone
.
features
[
7
:
14
](
x1
)
x3
=
self
.
model_backbone
.
features
[
14
:
18
](
x2
)
return
[
x1
,
x2
,
x3
]
class
TupleUnpack_FPN
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_FPN
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
32
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
96
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
self
.
conv3
=
nn
.
Conv2d
(
320
,
48
,
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
bias
=
False
)
# self.init_weights()
def
forward
(
self
,
inputs
):
"""Forward function."""
laterals
=
[]
laterals
.
append
(
self
.
conv1
(
inputs
[
0
]))
# inputs[0]==x1
laterals
.
append
(
self
.
conv2
(
inputs
[
1
]))
# inputs[1]==x2
laterals
.
append
(
self
.
conv3
(
inputs
[
2
]))
# inputs[2]==x3
return
laterals
class
TupleUnpack_Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TupleUnpack_Model
,
self
).
__init__
()
self
.
backbone
=
TupleUnpack_backbone
(
1.0
)
self
.
fpn
=
TupleUnpack_FPN
()
def
forward
(
self
,
x
):
x1
=
self
.
backbone
(
x
)
out
=
self
.
fpn
(
x1
)
return
out
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
)
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
)
SPARSITY
=
0.5
SPARSITY
=
0.5
MODEL_FILE
,
MASK_FILE
=
'./11_model.pth'
,
'./l1_mask.pth'
MODEL_FILE
,
MASK_FILE
=
'./11_model.pth'
,
'./l1_mask.pth'
...
@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
...
@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
return
cfg_list
def
generate_random_sparsity_v2
(
model
):
def
generate_random_sparsity_v2
(
model
):
"""
"""
Only select 50% layers to prune.
Only select 50% layers to prune.
...
@@ -142,6 +196,7 @@ def generate_random_sparsity_v2(model):
...
@@ -142,6 +196,7 @@ def generate_random_sparsity_v2(model):
'sparsity'
:
sparsity
})
'sparsity'
:
sparsity
})
return
cfg_list
return
cfg_list
def
zero_bn_bias
(
model
):
def
zero_bn_bias
(
model
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
...
@@ -231,19 +286,6 @@ def channel_prune(model):
...
@@ -231,19 +286,6 @@ def channel_prune(model):
class
SpeedupTestCase
(
TestCase
):
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
model
=
vgg16
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
MASK_FILE
)
ms
.
speedup_model
()
orig_model
=
vgg16
()
assert
model
.
training
assert
model
.
features
[
2
].
out_channels
==
int
(
orig_model
.
features
[
2
].
out_channels
*
SPARSITY
)
assert
model
.
classifier
[
0
].
in_features
==
int
(
orig_model
.
classifier
[
0
].
in_features
*
SPARSITY
)
def
test_speedup_bigmodel
(
self
):
def
test_speedup_bigmodel
(
self
):
prune_model_l1
(
BigModel
())
prune_model_l1
(
BigModel
())
...
@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
...
@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
mask_out
=
model
(
dummy_input
)
mask_out
=
model
(
dummy_input
)
model
.
train
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
speedup_model
()
assert
model
.
training
assert
model
.
training
...
@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
...
@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
new_model
=
TransposeModel
()
new_model
=
TransposeModel
()
state_dict
=
torch
.
load
(
MODEL_FILE
)
state_dict
=
torch
.
load
(
MODEL_FILE
)
new_model
.
load_state_dict
(
state_dict
)
new_model
.
load_state_dict
(
state_dict
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
)
ms
=
ModelSpeedup
(
new_model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
speedup_model
()
zero_bn_bias
(
ori_model
)
zero_bn_bias
(
ori_model
)
zero_bn_bias
(
new_model
)
zero_bn_bias
(
new_model
)
...
@@ -297,26 +339,34 @@ class SpeedupTestCase(TestCase):
...
@@ -297,26 +339,34 @@ class SpeedupTestCase(TestCase):
new_out
=
new_model
(
dummy_input
)
new_out
=
new_model
(
dummy_input
)
ori_sum
=
torch
.
sum
(
ori_out
)
ori_sum
=
torch
.
sum
(
ori_out
)
speeded_sum
=
torch
.
sum
(
new_out
)
speeded_sum
=
torch
.
sum
(
new_out
)
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
print
(
'Tanspose Speedup Test: ori_sum={} speedup_sum={}'
.
format
(
ori_sum
,
speeded_sum
))
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
# FIXME: This test case might fail randomly, no idea why
def
test_speedup_integration_small
(
self
):
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
model_list
=
[
'resnet18'
,
'mobilenet_v2'
,
'alexnet'
]
self
.
speedup_integration
(
model_list
)
def
test_speedup_integration
(
self
):
# skip this test on windows(7GB mem available) due to memory limit
def
test_speedup_integration_big
(
self
):
# Note: hack trick, may be updated in the future
model_list
=
[
'vgg11'
,
'vgg16'
,
'resnet34'
,
'squeezenet1_1'
,
if
'win'
in
sys
.
platform
or
'Win'
in
sys
.
platform
:
'densenet121'
,
'resnet50'
,
'wide_resnet50_2'
]
print
(
'Skip test_speedup_integration on windows due to memory limit!'
)
mem_info
=
psutil
.
virtual_memory
()
ava_gb
=
mem_info
.
available
/
1024.0
/
1024
/
1024
print
(
'Avaliable memory size: %.2f GB'
%
ava_gb
)
if
ava_gb
<
8.0
:
# memory size is too small that we may run into an OOM exception
# Skip this test in the pipeline test due to memory limitation
return
return
self
.
speedup_integration
(
model_list
)
def
speedup_integration
(
self
,
model_list
,
speedup_cfg
=
None
):
Gen_cfg_funcs
=
[
generate_random_sparsity
,
generate_random_sparsity_v2
]
Gen_cfg_funcs
=
[
generate_random_sparsity
,
generate_random_sparsity_v2
]
for
model_name
in
[
'resnet18'
,
'mobilenet_v2'
,
'squeezenet1_1'
,
'densenet121'
,
'densenet169'
,
#
for model_name in [
'vgg16',
'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121'
,
# 'inception_v3' inception is too large and may fail the pipeline
#
# 'inception_v3' inception is too large and may fail the pipeline
'resnet50'
]:
#
'resnet50']:
for
model_name
in
model_list
:
for
gen_cfg_func
in
Gen_cfg_funcs
:
for
gen_cfg_func
in
Gen_cfg_funcs
:
kwargs
=
{
kwargs
=
{
'pretrained'
:
True
'pretrained'
:
True
...
@@ -334,7 +384,10 @@ class SpeedupTestCase(TestCase):
...
@@ -334,7 +384,10 @@ class SpeedupTestCase(TestCase):
speedup_model
.
eval
()
speedup_model
.
eval
()
# random generate the prune config for the pruner
# random generate the prune config for the pruner
cfgs
=
gen_cfg_func
(
net
)
cfgs
=
gen_cfg_func
(
net
)
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
print
(
"Testing {} with compression config
\n
{}"
.
format
(
model_name
,
cfgs
))
if
len
(
cfgs
)
==
0
:
continue
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
=
L1FilterPruner
(
net
,
cfgs
)
pruner
.
compress
()
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
@@ -345,7 +398,10 @@ class SpeedupTestCase(TestCase):
...
@@ -345,7 +398,10 @@ class SpeedupTestCase(TestCase):
zero_bn_bias
(
speedup_model
)
zero_bn_bias
(
speedup_model
)
data
=
torch
.
ones
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
data
=
torch
.
ones
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
)
if
speedup_cfg
is
None
:
speedup_cfg
=
{}
ms
=
ModelSpeedup
(
speedup_model
,
data
,
MASK_FILE
,
confidence
=
2
,
**
speedup_cfg
)
ms
.
speedup_model
()
ms
.
speedup_model
()
speedup_model
.
eval
()
speedup_model
.
eval
()
...
@@ -360,7 +416,8 @@ class SpeedupTestCase(TestCase):
...
@@ -360,7 +416,8 @@ class SpeedupTestCase(TestCase):
model_name
,
speeded_sum
)
model_name
,
speeded_sum
)
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
print
(
"Collecting Garbage"
)
gc
.
collect
(
2
)
def
test_channel_prune
(
self
):
def
test_channel_prune
(
self
):
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
orig_net
=
resnet18
(
num_classes
=
10
).
to
(
device
)
...
@@ -378,7 +435,7 @@ class SpeedupTestCase(TestCase):
...
@@ -378,7 +435,7 @@ class SpeedupTestCase(TestCase):
net
.
eval
()
net
.
eval
()
data
=
torch
.
randn
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
data
=
torch
.
randn
(
BATCH_SIZE
,
3
,
128
,
128
).
to
(
device
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
)
ms
=
ModelSpeedup
(
net
,
data
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
ms
.
speedup_model
()
ms
.
bound_model
(
data
)
ms
.
bound_model
(
data
)
...
@@ -391,11 +448,56 @@ class SpeedupTestCase(TestCase):
...
@@ -391,11 +448,56 @@ class SpeedupTestCase(TestCase):
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
assert
(
abs
(
ori_sum
-
speeded_sum
)
/
abs
(
ori_sum
)
<
RELATIVE_THRESHOLD
)
or
\
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
(
abs
(
ori_sum
-
speeded_sum
)
<
ABSOLUTE_THRESHOLD
)
def
test_speedup_tupleunpack
(
self
):
"""This test is reported in issue3645"""
model
=
TupleUnpack_Model
()
cfg_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.5
}]
dummy_input
=
torch
.
rand
(
2
,
3
,
224
,
224
)
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
model
(
dummy_input
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
2
)
ms
.
speedup_model
()
def
test_finegrained_speedup
(
self
):
""" Test the speedup on the fine-grained sparsity"""
class
MLP
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MLP
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc2
=
nn
.
Linear
(
1024
,
1024
)
self
.
fc3
=
nn
.
Linear
(
1024
,
512
)
self
.
fc4
=
nn
.
Linear
(
512
,
10
)
def
forward
(
self
,
x
):
x
=
x
.
view
(
-
1
,
1024
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
x
=
self
.
fc4
(
x
)
return
x
model
=
MLP
().
to
(
device
)
dummy_input
=
torch
.
rand
(
16
,
1
,
32
,
32
).
to
(
device
)
cfg_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.99
}]
pruner
=
LevelPruner
(
model
,
cfg_list
)
pruner
.
compress
()
print
(
'Original Arch'
)
print
(
model
)
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
ms
=
ModelSpeedup
(
model
,
dummy_input
,
MASK_FILE
,
confidence
=
4
)
ms
.
speedup_model
()
print
(
"Fine-grained speeduped model"
)
print
(
model
)
def
tearDown
(
self
):
def
tearDown
(
self
):
if
os
.
path
.
exists
(
MODEL_FILE
):
if
os
.
path
.
exists
(
MODEL_FILE
):
os
.
remove
(
MODEL_FILE
)
os
.
remove
(
MODEL_FILE
)
if
os
.
path
.
exists
(
MASK_FILE
):
if
os
.
path
.
exists
(
MASK_FILE
):
os
.
remove
(
MASK_FILE
)
os
.
remove
(
MASK_FILE
)
# GC to release memory
gc
.
collect
(
2
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment