Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
06db4729
Unverified
Commit
06db4729
authored
Dec 30, 2019
by
QuanluZhang
Committed by
GitHub
Dec 30, 2019
Browse files
refactor code structure of pruning algorithms (#1882)
parent
9b49245e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
901 additions
and
3 deletions
+901
-3
src/sdk/pynni/nni/compression/torch/__init__.py
src/sdk/pynni/nni/compression/torch/__init__.py
+4
-3
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
...i/nni/compression/torch/activation_rank_filter_pruners.py
+252
-0
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+383
-0
src/sdk/pynni/nni/compression/torch/quantizers.py
src/sdk/pynni/nni/compression/torch/quantizers.py
+0
-0
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+262
-0
No files found.
src/sdk/pynni/nni/compression/torch/__init__.py
View file @
06db4729
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
from
.compressor
import
LayerInfo
,
Compressor
,
Pruner
,
Quantizer
from
.builtin_pruners
import
*
from
.builtin_quantizers
import
*
from
.lottery_ticket
import
LotteryTicketPruner
from
.pruners
import
*
from
.weight_rank_filter_pruners
import
*
from
.activation_rank_filter_pruners
import
*
from
.quantizers
import
*
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
0 → 100644
View file @
06db4729
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
from
.compressor
import
Pruner
__all__
=
[
'ActivationAPoZRankFilterPruner'
,
'ActivationMeanRankFilterPruner'
]
logger
=
logging
.
getLogger
(
'torch activation rank filter pruners'
)
class
ActivationRankFilterPruner
(
Pruner
):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers (using activation values)
to achieve a preset level of network sparsity.
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
statistics_batch_num
=
statistics_batch_num
self
.
collected_activation
=
{}
self
.
hooks
=
{}
assert
activation
in
[
'relu'
,
'relu6'
]
if
activation
==
'relu'
:
self
.
activation
=
torch
.
nn
.
functional
.
relu
elif
activation
==
'relu6'
:
self
.
activation
=
torch
.
nn
.
functional
.
relu6
else
:
self
.
activation
=
None
def
compress
(
self
):
"""
Compress the model, register a hook for collecting activations.
"""
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
self
.
collected_activation
[
layer
.
name
]
=
[]
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
layer
.
module
.
register_forward_hook
(
_hook
)
return
self
.
bound_model
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
else
:
mask_bias
=
None
mask
=
{
'weight'
:
mask_weight
,
'bias'
:
mask_bias
}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
or
len
(
self
.
collected_activation
[
layer
.
name
])
<
self
.
statistics_batch_num
:
return
mask
mask
=
self
.
get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
],
num_prune
)
finally
:
if
len
(
self
.
collected_activation
[
layer
.
name
])
==
self
.
statistics_batch_num
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
class
ActivationAPoZRankFilterPruner
(
ActivationRankFilterPruner
):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz
=
self
.
_calc_apoz
(
activations
)
prune_indices
=
torch
.
argsort
(
apoz
,
descending
=
True
)[:
num_prune
]
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
return
base_mask
def
_calc_apoz
(
self
,
activations
):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations
=
torch
.
cat
(
activations
,
0
)
_eq_zero
=
torch
.
eq
(
activations
,
torch
.
zeros_like
(
activations
))
_apoz
=
torch
.
sum
(
_eq_zero
,
dim
=
(
0
,
2
,
3
))
/
torch
.
numel
(
_eq_zero
[:,
0
,
:,
:])
return
_apoz
class
ActivationMeanRankFilterPruner
(
ActivationRankFilterPruner
):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation
=
self
.
_cal_mean_activation
(
activations
)
prune_indices
=
torch
.
argsort
(
mean_activation
)[:
num_prune
]
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
return
base_mask
def
_cal_mean_activation
(
self
,
activations
):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations
=
torch
.
cat
(
activations
,
0
)
mean_activation
=
torch
.
mean
(
activations
,
dim
=
(
0
,
2
,
3
))
return
mean_activation
src/sdk/pynni/nni/compression/torch/
lottery_ticket
.py
→
src/sdk/pynni/nni/compression/torch/
pruners
.py
View file @
06db4729
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
logging
import
torch
from
.compressor
import
Pruner
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SlimPruner'
,
'LotteryTicketPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
class
LevelPruner
(
Pruner
):
"""
Prune to an exact pruning level specification
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
if
op_name
not
in
self
.
mask_calculated_ops
:
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
else
:
assert
op_name
in
self
.
mask_dict
,
"op_name not in the mask_dict"
mask
=
self
.
mask_dict
[
op_name
]
return
mask
class
AGP_Pruner
(
Pruner
):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch
=
config
.
get
(
'end_epoch'
,
1
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
final_sparsity
=
config
.
get
(
'final_sparsity'
,
0
)
initial_sparsity
=
config
.
get
(
'initial_sparsity'
,
0
)
if
end_epoch
<=
start_epoch
or
initial_sparsity
>=
final_sparsity
:
logger
.
warning
(
'your end epoch <= start epoch or initial_sparsity >= final_sparsity'
)
return
final_sparsity
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
self
.
if_init_list
[
k
]
=
True
class
SlimPruner
(
Pruner
):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
class
LotteryTicketPruner
(
Pruner
):
"""
...
...
src/sdk/pynni/nni/compression/torch/
builtin_
quantizers.py
→
src/sdk/pynni/nni/compression/torch/quantizers.py
View file @
06db4729
File moved
src/sdk/pynni/nni/compression/torch/
bu
ilt
in
_pruners.py
→
src/sdk/pynni/nni/compression/torch/
weight_rank_f
ilt
er
_pruners.py
View file @
06db4729
...
...
@@ -5,240 +5,9 @@ import logging
import
torch
from
.compressor
import
Pruner
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SlimPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
,
'ActivationAPoZRankFilterPruner'
,
'ActivationMeanRankFilterPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
class
LevelPruner
(
Pruner
):
"""
Prune to an exact pruning level specification
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
if
op_name
not
in
self
.
mask_calculated_ops
:
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
else
:
assert
op_name
in
self
.
mask_dict
,
"op_name not in the mask_dict"
mask
=
self
.
mask_dict
[
op_name
]
return
mask
class
AGP_Pruner
(
Pruner
):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch
=
config
.
get
(
'end_epoch'
,
1
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
final_sparsity
=
config
.
get
(
'final_sparsity'
,
0
)
initial_sparsity
=
config
.
get
(
'initial_sparsity'
,
0
)
if
end_epoch
<=
start_epoch
or
initial_sparsity
>=
final_sparsity
:
logger
.
warning
(
'your end epoch <= start epoch or initial_sparsity >= final_sparsity'
)
return
final_sparsity
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
self
.
if_init_list
[
k
]
=
True
class
SlimPruner
(
Pruner
):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
__all__
=
[
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
]
logger
=
logging
.
getLogger
(
'torch weight rank filter pruners'
)
class
WeightRankFilterPruner
(
Pruner
):
"""
...
...
@@ -260,8 +29,8 @@ class WeightRankFilterPruner(Pruner):
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
# operations whose mask has been calculated
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
r
eturn
{
'weight'
:
None
,
'bias'
:
None
}
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
r
aise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
"""
...
...
@@ -299,7 +68,7 @@ class WeightRankFilterPruner(Pruner):
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
mask
=
self
.
_
get_mask
(
mask
,
weight
,
num_prune
)
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
finally
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
...
...
@@ -328,7 +97,7 @@ class L1FilterPruner(WeightRankFilterPruner):
super
().
__init__
(
model
,
config_list
)
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
...
...
@@ -376,7 +145,7 @@ class L2FilterPruner(WeightRankFilterPruner):
super
().
__init__
(
model
,
config_list
)
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest L2 norm of the absolute kernel weights are masked.
...
...
@@ -422,7 +191,7 @@ class FPGMPruner(WeightRankFilterPruner):
"""
super
().
__init__
(
model
,
config_list
)
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
...
...
@@ -491,251 +260,3 @@ class FPGMPruner(WeightRankFilterPruner):
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
class
ActivationRankFilterPruner
(
Pruner
):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
statistics_batch_num
=
statistics_batch_num
self
.
collected_activation
=
{}
self
.
hooks
=
{}
assert
activation
in
[
'relu'
,
'relu6'
]
if
activation
==
'relu'
:
self
.
activation
=
torch
.
nn
.
functional
.
relu
elif
activation
==
'relu6'
:
self
.
activation
=
torch
.
nn
.
functional
.
relu6
else
:
self
.
activation
=
None
def
compress
(
self
):
"""
Compress the model, register a hook for collecting activations.
"""
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
self
.
collected_activation
[
layer
.
name
]
=
[]
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
layer
.
module
.
register_forward_hook
(
_hook
)
return
self
.
bound_model
def
_get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
return
{
'weight'
:
None
,
'bias'
:
None
}
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
else
:
mask_bias
=
None
mask
=
{
'weight'
:
mask_weight
,
'bias'
:
mask_bias
}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
or
len
(
self
.
collected_activation
[
layer
.
name
])
<
self
.
statistics_batch_num
:
return
mask
mask
=
self
.
_get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
],
num_prune
)
finally
:
if
len
(
self
.
collected_activation
[
layer
.
name
])
==
self
.
statistics_batch_num
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
class
ActivationAPoZRankFilterPruner
(
ActivationRankFilterPruner
):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
def
_get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz
=
self
.
_calc_apoz
(
activations
)
prune_indices
=
torch
.
argsort
(
apoz
,
descending
=
True
)[:
num_prune
]
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
return
base_mask
def
_calc_apoz
(
self
,
activations
):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations
=
torch
.
cat
(
activations
,
0
)
_eq_zero
=
torch
.
eq
(
activations
,
torch
.
zeros_like
(
activations
))
_apoz
=
torch
.
sum
(
_eq_zero
,
dim
=
(
0
,
2
,
3
))
/
torch
.
numel
(
_eq_zero
[:,
0
,
:,
:])
return
_apoz
class
ActivationMeanRankFilterPruner
(
ActivationRankFilterPruner
):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
def
_get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation
=
self
.
_cal_mean_activation
(
activations
)
prune_indices
=
torch
.
argsort
(
mean_activation
)[:
num_prune
]
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
return
base_mask
def
_cal_mean_activation
(
self
,
activations
):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations
=
torch
.
cat
(
activations
,
0
)
mean_activation
=
torch
.
mean
(
activations
,
dim
=
(
0
,
2
,
3
))
return
mean_activation
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment