Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ed121315
Commit
ed121315
authored
Feb 10, 2020
by
chicm-ms
Committed by
GitHub
Feb 10, 2020
Browse files
Merge pull request #2022 from microsoft/dev-pruner-dataparallel
Dev pruner DataParallel
parents
c7187946
8092c8bd
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
523 additions
and
222 deletions
+523
-222
examples/model_compress/MeanActivation_torch_cifar10.py
examples/model_compress/MeanActivation_torch_cifar10.py
+26
-10
examples/model_compress/fpgm_torch_mnist.py
examples/model_compress/fpgm_torch_mnist.py
+13
-6
examples/model_compress/lottery_torch_mnist_fc.py
examples/model_compress/lottery_torch_mnist_fc.py
+2
-0
examples/model_compress/multi_gpu.py
examples/model_compress/multi_gpu.py
+101
-0
examples/model_compress/slim_torch_cifar10.py
examples/model_compress/slim_torch_cifar10.py
+29
-16
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
...i/nni/compression/torch/activation_rank_filter_pruners.py
+15
-10
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+249
-85
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+68
-72
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+8
-9
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+12
-14
No files found.
examples/model_compress/MeanActivation_torch_cifar10.py
View file @
ed121315
import
math
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
nni.compression.torch
import
L1
FilterPruner
from
nni.compression.torch
import
ActivationMeanRank
FilterPruner
from
models.cifar10.vgg
import
VGG
...
...
@@ -40,6 +41,12 @@ def test(model, device, test_loader):
def
main
():
parser
=
argparse
.
ArgumentParser
(
"multiple gpu with pruning"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
160
)
parser
.
add_argument
(
"--retrain"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--parallel"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
...
...
@@ -63,10 +70,11 @@ def main():
model
.
to
(
device
)
# Train the base VGG-16 model
if
args
.
retrain
:
print
(
'='
*
10
+
'Train the unpruned base model'
+
'='
*
10
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
160
,
0
)
for
epoch
in
range
(
160
):
for
epoch
in
range
(
args
.
epochs
):
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
lr_scheduler
.
step
(
epoch
)
...
...
@@ -88,8 +96,16 @@ def main():
# Prune model and test accuracy without fine tuning.
print
(
'='
*
10
+
'Test on the pruned model before fine tune'
+
'='
*
10
)
pruner
=
L1
FilterPruner
(
model
,
configure_list
)
pruner
=
ActivationMeanRank
FilterPruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
if
args
.
parallel
:
if
torch
.
cuda
.
device_count
()
>
1
:
print
(
"use {} gpus for pruning"
.
format
(
torch
.
cuda
.
device_count
()))
model
=
nn
.
DataParallel
(
model
)
else
:
print
(
"only detect 1 gpu, fall back"
)
model
.
to
(
device
)
test
(
model
,
device
,
test_loader
)
# top1 = 88.19%
...
...
examples/model_compress/fpgm_torch_mnist.py
View file @
ed121315
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
nni.compression.torch
import
FPGMPruner
...
...
@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
...
...
@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
return
num_zero_filters
,
num_filters
,
float
(
num_zero_filters
)
/
num_filters
def
print_conv_filter_sparsity
(
self
):
if
isinstance
(
self
.
conv1
,
nn
.
Conv2d
):
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
)
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
)
else
:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
.
module
)
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
.
module
)
print
(
'conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv1_data
[
0
],
conv1_data
[
1
],
conv1_data
[
2
]))
print
(
'conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv2_data
[
0
],
conv2_data
[
1
],
conv2_data
[
2
]))
...
...
examples/model_compress/lottery_torch_mnist_fc.py
View file @
ed121315
...
...
@@ -71,6 +71,8 @@ if __name__ == '__main__':
pruner
=
LotteryTicketPruner
(
model
,
configure_list
,
optimizer
)
pruner
.
compress
()
#model = nn.DataParallel(model)
for
i
in
pruner
.
get_prune_iterations
():
pruner
.
prune_iteration_start
()
loss
=
0
...
...
examples/model_compress/multi_gpu.py
0 → 100644
View file @
ed121315
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.data
import
torchvision.datasets
as
datasets
import
torchvision.transforms
as
transforms
from
nni.compression.torch
import
SlimPruner
class
fc1
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
=
10
):
super
(
fc1
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
32
)
self
.
relu1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
linear1
=
nn
.
Linear
(
32
*
28
*
28
,
300
)
self
.
relu2
=
nn
.
ReLU
(
inplace
=
True
)
self
.
linear2
=
nn
.
Linear
(
300
,
100
)
self
.
relu3
=
nn
.
ReLU
(
inplace
=
True
)
self
.
linear3
=
nn
.
Linear
(
100
,
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
linear1
(
x
)
x
=
self
.
relu2
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
relu3
(
x
)
x
=
self
.
linear3
(
x
)
return
x
def
train
(
model
,
train_loader
,
optimizer
,
criterion
,
device
):
model
.
train
()
for
imgs
,
targets
in
train_loader
:
optimizer
.
zero_grad
()
imgs
,
targets
=
imgs
.
to
(
device
),
targets
.
to
(
device
)
output
=
model
(
imgs
)
train_loss
=
criterion
(
output
,
targets
)
train_loss
.
backward
()
optimizer
.
step
()
return
train_loss
.
item
()
def
test
(
model
,
test_loader
,
criterion
,
device
):
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
pred
=
output
.
data
.
max
(
1
,
keepdim
=
True
)[
1
]
# get the index of the max log-probability
correct
+=
pred
.
eq
(
target
.
data
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
accuracy
=
100.
*
correct
/
len
(
test_loader
.
dataset
)
return
accuracy
if
__name__
==
'__main__'
:
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
traindataset
=
datasets
.
MNIST
(
'./data'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
testdataset
=
datasets
.
MNIST
(
'./data'
,
train
=
False
,
transform
=
transform
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
traindataset
,
batch_size
=
60
,
shuffle
=
True
,
num_workers
=
10
,
drop_last
=
False
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
testdataset
,
batch_size
=
60
,
shuffle
=
False
,
num_workers
=
10
,
drop_last
=
True
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
fc1
()
criterion
=
nn
.
CrossEntropyLoss
()
configure_list
=
[{
'prune_iterations'
:
5
,
'sparsity'
:
0.86
,
'op_types'
:
[
'BatchNorm2d'
]
}]
pruner
=
SlimPruner
(
model
,
configure_list
)
pruner
.
compress
()
if
torch
.
cuda
.
device_count
()
>
1
:
model
=
nn
.
DataParallel
(
model
)
model
.
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1.2e-3
)
for
name
,
par
in
model
.
named_parameters
():
print
(
name
)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss
=
0
accuracy
=
0
for
epoch
in
range
(
10
):
loss
=
train
(
model
,
train_loader
,
optimizer
,
criterion
,
device
)
accuracy
=
test
(
model
,
test_loader
,
criterion
,
device
)
print
(
'current epoch: {0}, loss: {1}, accuracy: {2}'
.
format
(
epoch
,
loss
,
accuracy
))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
)
\ No newline at end of file
examples/model_compress/slim_torch_cifar10.py
View file @
ed121315
import
math
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -6,7 +7,6 @@ from torchvision import datasets, transforms
from
nni.compression.torch
import
SlimPruner
from
models.cifar10.vgg
import
VGG
def
updateBN
(
model
):
for
m
in
model
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
...
...
@@ -49,6 +49,13 @@ def test(model, device, test_loader):
def
main
():
parser
=
argparse
.
ArgumentParser
(
"multiple gpu with pruning"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
160
)
parser
.
add_argument
(
"--retrain"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--parallel"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
...
...
@@ -70,15 +77,16 @@ def main():
model
=
VGG
(
depth
=
19
)
model
.
to
(
device
)
# Train the base VGG-19 model
if
args
.
retrain
:
print
(
'='
*
10
+
'Train the unpruned base model'
+
'='
*
10
)
epochs
=
160
epochs
=
args
.
epochs
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
for
epoch
in
range
(
epochs
):
if
epoch
in
[
epochs
*
0.5
,
epochs
*
0.75
]:
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
*=
0.1
print
(
"epoch {}"
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
,
True
)
test
(
model
,
device
,
test_loader
)
torch
.
save
(
model
.
state_dict
(),
'vgg19_cifar10.pth'
)
...
...
@@ -99,9 +107,14 @@ def main():
print
(
'='
*
10
+
'Test the pruned model before fine tune'
+
'='
*
10
)
pruner
=
SlimPruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
test
(
model
,
device
,
test_loader
)
# top1 = 93.55%
if
args
.
parallel
:
if
torch
.
cuda
.
device_count
()
>
1
:
print
(
"use {} gpus for pruning"
.
format
(
torch
.
cuda
.
device_count
()))
model
=
nn
.
DataParallel
(
model
)
# model = nn.DataParallel(model, device_ids=[0, 1])
else
:
print
(
"only detect 1 gpu, fall back"
)
model
.
to
(
device
)
# Fine tune the pruned model for 40 epochs and test accuracy
print
(
'='
*
10
+
'Fine tuning'
+
'='
*
10
)
optimizer_finetune
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
...
...
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
View file @
ed121315
...
...
@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
statistics_batch_num
=
statistics_batch_num
self
.
collected_activation
=
{}
self
.
hooks
=
{}
...
...
@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner):
"""
Compress the model, register a hook for collecting activations.
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
collected_activation
[
layer
.
name
]
=
[]
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
layer
.
module
.
register_forward_hook
(
_hook
)
wrapper
.
module
.
register_forward_hook
(
_hook
)
self
.
_wrap_model
()
return
self
.
bound_model
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
...
...
@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask
_calculated
_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
if
_calculated
=
kwargs
[
"if
_calculated
"
]
if
if_calculated
:
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
...
...
@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask
=
self
.
get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
],
num_prune
)
finally
:
if
len
(
self
.
collected_activation
[
layer
.
name
])
==
self
.
statistics_batch_num
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
ed121315
...
...
@@ -14,8 +14,11 @@ class LayerInfo:
self
.
name
=
name
self
.
type
=
type
(
module
).
__name__
self
.
_forward
=
None
def
_setattr
(
model
,
name
,
module
):
name_list
=
name
.
split
(
"."
)
for
name
in
name_list
[:
-
1
]:
model
=
getattr
(
model
,
name
)
setattr
(
model
,
name_list
[
-
1
],
module
)
class
Compressor
:
"""
...
...
@@ -36,6 +39,9 @@ class Compressor:
self
.
bound_model
=
model
self
.
config_list
=
config_list
self
.
modules_to_compress
=
None
self
.
modules_wrapper
=
None
self
.
buffers
=
{}
self
.
is_wrapped
=
False
def
detect_modules_to_compress
(
self
):
"""
...
...
@@ -51,21 +57,60 @@ class Compressor:
self
.
modules_to_compress
.
append
((
layer
,
config
))
return
self
.
modules_to_compress
def
_wrap_model
(
self
):
"""
wrap all modules that needed to be compressed
"""
for
wrapper
in
reversed
(
self
.
get_modules_wrapper
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
unwrap all modules that needed to be compressed
"""
for
wrapper
in
self
.
get_modules_wrapper
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
self
.
is_wrapped
=
False
def
compress
(
self
):
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
_wrap_model
()
return
self
.
bound_model
def
register_buffer
(
self
,
name
,
value
):
"""
To register buffers used in wrapped module's forward method.
"""
self
.
buffers
[
name
]
=
value
def
get_modules_to_compress
(
self
):
"""
To obtain all the to-be-compressed
layer
s.
To obtain all the to-be-compressed
module
s.
Returns
-------
...
...
@@ -75,6 +120,17 @@ class Compressor:
"""
return
self
.
modules_to_compress
def
get_modules_wrapper
(
self
):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return
self
.
modules_wrapper
def
select_config
(
self
,
layer
):
"""
Find the configuration for `layer` by parsing `self.config_list`
...
...
@@ -119,7 +175,7 @@ class Compressor:
If user want to update model every step, user can override this method
"""
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
...
...
@@ -143,6 +199,59 @@ class Compressor:
expanded_op_types
.
append
(
op_type
)
return
expanded_op_types
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
self
.
type
=
module_type
# config and pruner
self
.
config
=
config
self
.
pruner
=
pruner
self
.
registered_buffers
=
{}
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
def
forward
(
self
,
*
inputs
):
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
registered_buffers
)
if
mask
is
not
None
:
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# apply mask to weight
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
# apply mask to bias
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
mask
is
not
None
and
'bias'
in
mask
:
self
.
bias_mask
.
copy_
(
mask
[
'bias'
])
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
class
Pruner
(
Compressor
):
"""
...
...
@@ -158,9 +267,8 @@ class Pruner(Compressor):
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
self
.
mask_dict
=
{}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
...
...
@@ -176,9 +284,9 @@ class Pruner(Compressor):
"""
raise
NotImplementedError
(
"Pruners must overload calc_mask()"
)
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
Create a wrapper
forward function
to replace the original one.
Create a wrapper
module
to replace the original one.
Parameters
----------
...
...
@@ -187,30 +295,13 @@ class Pruner(Compressor):
config : dict
the configuration for generating the mask
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
if
not
_check_weight
(
layer
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
return
layer
.
_forward
=
layer
.
module
.
forward
def
new_forward
(
*
inputs
):
mask
=
self
.
calc_mask
(
layer
,
config
)
# apply mask to weight
old_weight
=
layer
.
module
.
weight
.
data
mask_weight
=
mask
[
'weight'
]
layer
.
module
.
weight
.
data
=
old_weight
.
mul
(
mask_weight
)
# apply mask to bias
if
mask
.
__contains__
(
'bias'
)
and
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
old_bias
=
layer
.
module
.
bias
.
data
mask_bias
=
mask
[
'bias'
]
layer
.
module
.
bias
.
data
=
old_bias
.
mul
(
mask_bias
)
# calculate forward
ret
=
layer
.
_forward
(
*
inputs
)
return
ret
layer
.
module
.
forward
=
new_forward
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
assert
hasattr
(
layer
.
module
,
'weight'
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
):
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export pruned model weights, masks and onnx model(optional)
...
...
@@ -224,35 +315,138 @@ class Pruner(Compressor):
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
if
self
.
detect_modules_to_compress
()
and
not
self
.
mask_dict
:
_logger
.
warning
(
'You may not use self.mask_dict in base Pruner class to record masks'
)
#
if self.detect_modules_to_compress() and not self.mask_dict:
#
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert
model_path
is
not
None
,
'model_path must be specified'
for
name
,
m
in
self
.
bound_model
.
named_modules
():
if
name
==
""
:
continue
masks
=
self
.
mask_dict
.
get
(
name
)
if
masks
is
not
None
:
mask_sum
=
masks
[
'weight'
].
sum
().
item
()
mask_num
=
masks
[
'weight'
].
numel
()
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
name
,
1
-
mask_sum
/
mask_num
)
m
.
weight
.
data
=
m
.
weight
.
data
.
mul
(
masks
[
'weight'
])
if
masks
.
__contains__
(
'bias'
)
and
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
m
.
bias
.
data
=
m
.
bias
.
data
.
mul
(
masks
[
'bias'
])
else
:
_logger
.
info
(
'Layer: %s NOT compressed'
,
name
)
mask_dict
=
{}
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
for
wrapper
in
self
.
get_modules_wrapper
():
weight_mask
=
wrapper
.
weight_mask
bias_mask
=
wrapper
.
bias_mask
if
weight_mask
is
not
None
:
mask_sum
=
weight_mask
.
sum
().
item
()
mask_num
=
weight_mask
.
numel
()
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
wrapper
.
name
,
1
-
mask_sum
/
mask_num
)
wrapper
.
module
.
weight
.
data
=
wrapper
.
module
.
weight
.
data
.
mul
(
weight_mask
)
if
bias_mask
is
not
None
:
wrapper
.
module
.
bias
.
data
=
wrapper
.
module
.
bias
.
data
.
mul
(
bias_mask
)
# save mask to dict
mask_dict
[
wrapper
.
name
]
=
{
"weight"
:
weight_mask
,
"bias"
:
bias_mask
}
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
mask_path
is
not
None
:
torch
.
save
(
self
.
mask_dict
,
mask_path
)
torch
.
save
(
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
)
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
self
.
_wrap_model
()
def
load_model_state_dict
(
self
,
model_state
):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if
self
.
is_wrapped
:
self
.
_unwrap_model
()
self
.
bound_model
.
load_state_dict
(
model_state
)
self
.
_wrap_model
()
else
:
self
.
bound_model
.
load_state_dict
(
model_state
)
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
self
.
type
=
module_type
# config and pruner
self
.
config
=
config
self
.
quantizer
=
quantizer
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if
'weight'
in
config
[
'quant_types'
]:
if
not
_check_weight
(
self
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
self
.
name
)
else
:
self
.
module
.
register_parameter
(
'old_weight'
,
torch
.
nn
.
Parameter
(
self
.
module
.
weight
))
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# register user specified buffer
self
.
registered_buffers
=
{}
for
name
in
self
.
quantizer
.
buffers
:
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
.
apply
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
.
quantizer
.
quantize_input
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
self
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantizer
.
quantize_weight
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
else
:
result
=
self
.
module
(
*
inputs
)
if
'output'
in
self
.
config
[
'quant_types'
]:
result
=
self
.
quantizer
.
quant_grad
.
apply
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantizer
.
quantize_output
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
return
result
class
Quantizer
(
Compressor
):
"""
...
...
@@ -303,7 +497,7 @@ class Quantizer(Compressor):
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
Create a wrapper forward function to replace the original one.
Parameters
...
...
@@ -313,7 +507,6 @@ class Quantizer(Compressor):
config : dict
the configuration for quantization
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
assert
'quant_types'
in
config
,
'must provide quant_types in config'
assert
isinstance
(
config
[
'quant_types'
],
list
),
'quant_types must be list type'
assert
'quant_bits'
in
config
,
'must provide quant_bits in config'
...
...
@@ -323,35 +516,7 @@ class Quantizer(Compressor):
for
quant_type
in
config
[
'quant_types'
]:
assert
quant_type
in
config
[
'quant_bits'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
if
'weight'
in
config
[
'quant_types'
]:
if
not
_check_weight
(
layer
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
else
:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer
.
module
.
register_parameter
(
'old_weight'
,
torch
.
nn
.
Parameter
(
layer
.
module
.
weight
))
delattr
(
layer
.
module
,
'weight'
)
layer
.
module
.
register_buffer
(
'weight'
,
layer
.
module
.
old_weight
)
layer
.
_forward
=
layer
.
module
.
forward
def
new_forward
(
*
inputs
):
if
'input'
in
config
[
'quant_types'
]:
inputs
=
self
.
quant_grad
.
apply
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
.
quantize_input
,
config
,
layer
)
if
'weight'
in
config
[
'quant_types'
]
and
_check_weight
(
layer
.
module
):
new_weight
=
self
.
quant_grad
.
apply
(
layer
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantize_weight
,
config
,
layer
)
layer
.
module
.
weight
=
new_weight
result
=
layer
.
_forward
(
*
inputs
)
else
:
result
=
layer
.
_forward
(
*
inputs
)
if
'output'
in
config
[
'quant_types'
]:
result
=
self
.
quant_grad
.
apply
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantize_output
,
config
,
layer
)
return
result
layer
.
module
.
forward
=
new_forward
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
class
QuantType
:
"""
...
...
@@ -387,19 +552,18 @@ class QuantGrad(torch.autograd.Function):
return
grad_output
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
layer
):
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
layer
,
**
kwargs
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
,
**
kwargs
)
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
return
output
,
None
,
None
,
None
,
None
return
output
,
None
,
None
,
None
,
None
,
None
def
_check_weight
(
module
):
try
:
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
except
AttributeError
:
return
False
\ No newline at end of file
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
ed121315
...
...
@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer
Parameters
...
...
@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
if
op_name
not
in
self
.
mask_calculated_ops
:
if_calculated
=
kwargs
[
"if_calculated"
]
if
not
if_calculated
:
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
...
...
@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
else
:
assert
op_name
in
self
.
mask_dict
,
"op_name not in the mask_dict"
mask
=
self
.
mask_dict
[
op_name
]
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
else
:
return
None
class
AGP_Pruner
(
Pruner
):
...
...
@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner):
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
...
...
@@ -102,12 +104,16 @@ class AGP_Pruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
if_calculated
=
kwargs
[
"if_calculated"
]
if
if_calculated
:
return
None
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
return
None
mask
=
{
'weight'
:
kwargs
[
'weight_mask'
]
if
'weight_mask'
in
kwargs
else
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)}
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
...
...
@@ -116,10 +122,8 @@ class AGP_Pruner(Pruner):
w_abs
=
weight
.
abs
()
*
mask
[
'weight'
]
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
...
...
@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
self
.
if_init_list
[
k
]
=
True
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
class
SlimPruner
(
Pruner
):
"""
...
...
@@ -187,7 +190,6 @@ class SlimPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
...
...
@@ -198,8 +200,9 @@ class SlimPruner(Pruner):
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
...
...
@@ -209,6 +212,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
...
...
@@ -216,27 +221,21 @@ class SlimPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
if_calculated
=
kwargs
[
"if_calculated"
]
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
if
if_calculated
:
return
None
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
if
filters
>=
2
and
num_prune
>=
1
:
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
class
LotteryTicketPruner
(
Pruner
):
...
...
@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations
=
config
[
'prune_iterations'
]
return
prune_iterations
def
_print_masks
(
self
,
print_mask
=
False
):
torch
.
set_printoptions
(
threshold
=
1000
)
for
op_name
in
self
.
mask_dict
.
keys
():
mask
=
self
.
mask_dict
[
op_name
]
print
(
'op name: '
,
op_name
)
if
print_mask
:
print
(
'mask: '
,
mask
)
# calculate current sparsity
mask_num
=
mask
[
'weight'
].
sum
().
item
()
mask_size
=
mask
[
'weight'
].
numel
()
print
(
'sparsity: '
,
1
-
mask_num
/
mask_size
)
torch
.
set_printoptions
(
profile
=
'default'
)
def
_calc_sparsity
(
self
,
sparsity
):
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
curr_keep_ratio
=
keep_ratio_once
**
self
.
curr_prune_iteration
return
max
(
1
-
curr_keep_ratio
,
0
)
def
_calc_mask
(
self
,
weight
,
sparsity
,
op_name
):
def
_calc_mask
(
self
,
weight
,
sparsity
,
curr_w_mask
):
if
self
.
curr_prune_iteration
==
0
:
mask
=
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
else
:
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
assert
self
.
mask_dict
.
get
(
op_name
)
is
not
None
curr_mask
=
self
.
mask_dict
.
get
(
op_name
)
w_abs
=
weight
.
abs
()
*
curr_mask
[
'weight'
]
w_abs
=
weight
.
abs
()
*
curr_w_mask
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask
}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Generate mask for the given ``weight``.
...
...
@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
assert
self
.
mask_dict
.
get
(
layer
.
name
)
is
not
None
,
'Please call iteration_start before training'
mask
=
self
.
mask_dict
[
layer
.
name
]
return
mask
return
None
def
get_prune_iterations
(
self
):
"""
...
...
@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner):
self
.
curr_prune_iteration
+=
1
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
modules_wrapper
=
self
.
get_modules_wrapper
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
module_wrapper
=
None
for
wrapper
in
modules_wrapper
:
if
wrapper
.
name
==
layer
.
name
:
module_wrapper
=
wrapper
break
assert
module_wrapper
is
not
None
sparsity
=
config
.
get
(
'sparsity'
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
layer
.
name
)
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
_print_masks
()
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
module_wrapper
.
weight_mask
)
# TODO: directly use weight_mask is not good
module_wrapper
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# there is no mask for bias
# reinit weights back to original after new masks are generated
if
self
.
reset_weights
:
self
.
_model
.
load_state_dict
(
self
.
_model_state
)
# should use this member function to reset model weights
self
.
load_model_state_dict
(
self
.
_model_state
)
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer_state
)
if
self
.
_lr_scheduler
is
not
None
:
self
.
_lr_scheduler
.
load_state_dict
(
self
.
_scheduler_state
)
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
View file @
ed121315
...
...
@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
# operations whose mask has been calculated
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
...
...
@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
],
"only support Conv1d and Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask
_calculated
_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
if
_calculated
=
kwargs
[
"if
_calculated
"
]
if
if_calculated
:
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
...
...
@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return
mask
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
finally
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
...
...
@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
src/sdk/pynni/tests/test_compressor.py
View file @
ed121315
...
...
@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase):
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
layer
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
pruner
.
update_epoch
(
1
)
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
45.
,
45.
]))
@
tf2
...
...
@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert
all
(
masks
.
sum
((
1
))
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
pruner
.
update_epoch
(
1
)
model
.
layers
[
2
].
set_weights
([
weights
[
0
],
weights
[
1
].
numpy
()])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]).
numpy
()
masks
=
masks
.
reshape
((
-
1
,
masks
.
shape
[
-
1
])).
transpose
([
1
,
0
])
...
...
@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase):
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv1'
,
model
.
conv1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
1
])
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
1
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
torch
.
sum
(
mask1
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
27.
,
27.
,
27.
,
27.
]))
assert
all
(
torch
.
sum
(
mask2
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
27.
,
27.
]))
...
...
@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase):
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
mask1
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask2
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask1
[
'bias'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
...
...
@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase):
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
mask1
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask2
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask1
[
'bias'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
...
...
@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase):
# test ema
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.002
,
abs_tol
=
eps
)
quantizer
.
step
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment