Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ed121315
Commit
ed121315
authored
Feb 10, 2020
by
chicm-ms
Committed by
GitHub
Feb 10, 2020
Browse files
Merge pull request #2022 from microsoft/dev-pruner-dataparallel
Dev pruner DataParallel
parents
c7187946
8092c8bd
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
523 additions
and
222 deletions
+523
-222
examples/model_compress/MeanActivation_torch_cifar10.py
examples/model_compress/MeanActivation_torch_cifar10.py
+26
-10
examples/model_compress/fpgm_torch_mnist.py
examples/model_compress/fpgm_torch_mnist.py
+13
-6
examples/model_compress/lottery_torch_mnist_fc.py
examples/model_compress/lottery_torch_mnist_fc.py
+2
-0
examples/model_compress/multi_gpu.py
examples/model_compress/multi_gpu.py
+101
-0
examples/model_compress/slim_torch_cifar10.py
examples/model_compress/slim_torch_cifar10.py
+29
-16
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
...i/nni/compression/torch/activation_rank_filter_pruners.py
+15
-10
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+249
-85
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+68
-72
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+8
-9
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+12
-14
No files found.
examples/model_compress/MeanActivation_torch_cifar10.py
View file @
ed121315
import
math
import
math
import
argparse
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.compression.torch
import
L1
FilterPruner
from
nni.compression.torch
import
ActivationMeanRank
FilterPruner
from
models.cifar10.vgg
import
VGG
from
models.cifar10.vgg
import
VGG
...
@@ -40,6 +41,12 @@ def test(model, device, test_loader):
...
@@ -40,6 +41,12 @@ def test(model, device, test_loader):
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
"multiple gpu with pruning"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
160
)
parser
.
add_argument
(
"--retrain"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--parallel"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
...
@@ -63,14 +70,15 @@ def main():
...
@@ -63,14 +70,15 @@ def main():
model
.
to
(
device
)
model
.
to
(
device
)
# Train the base VGG-16 model
# Train the base VGG-16 model
print
(
'='
*
10
+
'Train the unpruned base model'
+
'='
*
10
)
if
args
.
retrain
:
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
print
(
'='
*
10
+
'Train the unpruned base model'
+
'='
*
10
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
160
,
0
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
for
epoch
in
range
(
160
):
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
160
,
0
)
train
(
model
,
device
,
train_loader
,
optimizer
)
for
epoch
in
range
(
args
.
epochs
):
test
(
model
,
device
,
test_loader
)
train
(
model
,
device
,
train_loader
,
optimizer
)
lr_scheduler
.
step
(
epoch
)
test
(
model
,
device
,
test_loader
)
torch
.
save
(
model
.
state_dict
(),
'vgg16_cifar10.pth'
)
lr_scheduler
.
step
(
epoch
)
torch
.
save
(
model
.
state_dict
(),
'vgg16_cifar10.pth'
)
# Test base model accuracy
# Test base model accuracy
print
(
'='
*
10
+
'Test on the original model'
+
'='
*
10
)
print
(
'='
*
10
+
'Test on the original model'
+
'='
*
10
)
...
@@ -88,8 +96,16 @@ def main():
...
@@ -88,8 +96,16 @@ def main():
# Prune model and test accuracy without fine tuning.
# Prune model and test accuracy without fine tuning.
print
(
'='
*
10
+
'Test on the pruned model before fine tune'
+
'='
*
10
)
print
(
'='
*
10
+
'Test on the pruned model before fine tune'
+
'='
*
10
)
pruner
=
L1
FilterPruner
(
model
,
configure_list
)
pruner
=
ActivationMeanRank
FilterPruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
model
=
pruner
.
compress
()
if
args
.
parallel
:
if
torch
.
cuda
.
device_count
()
>
1
:
print
(
"use {} gpus for pruning"
.
format
(
torch
.
cuda
.
device_count
()))
model
=
nn
.
DataParallel
(
model
)
else
:
print
(
"only detect 1 gpu, fall back"
)
model
.
to
(
device
)
test
(
model
,
device
,
test_loader
)
test
(
model
,
device
,
test_loader
)
# top1 = 88.19%
# top1 = 88.19%
...
...
examples/model_compress/fpgm_torch_mnist.py
View file @
ed121315
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.compression.torch
import
FPGMPruner
from
nni.compression.torch
import
FPGMPruner
...
@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
...
@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
class
Mnist
(
torch
.
nn
.
Module
):
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
fc2
=
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
relu
(
self
.
conv1
(
x
))
...
@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
...
@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
return
num_zero_filters
,
num_filters
,
float
(
num_zero_filters
)
/
num_filters
return
num_zero_filters
,
num_filters
,
float
(
num_zero_filters
)
/
num_filters
def
print_conv_filter_sparsity
(
self
):
def
print_conv_filter_sparsity
(
self
):
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
)
if
isinstance
(
self
.
conv1
,
nn
.
Conv2d
):
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
)
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
)
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
)
else
:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
.
module
)
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
.
module
)
print
(
'conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv1_data
[
0
],
conv1_data
[
1
],
conv1_data
[
2
]))
print
(
'conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv1_data
[
0
],
conv1_data
[
1
],
conv1_data
[
2
]))
print
(
'conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv2_data
[
0
],
conv2_data
[
1
],
conv2_data
[
2
]))
print
(
'conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv2_data
[
0
],
conv2_data
[
1
],
conv2_data
[
2
]))
...
...
examples/model_compress/lottery_torch_mnist_fc.py
View file @
ed121315
...
@@ -71,6 +71,8 @@ if __name__ == '__main__':
...
@@ -71,6 +71,8 @@ if __name__ == '__main__':
pruner
=
LotteryTicketPruner
(
model
,
configure_list
,
optimizer
)
pruner
=
LotteryTicketPruner
(
model
,
configure_list
,
optimizer
)
pruner
.
compress
()
pruner
.
compress
()
#model = nn.DataParallel(model)
for
i
in
pruner
.
get_prune_iterations
():
for
i
in
pruner
.
get_prune_iterations
():
pruner
.
prune_iteration_start
()
pruner
.
prune_iteration_start
()
loss
=
0
loss
=
0
...
...
examples/model_compress/multi_gpu.py
0 → 100644
View file @
ed121315
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.data
import
torchvision.datasets
as
datasets
import
torchvision.transforms
as
transforms
from
nni.compression.torch
import
SlimPruner
class
fc1
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
=
10
):
super
(
fc1
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
32
)
self
.
relu1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
linear1
=
nn
.
Linear
(
32
*
28
*
28
,
300
)
self
.
relu2
=
nn
.
ReLU
(
inplace
=
True
)
self
.
linear2
=
nn
.
Linear
(
300
,
100
)
self
.
relu3
=
nn
.
ReLU
(
inplace
=
True
)
self
.
linear3
=
nn
.
Linear
(
100
,
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
linear1
(
x
)
x
=
self
.
relu2
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
relu3
(
x
)
x
=
self
.
linear3
(
x
)
return
x
def
train
(
model
,
train_loader
,
optimizer
,
criterion
,
device
):
model
.
train
()
for
imgs
,
targets
in
train_loader
:
optimizer
.
zero_grad
()
imgs
,
targets
=
imgs
.
to
(
device
),
targets
.
to
(
device
)
output
=
model
(
imgs
)
train_loss
=
criterion
(
output
,
targets
)
train_loss
.
backward
()
optimizer
.
step
()
return
train_loss
.
item
()
def
test
(
model
,
test_loader
,
criterion
,
device
):
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
pred
=
output
.
data
.
max
(
1
,
keepdim
=
True
)[
1
]
# get the index of the max log-probability
correct
+=
pred
.
eq
(
target
.
data
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
accuracy
=
100.
*
correct
/
len
(
test_loader
.
dataset
)
return
accuracy
if
__name__
==
'__main__'
:
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
traindataset
=
datasets
.
MNIST
(
'./data'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
testdataset
=
datasets
.
MNIST
(
'./data'
,
train
=
False
,
transform
=
transform
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
traindataset
,
batch_size
=
60
,
shuffle
=
True
,
num_workers
=
10
,
drop_last
=
False
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
testdataset
,
batch_size
=
60
,
shuffle
=
False
,
num_workers
=
10
,
drop_last
=
True
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
fc1
()
criterion
=
nn
.
CrossEntropyLoss
()
configure_list
=
[{
'prune_iterations'
:
5
,
'sparsity'
:
0.86
,
'op_types'
:
[
'BatchNorm2d'
]
}]
pruner
=
SlimPruner
(
model
,
configure_list
)
pruner
.
compress
()
if
torch
.
cuda
.
device_count
()
>
1
:
model
=
nn
.
DataParallel
(
model
)
model
.
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1.2e-3
)
for
name
,
par
in
model
.
named_parameters
():
print
(
name
)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss
=
0
accuracy
=
0
for
epoch
in
range
(
10
):
loss
=
train
(
model
,
train_loader
,
optimizer
,
criterion
,
device
)
accuracy
=
test
(
model
,
test_loader
,
criterion
,
device
)
print
(
'current epoch: {0}, loss: {1}, accuracy: {2}'
.
format
(
epoch
,
loss
,
accuracy
))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
)
\ No newline at end of file
examples/model_compress/slim_torch_cifar10.py
View file @
ed121315
import
math
import
math
import
argparse
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -6,7 +7,6 @@ from torchvision import datasets, transforms
...
@@ -6,7 +7,6 @@ from torchvision import datasets, transforms
from
nni.compression.torch
import
SlimPruner
from
nni.compression.torch
import
SlimPruner
from
models.cifar10.vgg
import
VGG
from
models.cifar10.vgg
import
VGG
def
updateBN
(
model
):
def
updateBN
(
model
):
for
m
in
model
.
modules
():
for
m
in
model
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
...
@@ -49,6 +49,13 @@ def test(model, device, test_loader):
...
@@ -49,6 +49,13 @@ def test(model, device, test_loader):
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
"multiple gpu with pruning"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
160
)
parser
.
add_argument
(
"--retrain"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--parallel"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
...
@@ -70,18 +77,19 @@ def main():
...
@@ -70,18 +77,19 @@ def main():
model
=
VGG
(
depth
=
19
)
model
=
VGG
(
depth
=
19
)
model
.
to
(
device
)
model
.
to
(
device
)
# Train the base VGG-19 model
# Train the base VGG-19 model
print
(
'='
*
10
+
'Train the unpruned base model'
+
'='
*
10
)
if
args
.
retrain
:
epochs
=
160
print
(
'='
*
10
+
'Train the unpruned base model'
+
'='
*
10
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
epochs
=
args
.
epochs
for
epoch
in
range
(
epochs
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
if
epoch
in
[
epochs
*
0.5
,
epochs
*
0.75
]:
for
epoch
in
range
(
epochs
):
for
param_group
in
optimizer
.
param_groups
:
if
epoch
in
[
epochs
*
0.5
,
epochs
*
0.75
]:
param_group
[
'lr'
]
*=
0.1
for
param_group
in
optimizer
.
param_groups
:
train
(
model
,
device
,
train_loader
,
optimizer
,
True
)
param_group
[
'lr'
]
*=
0.1
test
(
model
,
device
,
test_loader
)
print
(
"epoch {}"
.
format
(
epoch
))
torch
.
save
(
model
.
state_dict
(),
'vgg19_cifar10.pth'
)
train
(
model
,
device
,
train_loader
,
optimizer
,
True
)
test
(
model
,
device
,
test_loader
)
torch
.
save
(
model
.
state_dict
(),
'vgg19_cifar10.pth'
)
# Test base model accuracy
# Test base model accuracy
print
(
'='
*
10
+
'Test the original model'
+
'='
*
10
)
print
(
'='
*
10
+
'Test the original model'
+
'='
*
10
)
...
@@ -94,14 +102,19 @@ def main():
...
@@ -94,14 +102,19 @@ def main():
'sparsity'
:
0.7
,
'sparsity'
:
0.7
,
'op_types'
:
[
'BatchNorm2d'
],
'op_types'
:
[
'BatchNorm2d'
],
}]
}]
# Prune model and test accuracy without fine tuning.
# Prune model and test accuracy without fine tuning.
print
(
'='
*
10
+
'Test the pruned model before fine tune'
+
'='
*
10
)
print
(
'='
*
10
+
'Test the pruned model before fine tune'
+
'='
*
10
)
pruner
=
SlimPruner
(
model
,
configure_list
)
pruner
=
SlimPruner
(
model
,
configure_list
)
model
=
pruner
.
compress
()
model
=
pruner
.
compress
()
test
(
model
,
device
,
test_loader
)
if
args
.
parallel
:
# top1 = 93.55%
if
torch
.
cuda
.
device_count
()
>
1
:
print
(
"use {} gpus for pruning"
.
format
(
torch
.
cuda
.
device_count
()))
model
=
nn
.
DataParallel
(
model
)
# model = nn.DataParallel(model, device_ids=[0, 1])
else
:
print
(
"only detect 1 gpu, fall back"
)
model
.
to
(
device
)
# Fine tune the pruned model for 40 epochs and test accuracy
# Fine tune the pruned model for 40 epochs and test accuracy
print
(
'='
*
10
+
'Fine tuning'
+
'='
*
10
)
print
(
'='
*
10
+
'Fine tuning'
+
'='
*
10
)
optimizer_finetune
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
optimizer_finetune
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
,
weight_decay
=
1e-4
)
...
...
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
View file @
ed121315
...
@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
statistics_batch_num
=
statistics_batch_num
self
.
statistics_batch_num
=
statistics_batch_num
self
.
collected_activation
=
{}
self
.
collected_activation
=
{}
self
.
hooks
=
{}
self
.
hooks
=
{}
...
@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner):
"""
"""
Compress the model, register a hook for collecting activations.
Compress the model, register a hook for collecting activations.
"""
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
collected_activation
[
layer
.
name
]
=
[]
self
.
collected_activation
[
layer
.
name
]
=
[]
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
layer
.
module
.
register_forward_hook
(
_hook
)
wrapper
.
module
.
register_forward_hook
(
_hook
)
self
.
_wrap_model
()
return
self
.
bound_model
return
self
.
bound_model
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Filters with the smallest importance criterion which is calculated from the activation are masked.
...
@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner):
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask
_calculated
_ops
:
if
_calculated
=
kwargs
[
"if
_calculated
"
]
assert
op_name
in
self
.
mask_dict
if
if_calculated
:
return
self
.
mask_dict
.
get
(
op_name
)
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
...
@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask
=
self
.
get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
],
num_prune
)
mask
=
self
.
get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
],
num_prune
)
finally
:
finally
:
if
len
(
self
.
collected_activation
[
layer
.
name
])
==
self
.
statistics_batch_num
:
if
len
(
self
.
collected_activation
[
layer
.
name
])
==
self
.
statistics_batch_num
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
return
mask
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
ed121315
...
@@ -14,8 +14,11 @@ class LayerInfo:
...
@@ -14,8 +14,11 @@ class LayerInfo:
self
.
name
=
name
self
.
name
=
name
self
.
type
=
type
(
module
).
__name__
self
.
type
=
type
(
module
).
__name__
self
.
_forward
=
None
def
_setattr
(
model
,
name
,
module
):
name_list
=
name
.
split
(
"."
)
for
name
in
name_list
[:
-
1
]:
model
=
getattr
(
model
,
name
)
setattr
(
model
,
name_list
[
-
1
],
module
)
class
Compressor
:
class
Compressor
:
"""
"""
...
@@ -36,6 +39,9 @@ class Compressor:
...
@@ -36,6 +39,9 @@ class Compressor:
self
.
bound_model
=
model
self
.
bound_model
=
model
self
.
config_list
=
config_list
self
.
config_list
=
config_list
self
.
modules_to_compress
=
None
self
.
modules_to_compress
=
None
self
.
modules_wrapper
=
None
self
.
buffers
=
{}
self
.
is_wrapped
=
False
def
detect_modules_to_compress
(
self
):
def
detect_modules_to_compress
(
self
):
"""
"""
...
@@ -51,21 +57,60 @@ class Compressor:
...
@@ -51,21 +57,60 @@ class Compressor:
self
.
modules_to_compress
.
append
((
layer
,
config
))
self
.
modules_to_compress
.
append
((
layer
,
config
))
return
self
.
modules_to_compress
return
self
.
modules_to_compress
def
_wrap_model
(
self
):
"""
wrap all modules that needed to be compressed
"""
for
wrapper
in
reversed
(
self
.
get_modules_wrapper
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
unwrap all modules that needed to be compressed
"""
for
wrapper
in
self
.
get_modules_wrapper
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
self
.
is_wrapped
=
False
def
compress
(
self
):
def
compress
(
self
):
"""
"""
Compress the model with algorithm implemented by subclass.
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
`self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
_wrap_model
()
return
self
.
bound_model
return
self
.
bound_model
def
register_buffer
(
self
,
name
,
value
):
"""
To register buffers used in wrapped module's forward method.
"""
self
.
buffers
[
name
]
=
value
def
get_modules_to_compress
(
self
):
def
get_modules_to_compress
(
self
):
"""
"""
To obtain all the to-be-compressed
layer
s.
To obtain all the to-be-compressed
module
s.
Returns
Returns
-------
-------
...
@@ -75,6 +120,17 @@ class Compressor:
...
@@ -75,6 +120,17 @@ class Compressor:
"""
"""
return
self
.
modules_to_compress
return
self
.
modules_to_compress
def
get_modules_wrapper
(
self
):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return
self
.
modules_wrapper
def
select_config
(
self
,
layer
):
def
select_config
(
self
,
layer
):
"""
"""
Find the configuration for `layer` by parsing `self.config_list`
Find the configuration for `layer` by parsing `self.config_list`
...
@@ -119,7 +175,7 @@ class Compressor:
...
@@ -119,7 +175,7 @@ class Compressor:
If user want to update model every step, user can override this method
If user want to update model every step, user can override this method
"""
"""
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
...
@@ -143,6 +199,59 @@ class Compressor:
...
@@ -143,6 +199,59 @@ class Compressor:
expanded_op_types
.
append
(
op_type
)
expanded_op_types
.
append
(
op_type
)
return
expanded_op_types
return
expanded_op_types
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
self
.
type
=
module_type
# config and pruner
self
.
config
=
config
self
.
pruner
=
pruner
self
.
registered_buffers
=
{}
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
registered_buffers
[
'weight_mask'
]
=
self
.
weight_mask
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
[
'bias_mask'
]
=
self
.
bias_mask
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
def
forward
(
self
,
*
inputs
):
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
registered_buffers
)
if
mask
is
not
None
:
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# apply mask to weight
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
# apply mask to bias
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
mask
is
not
None
and
'bias'
in
mask
:
self
.
bias_mask
.
copy_
(
mask
[
'bias'
])
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
class
Pruner
(
Compressor
):
class
Pruner
(
Compressor
):
"""
"""
...
@@ -158,9 +267,8 @@ class Pruner(Compressor):
...
@@ -158,9 +267,8 @@ class Pruner(Compressor):
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_dict
=
{}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Pruners should overload this method to provide mask for weight tensors.
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
The mask must have the same shape and type comparing to the weight.
...
@@ -176,9 +284,9 @@ class Pruner(Compressor):
...
@@ -176,9 +284,9 @@ class Pruner(Compressor):
"""
"""
raise
NotImplementedError
(
"Pruners must overload calc_mask()"
)
raise
NotImplementedError
(
"Pruners must overload calc_mask()"
)
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
"""
Create a wrapper
forward function
to replace the original one.
Create a wrapper
module
to replace the original one.
Parameters
Parameters
----------
----------
...
@@ -187,30 +295,13 @@ class Pruner(Compressor):
...
@@ -187,30 +295,13 @@ class Pruner(Compressor):
config : dict
config : dict
the configuration for generating the mask
the configuration for generating the mask
"""
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
if
not
_check_weight
(
layer
.
module
):
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
assert
hasattr
(
layer
.
module
,
'weight'
)
return
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
layer
.
_forward
=
layer
.
module
.
forward
return
wrapper
def
new_forward
(
*
inputs
):
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
mask
=
self
.
calc_mask
(
layer
,
config
)
# apply mask to weight
old_weight
=
layer
.
module
.
weight
.
data
mask_weight
=
mask
[
'weight'
]
layer
.
module
.
weight
.
data
=
old_weight
.
mul
(
mask_weight
)
# apply mask to bias
if
mask
.
__contains__
(
'bias'
)
and
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
old_bias
=
layer
.
module
.
bias
.
data
mask_bias
=
mask
[
'bias'
]
layer
.
module
.
bias
.
data
=
old_bias
.
mul
(
mask_bias
)
# calculate forward
ret
=
layer
.
_forward
(
*
inputs
)
return
ret
layer
.
module
.
forward
=
new_forward
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
):
"""
"""
Export pruned model weights, masks and onnx model(optional)
Export pruned model weights, masks and onnx model(optional)
...
@@ -224,35 +315,138 @@ class Pruner(Compressor):
...
@@ -224,35 +315,138 @@ class Pruner(Compressor):
(optional) path to save onnx model
(optional) path to save onnx model
input_shape : list or tuple
input_shape : list or tuple
input shape to onnx model
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
"""
if
self
.
detect_modules_to_compress
()
and
not
self
.
mask_dict
:
#
if self.detect_modules_to_compress() and not self.mask_dict:
_logger
.
warning
(
'You may not use self.mask_dict in base Pruner class to record masks'
)
#
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert
model_path
is
not
None
,
'model_path must be specified'
assert
model_path
is
not
None
,
'model_path must be specified'
for
name
,
m
in
self
.
bound_model
.
named_modules
():
mask_dict
=
{}
if
name
==
""
:
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
continue
masks
=
self
.
mask_dict
.
get
(
name
)
for
wrapper
in
self
.
get_modules_wrapper
():
if
masks
is
not
None
:
weight_mask
=
wrapper
.
weight_mask
mask_sum
=
masks
[
'weight'
].
sum
().
item
()
bias_mask
=
wrapper
.
bias_mask
mask_num
=
masks
[
'weight'
].
numel
()
if
weight_mask
is
not
None
:
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
name
,
1
-
mask_sum
/
mask_num
)
mask_sum
=
weight_mask
.
sum
().
item
()
m
.
weight
.
data
=
m
.
weight
.
data
.
mul
(
masks
[
'weight'
])
mask_num
=
weight_mask
.
numel
()
if
masks
.
__contains__
(
'bias'
)
and
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
wrapper
.
name
,
1
-
mask_sum
/
mask_num
)
m
.
bias
.
data
=
m
.
bias
.
data
.
mul
(
masks
[
'bias'
])
wrapper
.
module
.
weight
.
data
=
wrapper
.
module
.
weight
.
data
.
mul
(
weight_mask
)
else
:
if
bias_mask
is
not
None
:
_logger
.
info
(
'Layer: %s NOT compressed'
,
name
)
wrapper
.
module
.
bias
.
data
=
wrapper
.
module
.
bias
.
data
.
mul
(
bias_mask
)
# save mask to dict
mask_dict
[
wrapper
.
name
]
=
{
"weight"
:
weight_mask
,
"bias"
:
bias_mask
}
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
mask_path
is
not
None
:
if
mask_path
is
not
None
:
torch
.
save
(
self
.
mask_dict
,
mask_path
)
torch
.
save
(
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
if
onnx_path
is
not
None
:
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
)
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
self
.
_wrap_model
()
def
load_model_state_dict
(
self
,
model_state
):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if
self
.
is_wrapped
:
self
.
_unwrap_model
()
self
.
bound_model
.
load_state_dict
(
model_state
)
self
.
_wrap_model
()
else
:
self
.
bound_model
.
load_state_dict
(
model_state
)
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
self
.
type
=
module_type
# config and pruner
self
.
config
=
config
self
.
quantizer
=
quantizer
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if
'weight'
in
config
[
'quant_types'
]:
if
not
_check_weight
(
self
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
self
.
name
)
else
:
self
.
module
.
register_parameter
(
'old_weight'
,
torch
.
nn
.
Parameter
(
self
.
module
.
weight
))
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# register user specified buffer
self
.
registered_buffers
=
{}
for
name
in
self
.
quantizer
.
buffers
:
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
registered_buffers
[
name
]
=
getattr
(
self
,
name
)
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
.
apply
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
.
quantizer
.
quantize_input
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
self
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantizer
.
quantize_weight
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
else
:
result
=
self
.
module
(
*
inputs
)
if
'output'
in
self
.
config
[
'quant_types'
]:
result
=
self
.
quantizer
.
quant_grad
.
apply
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantizer
.
quantize_output
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
registered_buffers
)
return
result
class
Quantizer
(
Compressor
):
class
Quantizer
(
Compressor
):
"""
"""
...
@@ -303,7 +497,7 @@ class Quantizer(Compressor):
...
@@ -303,7 +497,7 @@ class Quantizer(Compressor):
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
"""
Create a wrapper forward function to replace the original one.
Create a wrapper forward function to replace the original one.
Parameters
Parameters
...
@@ -313,7 +507,6 @@ class Quantizer(Compressor):
...
@@ -313,7 +507,6 @@ class Quantizer(Compressor):
config : dict
config : dict
the configuration for quantization
the configuration for quantization
"""
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
assert
'quant_types'
in
config
,
'must provide quant_types in config'
assert
'quant_types'
in
config
,
'must provide quant_types in config'
assert
isinstance
(
config
[
'quant_types'
],
list
),
'quant_types must be list type'
assert
isinstance
(
config
[
'quant_types'
],
list
),
'quant_types must be list type'
assert
'quant_bits'
in
config
,
'must provide quant_bits in config'
assert
'quant_bits'
in
config
,
'must provide quant_bits in config'
...
@@ -323,35 +516,7 @@ class Quantizer(Compressor):
...
@@ -323,35 +516,7 @@ class Quantizer(Compressor):
for
quant_type
in
config
[
'quant_types'
]:
for
quant_type
in
config
[
'quant_types'
]:
assert
quant_type
in
config
[
'quant_bits'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
assert
quant_type
in
config
[
'quant_bits'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
if
'weight'
in
config
[
'quant_types'
]:
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
if
not
_check_weight
(
layer
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
else
:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer
.
module
.
register_parameter
(
'old_weight'
,
torch
.
nn
.
Parameter
(
layer
.
module
.
weight
))
delattr
(
layer
.
module
,
'weight'
)
layer
.
module
.
register_buffer
(
'weight'
,
layer
.
module
.
old_weight
)
layer
.
_forward
=
layer
.
module
.
forward
def
new_forward
(
*
inputs
):
if
'input'
in
config
[
'quant_types'
]:
inputs
=
self
.
quant_grad
.
apply
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
.
quantize_input
,
config
,
layer
)
if
'weight'
in
config
[
'quant_types'
]
and
_check_weight
(
layer
.
module
):
new_weight
=
self
.
quant_grad
.
apply
(
layer
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantize_weight
,
config
,
layer
)
layer
.
module
.
weight
=
new_weight
result
=
layer
.
_forward
(
*
inputs
)
else
:
result
=
layer
.
_forward
(
*
inputs
)
if
'output'
in
config
[
'quant_types'
]:
result
=
self
.
quant_grad
.
apply
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantize_output
,
config
,
layer
)
return
result
layer
.
module
.
forward
=
new_forward
class
QuantType
:
class
QuantType
:
"""
"""
...
@@ -387,19 +552,18 @@ class QuantGrad(torch.autograd.Function):
...
@@ -387,19 +552,18 @@ class QuantGrad(torch.autograd.Function):
return
grad_output
return
grad_output
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
layer
):
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
layer
,
**
kwargs
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
,
**
kwargs
)
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
tensor
,
quant_type
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
return
output
,
None
,
None
,
None
,
None
return
output
,
None
,
None
,
None
,
None
,
None
def
_check_weight
(
module
):
def
_check_weight
(
module
):
try
:
try
:
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
except
AttributeError
:
except
AttributeError
:
return
False
return
False
\ No newline at end of file
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
ed121315
...
@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
...
@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Calculate the mask of given layer
Calculate the mask of given layer
Parameters
Parameters
...
@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
...
@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
if_calculated
=
kwargs
[
"if_calculated"
]
if
op_name
not
in
self
.
mask_calculated_ops
:
if
not
if_calculated
:
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
if
k
==
0
:
...
@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
...
@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
mask
=
{
'weight'
:
mask_weight
}
self
.
mask_dict
.
update
({
op_name
:
mask
})
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
else
:
else
:
assert
op_name
in
self
.
mask_dict
,
"op_name not in the mask_dict"
return
None
mask
=
self
.
mask_dict
[
op_name
]
return
mask
class
AGP_Pruner
(
Pruner
):
class
AGP_Pruner
(
Pruner
):
...
@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner):
...
@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
Parameters
----------
----------
layer : LayerInfo
layer : LayerInfo
the layer to instrument the compression operation
the layer to instrument the compression operation
config : dict
config : dict
layer's pruning config
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
Returns
-------
-------
dict
dict
...
@@ -102,24 +104,26 @@ class AGP_Pruner(Pruner):
...
@@ -102,24 +104,26 @@ class AGP_Pruner(Pruner):
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
if_calculated
=
kwargs
[
"if_calculated"
]
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
if
if_calculated
:
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
return
None
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
None
return
mask
# if we want to generate new mask, we should update weigth first
mask
=
{
'weight'
:
kwargs
[
'weight_mask'
]
if
'weight_mask'
in
kwargs
else
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)}
w_abs
=
weight
.
abs
()
*
mask
[
'weight'
]
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
return
mask
self
.
if_init_list
.
update
({
op_name
:
False
})
# if we want to generate new mask, we should update weigth first
else
:
w_abs
=
weight
.
abs
()
*
mask
[
'weight'
]
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
new_mask
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
def
compute_target_sparsity
(
self
,
config
):
...
@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner):
...
@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
if
epoch
>
0
:
self
.
now_epoch
=
epoch
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
for
wrapper
in
self
.
get_modules_wrapper
():
self
.
if_init_list
[
k
]
=
True
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
class
SlimPruner
(
Pruner
):
class
SlimPruner
(
Pruner
):
"""
"""
...
@@ -187,7 +190,6 @@ class SlimPruner(Pruner):
...
@@ -187,7 +190,6 @@ class SlimPruner(Pruner):
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
weight_list
=
[]
if
len
(
config_list
)
>
1
:
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
...
@@ -198,8 +200,9 @@ class SlimPruner(Pruner):
...
@@ -198,8 +200,9 @@ class SlimPruner(Pruner):
all_bn_weights
=
torch
.
cat
(
weight_list
)
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Scale factors with the smallest absolute value in the BN layer are masked.
...
@@ -209,6 +212,8 @@ class SlimPruner(Pruner):
...
@@ -209,6 +212,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation
the layer to instrument the compression operation
config : dict
config : dict
layer's pruning config
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
Returns
-------
-------
dict
dict
...
@@ -216,27 +221,21 @@ class SlimPruner(Pruner):
...
@@ -216,27 +221,21 @@ class SlimPruner(Pruner):
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
op_type
=
layer
.
type
if_calculated
=
kwargs
[
"if_calculated"
]
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
if
if_calculated
:
assert
op_name
in
self
.
mask_dict
return
None
return
self
.
mask_dict
.
get
(
op_name
)
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
try
:
filters
=
weight
.
size
(
0
)
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
>=
2
and
num_prune
>=
1
:
if
filters
<
2
or
num_prune
<
1
:
return
mask
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
finally
:
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
return
mask
class
LotteryTicketPruner
(
Pruner
):
class
LotteryTicketPruner
(
Pruner
):
...
@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner):
...
@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations
=
config
[
'prune_iterations'
]
prune_iterations
=
config
[
'prune_iterations'
]
return
prune_iterations
return
prune_iterations
def
_print_masks
(
self
,
print_mask
=
False
):
torch
.
set_printoptions
(
threshold
=
1000
)
for
op_name
in
self
.
mask_dict
.
keys
():
mask
=
self
.
mask_dict
[
op_name
]
print
(
'op name: '
,
op_name
)
if
print_mask
:
print
(
'mask: '
,
mask
)
# calculate current sparsity
mask_num
=
mask
[
'weight'
].
sum
().
item
()
mask_size
=
mask
[
'weight'
].
numel
()
print
(
'sparsity: '
,
1
-
mask_num
/
mask_size
)
torch
.
set_printoptions
(
profile
=
'default'
)
def
_calc_sparsity
(
self
,
sparsity
):
def
_calc_sparsity
(
self
,
sparsity
):
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
curr_keep_ratio
=
keep_ratio_once
**
self
.
curr_prune_iteration
curr_keep_ratio
=
keep_ratio_once
**
self
.
curr_prune_iteration
return
max
(
1
-
curr_keep_ratio
,
0
)
return
max
(
1
-
curr_keep_ratio
,
0
)
def
_calc_mask
(
self
,
weight
,
sparsity
,
op_name
):
def
_calc_mask
(
self
,
weight
,
sparsity
,
curr_w_mask
):
if
self
.
curr_prune_iteration
==
0
:
if
self
.
curr_prune_iteration
==
0
:
mask
=
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
mask
=
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
else
:
else
:
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
assert
self
.
mask_dict
.
get
(
op_name
)
is
not
None
w_abs
=
weight
.
abs
()
*
curr_w_mask
curr_mask
=
self
.
mask_dict
.
get
(
op_name
)
w_abs
=
weight
.
abs
()
*
curr_mask
[
'weight'
]
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask
}
return
{
'weight'
:
mask
}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Generate mask for the given ``weight``.
Generate mask for the given ``weight``.
...
@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner):
...
@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned
The layer to be pruned
config : dict
config : dict
Pruning configurations for this weight
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
Returns
-------
-------
tensor
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
"""
assert
self
.
mask_dict
.
get
(
layer
.
name
)
is
not
None
,
'Please call iteration_start before training'
return
None
mask
=
self
.
mask_dict
[
layer
.
name
]
return
mask
def
get_prune_iterations
(
self
):
def
get_prune_iterations
(
self
):
"""
"""
...
@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner):
...
@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner):
self
.
curr_prune_iteration
+=
1
self
.
curr_prune_iteration
+=
1
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
modules_wrapper
=
self
.
get_modules_wrapper
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
module_wrapper
=
None
for
wrapper
in
modules_wrapper
:
if
wrapper
.
name
==
layer
.
name
:
module_wrapper
=
wrapper
break
assert
module_wrapper
is
not
None
sparsity
=
config
.
get
(
'sparsity'
)
sparsity
=
config
.
get
(
'sparsity'
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
layer
.
name
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
module_wrapper
.
weight_mask
)
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
# TODO: directly use weight_mask is not good
self
.
_print_masks
()
module_wrapper
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# there is no mask for bias
# reinit weights back to original after new masks are generated
# reinit weights back to original after new masks are generated
if
self
.
reset_weights
:
if
self
.
reset_weights
:
self
.
_model
.
load_state_dict
(
self
.
_model_state
)
# should use this member function to reset model weights
self
.
load_model_state_dict
(
self
.
_model_state
)
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer_state
)
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer_state
)
if
self
.
_lr_scheduler
is
not
None
:
if
self
.
_lr_scheduler
is
not
None
:
self
.
_lr_scheduler
.
load_state_dict
(
self
.
_scheduler_state
)
self
.
_lr_scheduler
.
load_state_dict
(
self
.
_scheduler_state
)
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
View file @
ed121315
...
@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
...
@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
# operations whose mask has been calculated
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
Filters with the smallest importance criterion of the kernel weights are masked.
...
@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
...
@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
],
"only support Conv1d and Conv2d"
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
],
"only support Conv1d and Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask
_calculated
_ops
:
if
_calculated
=
kwargs
[
"if
_calculated
"
]
assert
op_name
in
self
.
mask_dict
if
if_calculated
:
return
self
.
mask_dict
.
get
(
op_name
)
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
...
@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
...
@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return
mask
return
mask
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
finally
:
finally
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
return
mask
...
@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return
x
.
sum
()
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
src/sdk/pynni/tests/test_compressor.py
View file @
ed121315
...
@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase):
...
@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase):
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
layer
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
layer
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
pruner
.
update_epoch
(
1
)
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
45.
,
45.
]))
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
45.
,
45.
]))
@
tf2
@
tf2
...
@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
...
@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert
all
(
masks
.
sum
((
1
))
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
assert
all
(
masks
.
sum
((
1
))
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
pruner
.
update_epoch
(
1
)
model
.
layers
[
2
].
set_weights
([
weights
[
0
],
weights
[
1
].
numpy
()])
model
.
layers
[
2
].
set_weights
([
weights
[
0
],
weights
[
1
].
numpy
()])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]).
numpy
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]).
numpy
()
masks
=
masks
.
reshape
((
-
1
,
masks
.
shape
[
-
1
])).
transpose
([
1
,
0
])
masks
=
masks
.
reshape
((
-
1
,
masks
.
shape
[
-
1
])).
transpose
([
1
,
0
])
...
@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase):
...
@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase):
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv1'
,
model
.
conv1
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv1'
,
model
.
conv1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
1
])
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
1
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
torch
.
sum
(
mask1
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
27.
,
27.
,
27.
,
27.
]))
assert
all
(
torch
.
sum
(
mask1
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
27.
,
27.
,
27.
,
27.
]))
assert
all
(
torch
.
sum
(
mask2
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
27.
,
27.
]))
assert
all
(
torch
.
sum
(
mask2
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
27.
,
27.
]))
...
@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase):
...
@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase):
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
mask1
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask1
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask2
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask2
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask1
[
'bias'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask1
[
'bias'
].
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
...
@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase):
...
@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase):
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
]
,
if_calculated
=
torch
.
tensor
(
0
)
)
assert
all
(
mask1
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask1
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask2
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask2
[
'weight'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask1
[
'bias'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask1
[
'bias'
].
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
...
@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase):
...
@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase):
# test ema
# test ema
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.002
,
abs_tol
=
eps
)
quantizer
.
step
()
quantizer
.
step
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment