Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
92f6754e
Unverified
Commit
92f6754e
authored
May 25, 2021
by
colorjam
Committed by
GitHub
May 25, 2021
Browse files
[Model Compression] Update api of iterative pruners (#3507)
parent
26f47727
Changes
45
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1010 additions
and
461 deletions
+1010
-461
examples/model_compress/pruning/naive_prune_torch.py
examples/model_compress/pruning/naive_prune_torch.py
+6
-6
examples/model_compress/quantization/BNN_quantizer_cifar10.py
...ples/model_compress/quantization/BNN_quantizer_cifar10.py
+0
-1
examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py
...odel_compress/quantization/DoReFaQuantizer_torch_mnist.py
+4
-22
examples/model_compress/quantization/QAT_torch_quantizer.py
examples/model_compress/quantization/QAT_torch_quantizer.py
+4
-23
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
...el_compress/quantization/mixed_precision_speedup_mnist.py
+5
-24
nni/algorithms/compression/pytorch/pruning/__init__.py
nni/algorithms/compression/pytorch/pruning/__init__.py
+4
-5
nni/algorithms/compression/pytorch/pruning/admm_pruner.py
nni/algorithms/compression/pytorch/pruning/admm_pruner.py
+0
-177
nni/algorithms/compression/pytorch/pruning/agp.py
nni/algorithms/compression/pytorch/pruning/agp.py
+0
-151
nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
...ithms/compression/pytorch/pruning/auto_compress_pruner.py
+16
-33
nni/algorithms/compression/pytorch/pruning/constants_pruner.py
...lgorithms/compression/pytorch/pruning/constants_pruner.py
+1
-1
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
...ms/compression/pytorch/pruning/dependency_aware_pruner.py
+162
-0
nni/algorithms/compression/pytorch/pruning/finegrained_pruning_masker.py
...compression/pytorch/pruning/finegrained_pruning_masker.py
+0
-0
nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
...lgorithms/compression/pytorch/pruning/iterative_pruner.py
+576
-0
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
+1
-1
nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
...algorithms/compression/pytorch/pruning/one_shot_pruner.py
+169
-0
nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py
.../compression/pytorch/pruning/structured_pruning_masker.py
+12
-4
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+10
-2
nni/algorithms/compression/tensorflow/pruning/__init__.py
nni/algorithms/compression/tensorflow/pruning/__init__.py
+1
-1
nni/algorithms/compression/tensorflow/pruning/one_shot_pruner.py
...orithms/compression/tensorflow/pruning/one_shot_pruner.py
+0
-0
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+39
-10
No files found.
examples/model_compress/pruning/naive_prune_torch.py
View file @
92f6754e
...
...
@@ -10,15 +10,16 @@ import logging
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torchvision
import
datasets
,
transforms
from
torch.optim.lr_scheduler
import
StepLR
from
models.mnist.lenet
import
LeNet
from
nni.algorithms.compression.pytorch.pruning
import
LevelPruner
import
nni
import
sys
sys
.
path
.
append
(
'../models'
)
from
mnist.lenet
import
LeNet
_logger
=
logging
.
getLogger
(
'mnist_example'
)
_logger
.
setLevel
(
logging
.
INFO
)
...
...
@@ -108,7 +109,7 @@ def main(args):
'op_types'
:
[
'default'
],
}]
pruner
=
LevelPruner
(
model
,
prune_config
,
optimizer_finetune
)
pruner
=
LevelPruner
(
model
,
prune_config
)
model
=
pruner
.
compress
()
# fine-tuning
...
...
@@ -149,5 +150,4 @@ if __name__ == '__main__':
help
=
'target overall target sparsity'
)
args
=
parser
.
parse_args
()
main
(
args
)
examples/model_compress/quantization/BNN_quantizer_cifar10.py
View file @
92f6754e
...
...
@@ -31,7 +31,6 @@ class VGG_Cifar10(nn.Module):
nn
.
BatchNorm2d
(
256
,
eps
=
1e-4
,
momentum
=
0.1
),
nn
.
Hardtanh
(
inplace
=
True
),
nn
.
Conv2d
(
256
,
512
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
512
,
eps
=
1e-4
,
momentum
=
0.1
),
nn
.
Hardtanh
(
inplace
=
True
),
...
...
examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py
View file @
92f6754e
...
...
@@ -3,27 +3,9 @@ import torch.nn.functional as F
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
DoReFaQuantizer
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
import
sys
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
...
...
@@ -66,7 +48,7 @@ def main():
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
=
NaiveModel
()
model
=
model
.
to
(
device
)
configure_list
=
[{
'quant_types'
:
[
'weight'
],
...
...
examples/model_compress/quantization/QAT_torch_quantizer.py
View file @
92f6754e
...
...
@@ -3,28 +3,9 @@ import torch.nn.functional as F
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
import
sys
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
model
.
train
()
...
...
@@ -66,7 +47,7 @@ def main():
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
=
NaiveModel
()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
...
...
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
View file @
92f6754e
...
...
@@ -5,28 +5,9 @@ from torchvision import datasets, transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.compression.pytorch.quantization_speedup
import
ModelSpeedupTensorRT
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
self
.
max_pool1
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
self
.
max_pool2
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
self
.
max_pool1
(
x
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
self
.
max_pool2
(
x
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
import
sys
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
def
train
(
model
,
device
,
train_loader
,
optimizer
):
...
...
@@ -74,7 +55,7 @@ def test_trt(engine, test_loader):
print
(
"Inference elapsed_time (whole dataset): {}s"
.
format
(
time_elasped
))
def
post_training_quantization_example
(
train_loader
,
test_loader
,
device
):
model
=
Mnist
()
model
=
NaiveModel
()
config
=
{
'conv1'
:{
'weight_bit'
:
8
,
'activation_bit'
:
8
},
...
...
@@ -99,7 +80,7 @@ def post_training_quantization_example(train_loader, test_loader, device):
test_trt
(
engine
,
test_loader
)
def
quantization_aware_training_example
(
train_loader
,
test_loader
,
device
):
model
=
Mnist
()
model
=
NaiveModel
()
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'output'
],
...
...
nni/algorithms/compression/pytorch/pruning/__init__.py
View file @
92f6754e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.finegrained_pruning
import
*
from
.structured_pruning
import
*
from
.one_shot
import
*
from
.
agp
import
*
from
.finegrained_pruning
_masker
import
*
from
.structured_pruning
_masker
import
*
from
.one_shot
_pruner
import
*
from
.
iterative_pruner
import
*
from
.lottery_ticket
import
LotteryTicketPruner
from
.simulated_annealing_pruner
import
SimulatedAnnealingPruner
from
.net_adapt_pruner
import
NetAdaptPruner
from
.admm_pruner
import
ADMMPruner
from
.auto_compress_pruner
import
AutoCompressPruner
from
.sensitivity_pruner
import
SensitivityPruner
from
.amc
import
AMCPruner
nni/algorithms/compression/pytorch/pruning/admm_pruner.py
deleted
100644 → 0
View file @
26f47727
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
from
schema
import
And
,
Optional
import
copy
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
.constants
import
MASKER_DICT
from
.one_shot
import
OneshotPruner
_logger
=
logging
.
getLogger
(
__name__
)
class
ADMMPruner
(
OneshotPruner
):
"""
A Pytorch implementation of ADMM Pruner algorithm.
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : list
List on pruning configs.
trainer : function
Function used for the first subproblem.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
num_iterations : int
Total number of iterations.
training_epochs : int
Training epochs of the first subproblem.
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
"""
def
__init__
(
self
,
model
,
config_list
,
trainer
,
num_iterations
=
30
,
training_epochs
=
5
,
row
=
1e-4
,
base_algo
=
'l1'
):
self
.
_base_algo
=
base_algo
super
().
__init__
(
model
,
config_list
)
self
.
_trainer
=
trainer
self
.
_num_iterations
=
num_iterations
self
.
_training_epochs
=
training_epochs
self
.
_row
=
row
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
masker
=
MASKER_DICT
[
self
.
_base_algo
](
self
.
bound_model
,
self
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
if
self
.
_base_algo
==
'level'
:
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
}],
model
,
_logger
)
elif
self
.
_base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
_projection
(
self
,
weight
,
sparsity
,
wrapper
):
'''
Return the Euclidean projection of the weight matrix according to the pruning mode.
Parameters
----------
weight : tensor
original matrix
sparsity : float
the ratio of parameters which need to be set to zero
wrapper: PrunerModuleWrapper
layer wrapper of this layer
Returns
-------
tensor
the projected matrix
'''
wrapper_copy
=
copy
.
deepcopy
(
wrapper
)
wrapper_copy
.
module
.
weight
.
data
=
weight
return
weight
.
data
.
mul
(
self
.
masker
.
calc_mask
(
sparsity
,
wrapper_copy
)[
'weight_mask'
])
def
compress
(
self
):
"""
Compress the model with ADMM.
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
_logger
.
info
(
'Starting ADMM Compression...'
)
# initiaze Z, U
# Z_i^0 = W_i^0
# U_i^0 = 0
Z
=
[]
U
=
[]
for
wrapper
in
self
.
get_modules_wrapper
():
z
=
wrapper
.
module
.
weight
.
data
Z
.
append
(
z
)
U
.
append
(
torch
.
zeros_like
(
z
))
optimizer
=
torch
.
optim
.
Adam
(
self
.
bound_model
.
parameters
(),
lr
=
1e-3
,
weight_decay
=
5e-5
)
# Loss = cross_entropy + l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
# callback function to do additonal optimization, refer to the deriatives of Formula (7)
def
callback
():
for
i
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
wrapper
.
module
.
weight
.
data
-=
self
.
_row
*
\
(
wrapper
.
module
.
weight
.
data
-
Z
[
i
]
+
U
[
i
])
# optimization iteration
for
k
in
range
(
self
.
_num_iterations
):
_logger
.
info
(
'ADMM iteration : %d'
,
k
)
# step 1: optimize W with AdamOptimizer
for
epoch
in
range
(
self
.
_training_epochs
):
self
.
_trainer
(
self
.
bound_model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
epoch
=
epoch
,
callback
=
callback
)
# step 2: update Z, U
# Z_i^{k+1} = projection(W_i^{k+1} + U_i^k)
# U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
for
i
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
z
=
wrapper
.
module
.
weight
.
data
+
U
[
i
]
Z
[
i
]
=
self
.
_projection
(
z
,
wrapper
.
config
[
'sparsity'
],
wrapper
)
U
[
i
]
=
U
[
i
]
+
wrapper
.
module
.
weight
.
data
-
Z
[
i
]
# apply prune
self
.
update_mask
()
_logger
.
info
(
'Compression finished.'
)
return
self
.
bound_model
nni/algorithms/compression/pytorch/pruning/agp.py
deleted
100644 → 0
View file @
26f47727
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices.
"""
import
logging
import
torch
from
schema
import
And
,
Optional
from
.constants
import
MASKER_DICT
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Pruner
__all__
=
[
'AGPPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
class
AGPPruner
(
Pruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : listlist
Supported keys:
- initial_sparsity: This is to specify the sparsity when compressor starts to compress.
- final_sparsity: This is to specify the sparsity when compressor finishes to compress.
- start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
- end_epoch: This is to specify the epoch number when compressor finishes to compress.
- frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
optimizer: torch.optim.Optimizer
Optimizer used to train model.
pruning_algorithm: str
Algorithms being used to prune model,
choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
pruning_algorithm
=
'level'
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self
.
masker
=
MASKER_DICT
[
pruning_algorithm
](
model
,
self
)
self
.
now_epoch
=
0
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema
=
CompressorSchema
([{
'initial_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
'final_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
'start_epoch'
:
And
(
int
,
lambda
n
:
n
>=
0
),
'end_epoch'
:
And
(
int
,
lambda
n
:
n
>=
0
),
'frequency'
:
And
(
int
,
lambda
n
:
n
>
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
calc_mask
(
self
,
wrapper
,
wrapper_idx
=
None
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict | None
Dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
config
=
wrapper
.
config
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
wrapper
.
if_calculated
:
return
None
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
return
None
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
new_mask
=
self
.
masker
.
calc_mask
(
sparsity
=
target_sparsity
,
wrapper
=
wrapper
,
wrapper_idx
=
wrapper_idx
)
if
new_mask
is
not
None
:
wrapper
.
if_calculated
=
True
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch
=
config
.
get
(
'end_epoch'
,
1
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
final_sparsity
=
config
.
get
(
'final_sparsity'
,
0
)
initial_sparsity
=
config
.
get
(
'initial_sparsity'
,
0
)
if
end_epoch
<=
start_epoch
or
initial_sparsity
>=
final_sparsity
:
logger
.
warning
(
'your end epoch <= start epoch or initial_sparsity >= final_sparsity'
)
return
final_sparsity
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
if_calculated
=
False
nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
View file @
92f6754e
...
...
@@ -13,8 +13,7 @@ from nni.compression.pytorch import ModelSpeedup
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
.simulated_annealing_pruner
import
SimulatedAnnealingPruner
from
.admm_pruner
import
ADMMPruner
from
.iterative_pruner
import
ADMMPruner
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -34,26 +33,7 @@ class AutoCompressPruner(Pruner):
trainer : function
Function used for the first subproblem of ADMM Pruner.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
and include `model, optimizer, criterion, epoch` as function arguments.
evaluator : function
function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value.
...
...
@@ -80,8 +60,8 @@ class AutoCompressPruner(Pruner):
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the
ops, the
assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Start temperature of the simulated annealing process.
stop_temperature : float
...
...
@@ -92,7 +72,7 @@ class AutoCompressPruner(Pruner):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
admm_num_iterations : int
Number of iterations of ADMM Pruner.
admm_
training_epochs
: int
admm_
epochs_per_iteration
: int
Training epochs of the first optimization subproblem of ADMMPruner.
row : float
Penalty parameters for ADMM training.
...
...
@@ -100,18 +80,19 @@ class AutoCompressPruner(Pruner):
PATH to store temporary experiment data.
"""
def
__init__
(
self
,
model
,
config_list
,
trainer
,
evaluator
,
dummy_input
,
def
__init__
(
self
,
model
,
config_list
,
trainer
,
criterion
,
evaluator
,
dummy_input
,
num_iterations
=
3
,
optimize_mode
=
'maximize'
,
base_algo
=
'l1'
,
# SimulatedAnnealing related
start_temperature
=
100
,
stop_temperature
=
20
,
cool_down_rate
=
0.9
,
perturbation_magnitude
=
0.35
,
# ADMM related
admm_num_iterations
=
30
,
admm_
training_epochs
=
5
,
row
=
1e-4
,
admm_num_iterations
=
30
,
admm_
epochs_per_iteration
=
5
,
row
=
1e-4
,
experiment_data_dir
=
'./'
):
# original model
self
.
_model_to_prune
=
model
self
.
_base_algo
=
base_algo
self
.
_trainer
=
trainer
self
.
_criterion
=
criterion
self
.
_evaluator
=
evaluator
self
.
_dummy_input
=
dummy_input
self
.
_num_iterations
=
num_iterations
...
...
@@ -125,7 +106,7 @@ class AutoCompressPruner(Pruner):
# hyper parameters for ADMM algorithm
self
.
_admm_num_iterations
=
admm_num_iterations
self
.
_admm_
training_epochs
=
admm_training_epochs
self
.
_admm_
epochs_per_iteration
=
admm_epochs_per_iteration
self
.
_row
=
row
# overall pruning rate
...
...
@@ -174,12 +155,12 @@ class AutoCompressPruner(Pruner):
"""
_logger
.
info
(
'Starting AutoCompress pruning...'
)
sparsity_each_round
=
1
-
pow
(
1
-
self
.
_sparsity
,
1
/
self
.
_num_iterations
)
sparsity_each_round
=
1
-
pow
(
1
-
self
.
_sparsity
,
1
/
self
.
_num_iterations
)
for
i
in
range
(
self
.
_num_iterations
):
_logger
.
info
(
'Pruning iteration: %d'
,
i
)
_logger
.
info
(
'Target sparsity this round: %s'
,
1
-
pow
(
1
-
sparsity_each_round
,
i
+
1
))
1
-
pow
(
1
-
sparsity_each_round
,
i
+
1
))
# SimulatedAnnealingPruner
_logger
.
info
(
...
...
@@ -204,9 +185,10 @@ class AutoCompressPruner(Pruner):
ADMMpruner
=
ADMMPruner
(
model
=
copy
.
deepcopy
(
self
.
_model_to_prune
),
config_list
=
config_list
,
criterion
=
self
.
_criterion
,
trainer
=
self
.
_trainer
,
num_iterations
=
self
.
_admm_num_iterations
,
training_epochs
=
self
.
_admm_training_epochs
,
epochs_per_iteration
=
self
.
_admm_epochs_per_iteration
,
row
=
self
.
_row
,
base_algo
=
self
.
_base_algo
)
ADMMpruner
.
compress
()
...
...
@@ -214,12 +196,13 @@ class AutoCompressPruner(Pruner):
ADMMpruner
.
export_model
(
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'model_admm_masked.pth'
),
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'mask.pth'
))
# use speed up to prune the model before next iteration, because SimulatedAnnealingPruner & ADMMPruner don't take masked models
# use speed up to prune the model before next iteration,
# because SimulatedAnnealingPruner & ADMMPruner don't take masked models
self
.
_model_to_prune
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'model_admm_masked.pth'
)))
masks_file
=
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'mask.pth'
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
next
(
self
.
_model_to_prune
.
parameters
()).
device
_logger
.
info
(
'Speeding up models...'
)
m_speedup
=
ModelSpeedup
(
self
.
_model_to_prune
,
self
.
_dummy_input
,
masks_file
,
device
)
...
...
nni/algorithms/compression/pytorch/pruning/constants_pruner.py
View file @
92f6754e
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from
.one_shot
import
LevelPruner
,
L1FilterPruner
,
L2FilterPruner
,
FPGMPruner
from
.one_shot
_pruner
import
LevelPruner
,
L1FilterPruner
,
L2FilterPruner
,
FPGMPruner
PRUNER_DICT
=
{
'level'
:
LevelPruner
,
...
...
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
0 → 100644
View file @
92f6754e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
schema
import
And
,
Optional
,
SchemaError
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.compression.pytorch.utils.shape_dependency
import
ChannelDependency
,
GroupDependency
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Pruner
from
.constants
import
MASKER_DICT
__all__
=
[
'DependencyAwarePruner'
]
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
DependencyAwarePruner
(
Pruner
):
"""
DependencyAwarePruner has two ways to calculate the masks
for conv layers. In the normal way, the DependencyAwarePruner
will calculate the mask of each layer separately. For example, each
conv layer determine which filters should be pruned according to its L1
norm. In constrast, in the dependency-aware way, the layers that in a
dependency group will be pruned jointly and these layers will be forced
to prune the same channels.
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
pruning_algorithm
=
'level'
,
dependency_aware
=
False
,
dummy_input
=
None
,
**
algo_kwargs
):
super
().
__init__
(
model
,
config_list
=
config_list
,
optimizer
=
optimizer
)
self
.
dependency_aware
=
dependency_aware
self
.
dummy_input
=
dummy_input
if
self
.
dependency_aware
:
if
not
self
.
_supported_dependency_aware
():
raise
ValueError
(
'This pruner does not support dependency aware!'
)
errmsg
=
"When dependency_aware is set, the dummy_input should not be None"
assert
self
.
dummy_input
is
not
None
,
errmsg
# Get the TorchModuleGraph of the target model
# to trace the model, we need to unwrap the wrappers
self
.
_unwrap_model
()
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
)
self
.
_wrap_model
()
self
.
channel_depen
=
ChannelDependency
(
traced_model
=
self
.
graph
.
trace
)
self
.
group_depen
=
GroupDependency
(
traced_model
=
self
.
graph
.
trace
)
self
.
channel_depen
=
self
.
channel_depen
.
dependency_sets
self
.
channel_depen
=
{
name
:
sets
for
sets
in
self
.
channel_depen
for
name
in
sets
}
self
.
group_depen
=
self
.
group_depen
.
dependency_sets
self
.
masker
=
MASKER_DICT
[
pruning_algorithm
](
model
,
self
,
**
algo_kwargs
)
# set the dependency-aware switch for the masker
self
.
masker
.
dependency_aware
=
dependency_aware
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
calc_mask
(
self
,
wrapper
,
wrapper_idx
=
None
):
if
not
wrapper
.
if_calculated
:
sparsity
=
wrapper
.
config
[
'sparsity'
]
masks
=
self
.
masker
.
calc_mask
(
sparsity
=
sparsity
,
wrapper
=
wrapper
,
wrapper_idx
=
wrapper_idx
)
# masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
if
masks
is
not
None
:
wrapper
.
if_calculated
=
True
return
masks
else
:
return
None
def
update_mask
(
self
):
if
not
self
.
dependency_aware
:
# if we use the normal way to update the mask,
# then call the update_mask of the father class
super
(
DependencyAwarePruner
,
self
).
update_mask
()
else
:
# if we update the mask in a dependency-aware way
# then we call _dependency_update_mask
self
.
_dependency_update_mask
()
def
validate_config
(
self
,
model
,
config_list
):
schema
=
CompressorSchema
([{
Optional
(
'sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
for
config
in
config_list
:
if
'exclude'
not
in
config
and
'sparsity'
not
in
config
:
raise
SchemaError
(
'Either sparisty or exclude must be specified!'
)
def
_supported_dependency_aware
(
self
):
raise
NotImplementedError
def
_dependency_calc_mask
(
self
,
wrappers
,
channel_dsets
,
wrappers_idx
=
None
):
"""
calculate the masks for the conv layers in the same
channel dependecy set. All the layers passed in have
the same number of channels.
Parameters
----------
wrappers: list
The list of the wrappers that in the same channel dependency
set.
wrappers_idx: list
The list of the indexes of wrapppers.
Returns
-------
masks: dict
A dict object that contains the masks of the layers in this
dependency group, the key is the name of the convolutional layers.
"""
# The number of the groups for each conv layers
# Note that, this number may be different from its
# original number of groups of filters.
groups
=
[
self
.
group_depen
[
_w
.
name
]
for
_w
in
wrappers
]
sparsities
=
[
_w
.
config
[
'sparsity'
]
for
_w
in
wrappers
]
masks
=
self
.
masker
.
calc_mask
(
sparsities
,
wrappers
,
wrappers_idx
,
channel_dsets
=
channel_dsets
,
groups
=
groups
)
if
masks
is
not
None
:
# if masks is None, then the mask calculation fails.
# for example, in activation related maskers, we should
# pass enough batches of data to the model, so that the
# masks can be calculated successfully.
for
_w
in
wrappers
:
_w
.
if_calculated
=
True
return
masks
def
_dependency_update_mask
(
self
):
"""
In the original update_mask, the wraper of each layer will update its
own mask according to the sparsity specified in the config_list. However, in
the _dependency_update_mask, we may prune several layers at the same
time according the sparsities and the channel/group dependencies.
"""
name2wrapper
=
{
x
.
name
:
x
for
x
in
self
.
get_modules_wrapper
()}
wrapper2index
=
{
x
:
i
for
i
,
x
in
enumerate
(
self
.
get_modules_wrapper
())}
for
wrapper
in
self
.
get_modules_wrapper
():
if
wrapper
.
if_calculated
:
continue
# find all the conv layers that have channel dependecy with this layer
# and prune all these layers at the same time.
_names
=
[
x
for
x
in
self
.
channel_depen
[
wrapper
.
name
]]
logger
.
info
(
'Pruning the dependent layers: %s'
,
','
.
join
(
_names
))
_wrappers
=
[
name2wrapper
[
name
]
for
name
in
_names
if
name
in
name2wrapper
]
_wrapper_idxes
=
[
wrapper2index
[
_w
]
for
_w
in
_wrappers
]
masks
=
self
.
_dependency_calc_mask
(
_wrappers
,
_names
,
wrappers_idx
=
_wrapper_idxes
)
if
masks
is
not
None
:
for
layer
in
masks
:
for
mask_type
in
masks
[
layer
]:
assert
hasattr
(
name2wrapper
[
layer
],
mask_type
),
"there is no attribute '%s' in wrapper on %s"
\
%
(
mask_type
,
layer
)
setattr
(
name2wrapper
[
layer
],
mask_type
,
masks
[
layer
][
mask_type
])
nni/algorithms/compression/pytorch/pruning/finegrained_pruning.py
→
nni/algorithms/compression/pytorch/pruning/finegrained_pruning
_masker
.py
View file @
92f6754e
File moved
nni/algorithms/compression/pytorch/pruning/
one_shot
.py
→
nni/algorithms/compression/pytorch/pruning/
iterative_pruner
.py
View file @
92f6754e
This diff is collapsed.
Click to expand it.
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
View file @
92f6754e
...
...
@@ -7,7 +7,7 @@ import torch
from
schema
import
And
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Pruner
from
.finegrained_pruning
import
LevelPrunerMasker
from
.finegrained_pruning
_masker
import
LevelPrunerMasker
logger
=
logging
.
getLogger
(
'torch pruner'
)
...
...
nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
0 → 100644
View file @
92f6754e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
schema
import
And
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
.dependency_aware_pruner
import
DependencyAwarePruner
__all__
=
[
'LevelPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
]
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
OneshotPruner
(
DependencyAwarePruner
):
"""
Prune model to an exact pruning level for one time.
"""
def
__init__
(
self
,
model
,
config_list
,
pruning_algorithm
=
'level'
,
dependency_aware
=
False
,
dummy_input
=
None
,
**
algo_kwargs
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
pruning_algorithm: str
algorithms being used to prune model
dependency_aware: bool
If prune the model in a dependency-aware way.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that,
the dummy_input should on the same device with the model.
algo_kwargs: dict
Additional parameters passed to pruning algorithm masker class
"""
super
().
__init__
(
model
,
config_list
,
None
,
pruning_algorithm
,
dependency_aware
,
dummy_input
,
**
algo_kwargs
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
class
LevelPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'level'
)
def
_supported_dependency_aware
(
self
):
return
False
class
L1FilterPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
_supported_dependency_aware
(
self
):
return
True
class
L2FilterPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l2'
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
_supported_dependency_aware
(
self
):
return
True
class
FPGMPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in FPGM Pruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'fpgm'
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
_supported_dependency_aware
(
self
):
return
True
nni/algorithms/compression/pytorch/pruning/structured_pruning.py
→
nni/algorithms/compression/pytorch/pruning/structured_pruning
_masker
.py
View file @
92f6754e
...
...
@@ -474,8 +474,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
def
__init__
(
self
,
model
,
pruner
,
statistics_batch_num
=
1
):
super
().
__init__
(
model
,
pruner
)
self
.
pruner
.
statistics_batch_num
=
statistics_batch_num
self
.
pruner
.
set_wrappers_attribute
(
"contribution"
,
None
)
self
.
pruner
.
iterations
=
0
self
.
pruner
.
set_wrappers_attribute
(
"contribution"
,
None
)
self
.
pruner
.
patch_optimizer
(
self
.
calc_contributions
)
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
,
wrapper
,
wrapper_idx
,
channel_masks
=
None
):
...
...
@@ -499,6 +499,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
"""
if
self
.
pruner
.
iterations
>=
self
.
pruner
.
statistics_batch_num
:
return
for
wrapper
in
self
.
pruner
.
get_modules_wrapper
():
filters
=
wrapper
.
module
.
weight
.
size
(
0
)
contribution
=
(
...
...
@@ -677,16 +678,24 @@ class SlimPrunerMasker(WeightMasker):
def
__init__
(
self
,
model
,
pruner
,
**
kwargs
):
super
().
__init__
(
model
,
pruner
)
self
.
global_threshold
=
None
def
_get_global_threshold
(
self
):
weight_list
=
[]
for
(
layer
,
_
)
in
pruner
.
get_modules_to_compress
():
for
(
layer
,
_
)
in
self
.
pruner
.
get_modules_to_compress
():
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
pruner
.
config_list
[
0
][
'sparsity'
])
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
self
.
pruner
.
config_list
[
0
][
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
print
(
f
'set global threshold to
{
self
.
global_threshold
}
'
)
def
calc_mask
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
assert
wrapper
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
self
.
global_threshold
is
None
:
self
.
_get_global_threshold
()
weight
=
wrapper
.
module
.
weight
.
data
.
clone
()
if
wrapper
.
weight_mask
is
not
None
:
# apply base mask for iterative pruning
...
...
@@ -706,7 +715,6 @@ class SlimPrunerMasker(WeightMasker):
),
'bias_mask'
:
mask_bias
.
detach
()}
return
mask
def
least_square_sklearn
(
X
,
Y
):
from
sklearn.linear_model
import
LinearRegression
reg
=
LinearRegression
(
fit_intercept
=
False
)
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
92f6754e
...
...
@@ -148,6 +148,7 @@ class QAT_Quantizer(Quantizer):
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
device
=
next
(
model
.
parameters
()).
device
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
...
...
@@ -161,7 +162,7 @@ class QAT_Quantizer(Quantizer):
layer
.
module
.
register_buffer
(
'activation_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_activation'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_activation'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
"""
...
...
@@ -359,7 +360,7 @@ class QAT_Quantizer(Quantizer):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self
.
bound_model
.
steps
+=
1
self
.
bound_model
.
steps
+=
1
class
DoReFaQuantizer
(
Quantizer
):
...
...
@@ -370,10 +371,12 @@ class DoReFaQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
"""
...
...
@@ -474,11 +477,13 @@ class BNNQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
ClipGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
"""
...
...
@@ -589,6 +594,7 @@ class LsqQuantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
QuantForward
()
modules_to_compress
=
self
.
get_modules_to_compress
()
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
...
...
@@ -631,6 +637,8 @@ class LsqQuantizer(Quantizer):
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
input_scale
})
self
.
bound_model
.
to
(
device
)
@
staticmethod
def
grad_scale
(
x
,
scale
):
"""
...
...
nni/algorithms/compression/tensorflow/pruning/__init__.py
View file @
92f6754e
from
.one_shot
import
*
from
.one_shot
_pruner
import
*
nni/algorithms/compression/tensorflow/pruning/one_shot.py
→
nni/algorithms/compression/tensorflow/pruning/one_shot
_pruner
.py
View file @
92f6754e
File moved
nni/compression/pytorch/compressor.py
View file @
92f6754e
...
...
@@ -8,7 +8,6 @@ from . import default_layers
_logger
=
logging
.
getLogger
(
__name__
)
class
LayerInfo
:
def
__init__
(
self
,
name
,
module
):
self
.
module
=
module
...
...
@@ -235,7 +234,6 @@ class Compressor:
"""
raise
NotImplementedError
()
def
add_activation_collector
(
self
,
collector
):
self
.
_fwd_hook_id
+=
1
self
.
_fwd_hook_handles
[
self
.
_fwd_hook_id
]
=
[]
...
...
@@ -264,6 +262,18 @@ class Compressor:
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
step
=
types
.
MethodType
(
patch_step
(
self
.
optimizer
.
step
),
self
.
optimizer
)
def
patch_optimizer_before
(
self
,
*
tasks
):
def
patch_step
(
old_step
):
def
new_step
(
_
,
*
args
,
**
kwargs
):
for
task
in
tasks
:
task
()
# call origin optimizer step method
output
=
old_step
(
*
args
,
**
kwargs
)
return
output
return
new_step
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
step
=
types
.
MethodType
(
patch_step
(
self
.
optimizer
.
step
),
self
.
optimizer
)
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
"""
...
...
@@ -319,8 +329,6 @@ class Pruner(Compressor):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
if
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
update_mask
)
def
compress
(
self
):
self
.
update_mask
()
...
...
@@ -433,6 +441,27 @@ class Pruner(Compressor):
else
:
self
.
bound_model
.
load_state_dict
(
model_state
)
def
get_pruned_weights
(
self
,
dim
=
0
):
"""
Log the simulated prune sparsity.
Parameters
----------
dim : int
the pruned dim.
"""
for
_
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
weight_mask
=
wrapper
.
weight_mask
mask_size
=
weight_mask
.
size
()
if
len
(
mask_size
)
==
1
:
index
=
torch
.
nonzero
(
weight_mask
.
abs
()
!=
0
).
tolist
()
else
:
sum_idx
=
list
(
range
(
len
(
mask_size
)))
sum_idx
.
remove
(
dim
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
).
tolist
()
_logger
.
info
(
f
'simulated prune
{
wrapper
.
name
}
remain/total:
{
len
(
index
)
}
/
{
weight_mask
.
size
(
dim
)
}
'
)
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
"""
...
...
@@ -549,7 +578,6 @@ class Quantizer(Compressor):
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
def
_wrap_modules
(
self
,
layer
,
config
):
"""
Create a wrapper forward function to replace the original one.
...
...
@@ -571,7 +599,7 @@ class Quantizer(Compressor):
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
\
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
This method helps save pytorch model, calibration config, onnx model in quantizer.
...
...
@@ -671,6 +699,7 @@ class QuantGrad(torch.autograd.Function):
quantized x without clamped
"""
return
((
x
/
scale
)
+
zero_point
).
round
()
@
classmethod
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
...
...
@@ -703,8 +732,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
gradient of the output of quantization operation
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
`QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
`QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment