Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
92f6754e
"src/vscode:/vscode.git/clone" did not exist on "263579c2cd6ee0633e9352b68f1f294f316caae5"
Unverified
Commit
92f6754e
authored
May 25, 2021
by
colorjam
Committed by
GitHub
May 25, 2021
Browse files
[Model Compression] Update api of iterative pruners (#3507)
parent
26f47727
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1010 additions
and
461 deletions
+1010
-461
examples/model_compress/pruning/naive_prune_torch.py
examples/model_compress/pruning/naive_prune_torch.py
+6
-6
examples/model_compress/quantization/BNN_quantizer_cifar10.py
...ples/model_compress/quantization/BNN_quantizer_cifar10.py
+0
-1
examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py
...odel_compress/quantization/DoReFaQuantizer_torch_mnist.py
+4
-22
examples/model_compress/quantization/QAT_torch_quantizer.py
examples/model_compress/quantization/QAT_torch_quantizer.py
+4
-23
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
...el_compress/quantization/mixed_precision_speedup_mnist.py
+5
-24
nni/algorithms/compression/pytorch/pruning/__init__.py
nni/algorithms/compression/pytorch/pruning/__init__.py
+4
-5
nni/algorithms/compression/pytorch/pruning/admm_pruner.py
nni/algorithms/compression/pytorch/pruning/admm_pruner.py
+0
-177
nni/algorithms/compression/pytorch/pruning/agp.py
nni/algorithms/compression/pytorch/pruning/agp.py
+0
-151
nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
...ithms/compression/pytorch/pruning/auto_compress_pruner.py
+16
-33
nni/algorithms/compression/pytorch/pruning/constants_pruner.py
...lgorithms/compression/pytorch/pruning/constants_pruner.py
+1
-1
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
...ms/compression/pytorch/pruning/dependency_aware_pruner.py
+162
-0
nni/algorithms/compression/pytorch/pruning/finegrained_pruning_masker.py
...compression/pytorch/pruning/finegrained_pruning_masker.py
+0
-0
nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
...lgorithms/compression/pytorch/pruning/iterative_pruner.py
+576
-0
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
+1
-1
nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
...algorithms/compression/pytorch/pruning/one_shot_pruner.py
+169
-0
nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py
.../compression/pytorch/pruning/structured_pruning_masker.py
+12
-4
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+10
-2
nni/algorithms/compression/tensorflow/pruning/__init__.py
nni/algorithms/compression/tensorflow/pruning/__init__.py
+1
-1
nni/algorithms/compression/tensorflow/pruning/one_shot_pruner.py
...orithms/compression/tensorflow/pruning/one_shot_pruner.py
+0
-0
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+39
-10
No files found.
examples/model_compress/pruning/naive_prune_torch.py
View file @
92f6754e
...
@@ -10,15 +10,16 @@ import logging
...
@@ -10,15 +10,16 @@ import logging
import
argparse
import
argparse
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
torch.optim.lr_scheduler
import
StepLR
from
torch.optim.lr_scheduler
import
StepLR
from
models.mnist.lenet
import
LeNet
from
nni.algorithms.compression.pytorch.pruning
import
LevelPruner
from
nni.algorithms.compression.pytorch.pruning
import
LevelPruner
import
nni
import
sys
sys
.
path
.
append
(
'../models'
)
from
mnist.lenet
import
LeNet
_logger
=
logging
.
getLogger
(
'mnist_example'
)
_logger
=
logging
.
getLogger
(
'mnist_example'
)
_logger
.
setLevel
(
logging
.
INFO
)
_logger
.
setLevel
(
logging
.
INFO
)
...
@@ -108,7 +109,7 @@ def main(args):
...
@@ -108,7 +109,7 @@ def main(args):
'op_types'
:
[
'default'
],
'op_types'
:
[
'default'
],
}]
}]
pruner
=
LevelPruner
(
model
,
prune_config
,
optimizer_finetune
)
pruner
=
LevelPruner
(
model
,
prune_config
)
model
=
pruner
.
compress
()
model
=
pruner
.
compress
()
# fine-tuning
# fine-tuning
...
@@ -149,5 +150,4 @@ if __name__ == '__main__':
...
@@ -149,5 +150,4 @@ if __name__ == '__main__':
help
=
'target overall target sparsity'
)
help
=
'target overall target sparsity'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
\ No newline at end of file
examples/model_compress/quantization/BNN_quantizer_cifar10.py
View file @
92f6754e
...
@@ -31,7 +31,6 @@ class VGG_Cifar10(nn.Module):
...
@@ -31,7 +31,6 @@ class VGG_Cifar10(nn.Module):
nn
.
BatchNorm2d
(
256
,
eps
=
1e-4
,
momentum
=
0.1
),
nn
.
BatchNorm2d
(
256
,
eps
=
1e-4
,
momentum
=
0.1
),
nn
.
Hardtanh
(
inplace
=
True
),
nn
.
Hardtanh
(
inplace
=
True
),
nn
.
Conv2d
(
256
,
512
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
256
,
512
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
512
,
eps
=
1e-4
,
momentum
=
0.1
),
nn
.
BatchNorm2d
(
512
,
eps
=
1e-4
,
momentum
=
0.1
),
nn
.
Hardtanh
(
inplace
=
True
),
nn
.
Hardtanh
(
inplace
=
True
),
...
...
examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py
View file @
92f6754e
...
@@ -3,27 +3,9 @@ import torch.nn.functional as F
...
@@ -3,27 +3,9 @@ import torch.nn.functional as F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
DoReFaQuantizer
from
nni.algorithms.compression.pytorch.quantization
import
DoReFaQuantizer
import
sys
class
Mnist
(
torch
.
nn
.
Module
):
sys
.
path
.
append
(
'../models'
)
def
__init__
(
self
):
from
mnist.naive
import
NaiveModel
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
...
@@ -66,7 +48,7 @@ def main():
...
@@ -66,7 +48,7 @@ def main():
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
=
NaiveModel
()
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
configure_list
=
[{
configure_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_types'
:
[
'weight'
],
...
...
examples/model_compress/quantization/QAT_torch_quantizer.py
View file @
92f6754e
...
@@ -3,28 +3,9 @@ import torch.nn.functional as F
...
@@ -3,28 +3,9 @@ import torch.nn.functional as F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
import
sys
class
Mnist
(
torch
.
nn
.
Module
):
sys
.
path
.
append
(
'../models'
)
def
__init__
(
self
):
from
mnist.naive
import
NaiveModel
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
model
.
train
()
model
.
train
()
...
@@ -66,7 +47,7 @@ def main():
...
@@ -66,7 +47,7 @@ def main():
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
=
NaiveModel
()
'''you can change this to DoReFaQuantizer to implement it
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
DoReFaQuantizer(configure_list).compress(model)
'''
'''
...
...
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
View file @
92f6754e
...
@@ -5,28 +5,9 @@ from torchvision import datasets, transforms
...
@@ -5,28 +5,9 @@ from torchvision import datasets, transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.compression.pytorch.quantization_speedup
import
ModelSpeedupTensorRT
from
nni.compression.pytorch.quantization_speedup
import
ModelSpeedupTensorRT
class
Mnist
(
torch
.
nn
.
Module
):
import
sys
def
__init__
(
self
):
sys
.
path
.
append
(
'../models'
)
super
().
__init__
()
from
mnist.naive
import
NaiveModel
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
self
.
max_pool1
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
self
.
max_pool2
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
self
.
max_pool1
(
x
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
self
.
max_pool2
(
x
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
train
(
model
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
device
,
train_loader
,
optimizer
):
...
@@ -74,7 +55,7 @@ def test_trt(engine, test_loader):
...
@@ -74,7 +55,7 @@ def test_trt(engine, test_loader):
print
(
"Inference elapsed_time (whole dataset): {}s"
.
format
(
time_elasped
))
print
(
"Inference elapsed_time (whole dataset): {}s"
.
format
(
time_elasped
))
def
post_training_quantization_example
(
train_loader
,
test_loader
,
device
):
def
post_training_quantization_example
(
train_loader
,
test_loader
,
device
):
model
=
Mnist
()
model
=
NaiveModel
()
config
=
{
config
=
{
'conv1'
:{
'weight_bit'
:
8
,
'activation_bit'
:
8
},
'conv1'
:{
'weight_bit'
:
8
,
'activation_bit'
:
8
},
...
@@ -99,7 +80,7 @@ def post_training_quantization_example(train_loader, test_loader, device):
...
@@ -99,7 +80,7 @@ def post_training_quantization_example(train_loader, test_loader, device):
test_trt
(
engine
,
test_loader
)
test_trt
(
engine
,
test_loader
)
def
quantization_aware_training_example
(
train_loader
,
test_loader
,
device
):
def
quantization_aware_training_example
(
train_loader
,
test_loader
,
device
):
model
=
Mnist
()
model
=
NaiveModel
()
configure_list
=
[{
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'output'
],
'quant_types'
:
[
'weight'
,
'output'
],
...
...
nni/algorithms/compression/pytorch/pruning/__init__.py
View file @
92f6754e
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.finegrained_pruning
import
*
from
.finegrained_pruning
_masker
import
*
from
.structured_pruning
import
*
from
.structured_pruning
_masker
import
*
from
.one_shot
import
*
from
.one_shot
_pruner
import
*
from
.
agp
import
*
from
.
iterative_pruner
import
*
from
.lottery_ticket
import
LotteryTicketPruner
from
.lottery_ticket
import
LotteryTicketPruner
from
.simulated_annealing_pruner
import
SimulatedAnnealingPruner
from
.simulated_annealing_pruner
import
SimulatedAnnealingPruner
from
.net_adapt_pruner
import
NetAdaptPruner
from
.net_adapt_pruner
import
NetAdaptPruner
from
.admm_pruner
import
ADMMPruner
from
.auto_compress_pruner
import
AutoCompressPruner
from
.auto_compress_pruner
import
AutoCompressPruner
from
.sensitivity_pruner
import
SensitivityPruner
from
.sensitivity_pruner
import
SensitivityPruner
from
.amc
import
AMCPruner
from
.amc
import
AMCPruner
nni/algorithms/compression/pytorch/pruning/admm_pruner.py
deleted
100644 → 0
View file @
26f47727
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
from
schema
import
And
,
Optional
import
copy
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
.constants
import
MASKER_DICT
from
.one_shot
import
OneshotPruner
_logger
=
logging
.
getLogger
(
__name__
)
class
ADMMPruner
(
OneshotPruner
):
"""
A Pytorch implementation of ADMM Pruner algorithm.
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : list
List on pruning configs.
trainer : function
Function used for the first subproblem.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
num_iterations : int
Total number of iterations.
training_epochs : int
Training epochs of the first subproblem.
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
"""
def
__init__
(
self
,
model
,
config_list
,
trainer
,
num_iterations
=
30
,
training_epochs
=
5
,
row
=
1e-4
,
base_algo
=
'l1'
):
self
.
_base_algo
=
base_algo
super
().
__init__
(
model
,
config_list
)
self
.
_trainer
=
trainer
self
.
_num_iterations
=
num_iterations
self
.
_training_epochs
=
training_epochs
self
.
_row
=
row
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
masker
=
MASKER_DICT
[
self
.
_base_algo
](
self
.
bound_model
,
self
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
if
self
.
_base_algo
==
'level'
:
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
}],
model
,
_logger
)
elif
self
.
_base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
_projection
(
self
,
weight
,
sparsity
,
wrapper
):
'''
Return the Euclidean projection of the weight matrix according to the pruning mode.
Parameters
----------
weight : tensor
original matrix
sparsity : float
the ratio of parameters which need to be set to zero
wrapper: PrunerModuleWrapper
layer wrapper of this layer
Returns
-------
tensor
the projected matrix
'''
wrapper_copy
=
copy
.
deepcopy
(
wrapper
)
wrapper_copy
.
module
.
weight
.
data
=
weight
return
weight
.
data
.
mul
(
self
.
masker
.
calc_mask
(
sparsity
,
wrapper_copy
)[
'weight_mask'
])
def
compress
(
self
):
"""
Compress the model with ADMM.
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
_logger
.
info
(
'Starting ADMM Compression...'
)
# initiaze Z, U
# Z_i^0 = W_i^0
# U_i^0 = 0
Z
=
[]
U
=
[]
for
wrapper
in
self
.
get_modules_wrapper
():
z
=
wrapper
.
module
.
weight
.
data
Z
.
append
(
z
)
U
.
append
(
torch
.
zeros_like
(
z
))
optimizer
=
torch
.
optim
.
Adam
(
self
.
bound_model
.
parameters
(),
lr
=
1e-3
,
weight_decay
=
5e-5
)
# Loss = cross_entropy + l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
# callback function to do additonal optimization, refer to the deriatives of Formula (7)
def
callback
():
for
i
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
wrapper
.
module
.
weight
.
data
-=
self
.
_row
*
\
(
wrapper
.
module
.
weight
.
data
-
Z
[
i
]
+
U
[
i
])
# optimization iteration
for
k
in
range
(
self
.
_num_iterations
):
_logger
.
info
(
'ADMM iteration : %d'
,
k
)
# step 1: optimize W with AdamOptimizer
for
epoch
in
range
(
self
.
_training_epochs
):
self
.
_trainer
(
self
.
bound_model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
epoch
=
epoch
,
callback
=
callback
)
# step 2: update Z, U
# Z_i^{k+1} = projection(W_i^{k+1} + U_i^k)
# U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
for
i
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
z
=
wrapper
.
module
.
weight
.
data
+
U
[
i
]
Z
[
i
]
=
self
.
_projection
(
z
,
wrapper
.
config
[
'sparsity'
],
wrapper
)
U
[
i
]
=
U
[
i
]
+
wrapper
.
module
.
weight
.
data
-
Z
[
i
]
# apply prune
self
.
update_mask
()
_logger
.
info
(
'Compression finished.'
)
return
self
.
bound_model
nni/algorithms/compression/pytorch/pruning/agp.py
deleted
100644 → 0
View file @
26f47727
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices.
"""
import
logging
import
torch
from
schema
import
And
,
Optional
from
.constants
import
MASKER_DICT
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Pruner
__all__
=
[
'AGPPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
class
AGPPruner
(
Pruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : listlist
Supported keys:
- initial_sparsity: This is to specify the sparsity when compressor starts to compress.
- final_sparsity: This is to specify the sparsity when compressor finishes to compress.
- start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
- end_epoch: This is to specify the epoch number when compressor finishes to compress.
- frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
optimizer: torch.optim.Optimizer
Optimizer used to train model.
pruning_algorithm: str
Algorithms being used to prune model,
choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
pruning_algorithm
=
'level'
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self
.
masker
=
MASKER_DICT
[
pruning_algorithm
](
model
,
self
)
self
.
now_epoch
=
0
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema
=
CompressorSchema
([{
'initial_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
'final_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
'start_epoch'
:
And
(
int
,
lambda
n
:
n
>=
0
),
'end_epoch'
:
And
(
int
,
lambda
n
:
n
>=
0
),
'frequency'
:
And
(
int
,
lambda
n
:
n
>
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
calc_mask
(
self
,
wrapper
,
wrapper_idx
=
None
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict | None
Dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
config
=
wrapper
.
config
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
wrapper
.
if_calculated
:
return
None
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
return
None
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
new_mask
=
self
.
masker
.
calc_mask
(
sparsity
=
target_sparsity
,
wrapper
=
wrapper
,
wrapper_idx
=
wrapper_idx
)
if
new_mask
is
not
None
:
wrapper
.
if_calculated
=
True
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch
=
config
.
get
(
'end_epoch'
,
1
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
final_sparsity
=
config
.
get
(
'final_sparsity'
,
0
)
initial_sparsity
=
config
.
get
(
'initial_sparsity'
,
0
)
if
end_epoch
<=
start_epoch
or
initial_sparsity
>=
final_sparsity
:
logger
.
warning
(
'your end epoch <= start epoch or initial_sparsity >= final_sparsity'
)
return
final_sparsity
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
if_calculated
=
False
nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
View file @
92f6754e
...
@@ -13,8 +13,7 @@ from nni.compression.pytorch import ModelSpeedup
...
@@ -13,8 +13,7 @@ from nni.compression.pytorch import ModelSpeedup
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
.simulated_annealing_pruner
import
SimulatedAnnealingPruner
from
.simulated_annealing_pruner
import
SimulatedAnnealingPruner
from
.admm_pruner
import
ADMMPruner
from
.iterative_pruner
import
ADMMPruner
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -34,26 +33,7 @@ class AutoCompressPruner(Pruner):
...
@@ -34,26 +33,7 @@ class AutoCompressPruner(Pruner):
trainer : function
trainer : function
Function used for the first subproblem of ADMM Pruner.
Function used for the first subproblem of ADMM Pruner.
Users should write this function as a normal function to train the Pytorch model
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
and include `model, optimizer, criterion, epoch` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
evaluator : function
evaluator : function
function to evaluate the pruned model.
function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value.
This function should include `model` as the only parameter, and returns a scalar value.
...
@@ -80,8 +60,8 @@ class AutoCompressPruner(Pruner):
...
@@ -80,8 +60,8 @@ class AutoCompressPruner(Pruner):
optimize_mode : str
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the ops,
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
the
ops, the
assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
start_temperature : float
Start temperature of the simulated annealing process.
Start temperature of the simulated annealing process.
stop_temperature : float
stop_temperature : float
...
@@ -92,7 +72,7 @@ class AutoCompressPruner(Pruner):
...
@@ -92,7 +72,7 @@ class AutoCompressPruner(Pruner):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
admm_num_iterations : int
admm_num_iterations : int
Number of iterations of ADMM Pruner.
Number of iterations of ADMM Pruner.
admm_
training_epochs
: int
admm_
epochs_per_iteration
: int
Training epochs of the first optimization subproblem of ADMMPruner.
Training epochs of the first optimization subproblem of ADMMPruner.
row : float
row : float
Penalty parameters for ADMM training.
Penalty parameters for ADMM training.
...
@@ -100,18 +80,19 @@ class AutoCompressPruner(Pruner):
...
@@ -100,18 +80,19 @@ class AutoCompressPruner(Pruner):
PATH to store temporary experiment data.
PATH to store temporary experiment data.
"""
"""
def
__init__
(
self
,
model
,
config_list
,
trainer
,
evaluator
,
dummy_input
,
def
__init__
(
self
,
model
,
config_list
,
trainer
,
criterion
,
evaluator
,
dummy_input
,
num_iterations
=
3
,
optimize_mode
=
'maximize'
,
base_algo
=
'l1'
,
num_iterations
=
3
,
optimize_mode
=
'maximize'
,
base_algo
=
'l1'
,
# SimulatedAnnealing related
# SimulatedAnnealing related
start_temperature
=
100
,
stop_temperature
=
20
,
cool_down_rate
=
0.9
,
perturbation_magnitude
=
0.35
,
start_temperature
=
100
,
stop_temperature
=
20
,
cool_down_rate
=
0.9
,
perturbation_magnitude
=
0.35
,
# ADMM related
# ADMM related
admm_num_iterations
=
30
,
admm_
training_epochs
=
5
,
row
=
1e-4
,
admm_num_iterations
=
30
,
admm_
epochs_per_iteration
=
5
,
row
=
1e-4
,
experiment_data_dir
=
'./'
):
experiment_data_dir
=
'./'
):
# original model
# original model
self
.
_model_to_prune
=
model
self
.
_model_to_prune
=
model
self
.
_base_algo
=
base_algo
self
.
_base_algo
=
base_algo
self
.
_trainer
=
trainer
self
.
_trainer
=
trainer
self
.
_criterion
=
criterion
self
.
_evaluator
=
evaluator
self
.
_evaluator
=
evaluator
self
.
_dummy_input
=
dummy_input
self
.
_dummy_input
=
dummy_input
self
.
_num_iterations
=
num_iterations
self
.
_num_iterations
=
num_iterations
...
@@ -125,7 +106,7 @@ class AutoCompressPruner(Pruner):
...
@@ -125,7 +106,7 @@ class AutoCompressPruner(Pruner):
# hyper parameters for ADMM algorithm
# hyper parameters for ADMM algorithm
self
.
_admm_num_iterations
=
admm_num_iterations
self
.
_admm_num_iterations
=
admm_num_iterations
self
.
_admm_
training_epochs
=
admm_training_epochs
self
.
_admm_
epochs_per_iteration
=
admm_epochs_per_iteration
self
.
_row
=
row
self
.
_row
=
row
# overall pruning rate
# overall pruning rate
...
@@ -174,12 +155,12 @@ class AutoCompressPruner(Pruner):
...
@@ -174,12 +155,12 @@ class AutoCompressPruner(Pruner):
"""
"""
_logger
.
info
(
'Starting AutoCompress pruning...'
)
_logger
.
info
(
'Starting AutoCompress pruning...'
)
sparsity_each_round
=
1
-
pow
(
1
-
self
.
_sparsity
,
1
/
self
.
_num_iterations
)
sparsity_each_round
=
1
-
pow
(
1
-
self
.
_sparsity
,
1
/
self
.
_num_iterations
)
for
i
in
range
(
self
.
_num_iterations
):
for
i
in
range
(
self
.
_num_iterations
):
_logger
.
info
(
'Pruning iteration: %d'
,
i
)
_logger
.
info
(
'Pruning iteration: %d'
,
i
)
_logger
.
info
(
'Target sparsity this round: %s'
,
_logger
.
info
(
'Target sparsity this round: %s'
,
1
-
pow
(
1
-
sparsity_each_round
,
i
+
1
))
1
-
pow
(
1
-
sparsity_each_round
,
i
+
1
))
# SimulatedAnnealingPruner
# SimulatedAnnealingPruner
_logger
.
info
(
_logger
.
info
(
...
@@ -204,9 +185,10 @@ class AutoCompressPruner(Pruner):
...
@@ -204,9 +185,10 @@ class AutoCompressPruner(Pruner):
ADMMpruner
=
ADMMPruner
(
ADMMpruner
=
ADMMPruner
(
model
=
copy
.
deepcopy
(
self
.
_model_to_prune
),
model
=
copy
.
deepcopy
(
self
.
_model_to_prune
),
config_list
=
config_list
,
config_list
=
config_list
,
criterion
=
self
.
_criterion
,
trainer
=
self
.
_trainer
,
trainer
=
self
.
_trainer
,
num_iterations
=
self
.
_admm_num_iterations
,
num_iterations
=
self
.
_admm_num_iterations
,
training_epochs
=
self
.
_admm_training_epochs
,
epochs_per_iteration
=
self
.
_admm_epochs_per_iteration
,
row
=
self
.
_row
,
row
=
self
.
_row
,
base_algo
=
self
.
_base_algo
)
base_algo
=
self
.
_base_algo
)
ADMMpruner
.
compress
()
ADMMpruner
.
compress
()
...
@@ -214,12 +196,13 @@ class AutoCompressPruner(Pruner):
...
@@ -214,12 +196,13 @@ class AutoCompressPruner(Pruner):
ADMMpruner
.
export_model
(
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'model_admm_masked.pth'
),
os
.
path
.
join
(
ADMMpruner
.
export_model
(
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'model_admm_masked.pth'
),
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'mask.pth'
))
self
.
_experiment_data_dir
,
'mask.pth'
))
# use speed up to prune the model before next iteration, because SimulatedAnnealingPruner & ADMMPruner don't take masked models
# use speed up to prune the model before next iteration,
# because SimulatedAnnealingPruner & ADMMPruner don't take masked models
self
.
_model_to_prune
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
self
.
_model_to_prune
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'model_admm_masked.pth'
)))
self
.
_experiment_data_dir
,
'model_admm_masked.pth'
)))
masks_file
=
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'mask.pth'
)
masks_file
=
os
.
path
.
join
(
self
.
_experiment_data_dir
,
'mask.pth'
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
next
(
self
.
_model_to_prune
.
parameters
()).
device
_logger
.
info
(
'Speeding up models...'
)
_logger
.
info
(
'Speeding up models...'
)
m_speedup
=
ModelSpeedup
(
self
.
_model_to_prune
,
self
.
_dummy_input
,
masks_file
,
device
)
m_speedup
=
ModelSpeedup
(
self
.
_model_to_prune
,
self
.
_dummy_input
,
masks_file
,
device
)
...
...
nni/algorithms/compression/pytorch/pruning/constants_pruner.py
View file @
92f6754e
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.one_shot
import
LevelPruner
,
L1FilterPruner
,
L2FilterPruner
,
FPGMPruner
from
.one_shot
_pruner
import
LevelPruner
,
L1FilterPruner
,
L2FilterPruner
,
FPGMPruner
PRUNER_DICT
=
{
PRUNER_DICT
=
{
'level'
:
LevelPruner
,
'level'
:
LevelPruner
,
...
...
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
0 → 100644
View file @
92f6754e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
schema
import
And
,
Optional
,
SchemaError
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.compression.pytorch.utils.shape_dependency
import
ChannelDependency
,
GroupDependency
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Pruner
from
.constants
import
MASKER_DICT
__all__
=
[
'DependencyAwarePruner'
]
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
DependencyAwarePruner
(
Pruner
):
"""
DependencyAwarePruner has two ways to calculate the masks
for conv layers. In the normal way, the DependencyAwarePruner
will calculate the mask of each layer separately. For example, each
conv layer determine which filters should be pruned according to its L1
norm. In constrast, in the dependency-aware way, the layers that in a
dependency group will be pruned jointly and these layers will be forced
to prune the same channels.
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
pruning_algorithm
=
'level'
,
dependency_aware
=
False
,
dummy_input
=
None
,
**
algo_kwargs
):
super
().
__init__
(
model
,
config_list
=
config_list
,
optimizer
=
optimizer
)
self
.
dependency_aware
=
dependency_aware
self
.
dummy_input
=
dummy_input
if
self
.
dependency_aware
:
if
not
self
.
_supported_dependency_aware
():
raise
ValueError
(
'This pruner does not support dependency aware!'
)
errmsg
=
"When dependency_aware is set, the dummy_input should not be None"
assert
self
.
dummy_input
is
not
None
,
errmsg
# Get the TorchModuleGraph of the target model
# to trace the model, we need to unwrap the wrappers
self
.
_unwrap_model
()
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
)
self
.
_wrap_model
()
self
.
channel_depen
=
ChannelDependency
(
traced_model
=
self
.
graph
.
trace
)
self
.
group_depen
=
GroupDependency
(
traced_model
=
self
.
graph
.
trace
)
self
.
channel_depen
=
self
.
channel_depen
.
dependency_sets
self
.
channel_depen
=
{
name
:
sets
for
sets
in
self
.
channel_depen
for
name
in
sets
}
self
.
group_depen
=
self
.
group_depen
.
dependency_sets
self
.
masker
=
MASKER_DICT
[
pruning_algorithm
](
model
,
self
,
**
algo_kwargs
)
# set the dependency-aware switch for the masker
self
.
masker
.
dependency_aware
=
dependency_aware
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
calc_mask
(
self
,
wrapper
,
wrapper_idx
=
None
):
if
not
wrapper
.
if_calculated
:
sparsity
=
wrapper
.
config
[
'sparsity'
]
masks
=
self
.
masker
.
calc_mask
(
sparsity
=
sparsity
,
wrapper
=
wrapper
,
wrapper_idx
=
wrapper_idx
)
# masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
if
masks
is
not
None
:
wrapper
.
if_calculated
=
True
return
masks
else
:
return
None
def
update_mask
(
self
):
if
not
self
.
dependency_aware
:
# if we use the normal way to update the mask,
# then call the update_mask of the father class
super
(
DependencyAwarePruner
,
self
).
update_mask
()
else
:
# if we update the mask in a dependency-aware way
# then we call _dependency_update_mask
self
.
_dependency_update_mask
()
def
validate_config
(
self
,
model
,
config_list
):
schema
=
CompressorSchema
([{
Optional
(
'sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
for
config
in
config_list
:
if
'exclude'
not
in
config
and
'sparsity'
not
in
config
:
raise
SchemaError
(
'Either sparisty or exclude must be specified!'
)
def
_supported_dependency_aware
(
self
):
raise
NotImplementedError
def
_dependency_calc_mask
(
self
,
wrappers
,
channel_dsets
,
wrappers_idx
=
None
):
"""
calculate the masks for the conv layers in the same
channel dependecy set. All the layers passed in have
the same number of channels.
Parameters
----------
wrappers: list
The list of the wrappers that in the same channel dependency
set.
wrappers_idx: list
The list of the indexes of wrapppers.
Returns
-------
masks: dict
A dict object that contains the masks of the layers in this
dependency group, the key is the name of the convolutional layers.
"""
# The number of the groups for each conv layers
# Note that, this number may be different from its
# original number of groups of filters.
groups
=
[
self
.
group_depen
[
_w
.
name
]
for
_w
in
wrappers
]
sparsities
=
[
_w
.
config
[
'sparsity'
]
for
_w
in
wrappers
]
masks
=
self
.
masker
.
calc_mask
(
sparsities
,
wrappers
,
wrappers_idx
,
channel_dsets
=
channel_dsets
,
groups
=
groups
)
if
masks
is
not
None
:
# if masks is None, then the mask calculation fails.
# for example, in activation related maskers, we should
# pass enough batches of data to the model, so that the
# masks can be calculated successfully.
for
_w
in
wrappers
:
_w
.
if_calculated
=
True
return
masks
def
_dependency_update_mask
(
self
):
"""
In the original update_mask, the wraper of each layer will update its
own mask according to the sparsity specified in the config_list. However, in
the _dependency_update_mask, we may prune several layers at the same
time according the sparsities and the channel/group dependencies.
"""
name2wrapper
=
{
x
.
name
:
x
for
x
in
self
.
get_modules_wrapper
()}
wrapper2index
=
{
x
:
i
for
i
,
x
in
enumerate
(
self
.
get_modules_wrapper
())}
for
wrapper
in
self
.
get_modules_wrapper
():
if
wrapper
.
if_calculated
:
continue
# find all the conv layers that have channel dependecy with this layer
# and prune all these layers at the same time.
_names
=
[
x
for
x
in
self
.
channel_depen
[
wrapper
.
name
]]
logger
.
info
(
'Pruning the dependent layers: %s'
,
','
.
join
(
_names
))
_wrappers
=
[
name2wrapper
[
name
]
for
name
in
_names
if
name
in
name2wrapper
]
_wrapper_idxes
=
[
wrapper2index
[
_w
]
for
_w
in
_wrappers
]
masks
=
self
.
_dependency_calc_mask
(
_wrappers
,
_names
,
wrappers_idx
=
_wrapper_idxes
)
if
masks
is
not
None
:
for
layer
in
masks
:
for
mask_type
in
masks
[
layer
]:
assert
hasattr
(
name2wrapper
[
layer
],
mask_type
),
"there is no attribute '%s' in wrapper on %s"
\
%
(
mask_type
,
layer
)
setattr
(
name2wrapper
[
layer
],
mask_type
,
masks
[
layer
][
mask_type
])
nni/algorithms/compression/pytorch/pruning/finegrained_pruning.py
→
nni/algorithms/compression/pytorch/pruning/finegrained_pruning
_masker
.py
View file @
92f6754e
File moved
nni/algorithms/compression/pytorch/pruning/
one_shot
.py
→
nni/algorithms/compression/pytorch/pruning/
iterative_pruner
.py
View file @
92f6754e
...
@@ -2,46 +2,119 @@
...
@@ -2,46 +2,119 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
from
schema
import
And
,
Optional
,
SchemaError
import
copy
from
nni.common.graph_utils
import
TorchModuleGraph
import
torch
from
nni.compression.pytorch.utils.shape_dependency
import
ChannelDependency
,
GroupDependency
from
schema
import
And
,
Optional
from
.constants
import
MASKER_DICT
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Pruner
from
.constants
import
MASKER_DICT
from
.dependency_aware_pruner
import
DependencyAwarePruner
__all__
=
[
'
Level
Pruner'
,
'
Slim
Pruner'
,
'
L1Filter
Pruner'
,
'
L2
FilterPruner'
,
'
FPGM
Pruner'
,
__all__
=
[
'
AGP
Pruner'
,
'
ADMM
Pruner'
,
'
Slim
Pruner'
,
'
TaylorFOWeight
FilterPruner'
,
'
ActivationAPoZRankFilter
Pruner'
,
'TaylorFOWeightFilterPruner'
,
'ActivationAPoZRankFilterPruner'
,
'ActivationMeanRankFilterPruner'
]
'ActivationMeanRankFilterPruner'
]
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
setLevel
(
logging
.
INFO
)
class
OneshotPruner
(
Pruner
):
class
IterativePruner
(
DependencyAware
Pruner
):
"""
"""
Prune model
to an exact pruning level for one time
.
Prune model
during the training process
.
"""
"""
def
__init__
(
self
,
model
,
config_list
,
pruning_algorithm
=
'level'
,
optimizer
=
None
,
**
algo_kwargs
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
pruning_algorithm
=
'slim'
,
trainer
=
None
,
criterion
=
None
,
num_iterations
=
20
,
epochs_per_iteration
=
5
,
dependency_aware
=
False
,
dummy_input
=
None
,
**
algo_kwargs
):
"""
"""
Parameters
Parameters
----------
----------
model
: torch.nn.Module
model: torch.nn.Module
Model to be pruned
Model to be pruned
config_list
: list
config_list: list
List on pruning configs
List on pruning configs
pruning_algorithm: str
algorithms being used to prune model
optimizer: torch.optim.Optimizer
optimizer: torch.optim.Optimizer
Optimizer used to train model
Optimizer used to train model
pruning_algorithm: str
algorithms being used to prune model
trainer: function
Function used to train the model.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion: function
Function used to calculate the loss between the target and the output.
num_iterations: int
Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
epochs_per_iteration: Union[int, list]
The number of training epochs for each iteration. `int` represents the same value for each iteration.
`list` represents the specific value for each iteration.
dependency_aware: bool
If prune the model in a dependency-aware way.
dummy_input: torch.Tensor
The dummy input to analyze the topology constraints. Note that,
the dummy_input should on the same device with the model.
algo_kwargs: dict
algo_kwargs: dict
Additional parameters passed to pruning algorithm masker class
Additional parameters passed to pruning algorithm masker class
"""
"""
super
().
__init__
(
model
,
config_list
,
optimizer
,
pruning_algorithm
,
dependency_aware
,
dummy_input
,
**
algo_kwargs
)
if
isinstance
(
epochs_per_iteration
,
list
):
assert
len
(
epochs_per_iteration
)
==
num_iterations
,
'num_iterations should equal to the length of epochs_per_iteration'
self
.
epochs_per_iteration
=
epochs_per_iteration
else
:
self
.
epochs_per_iteration
=
[
epochs_per_iteration
]
*
num_iterations
self
.
_trainer
=
trainer
self
.
_criterion
=
criterion
super
().
__init__
(
model
,
config_list
,
optimizer
)
def
_fresh_calculated
(
self
):
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
if_calculated
=
False
def
compress
(
self
):
training
=
self
.
bound_model
.
training
self
.
bound_model
.
train
()
for
_
,
epochs_num
in
enumerate
(
self
.
epochs_per_iteration
):
self
.
_fresh_calculated
()
for
epoch
in
range
(
epochs_num
):
self
.
_trainer
(
self
.
bound_model
,
optimizer
=
self
.
optimizer
,
criterion
=
self
.
_criterion
,
epoch
=
epoch
)
self
.
update_mask
()
self
.
bound_model
.
train
(
training
)
return
self
.
bound_model
class
AGPPruner
(
IterativePruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : listlist
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : See supported type in your specific pruning algorithm.
optimizer: torch.optim.Optimizer
Optimizer used to train model.
trainer: function
Function to train the model
criterion: function
Function used to calculate the loss between the target and the output.
num_iterations: int
Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
epochs_per_iteration: int
The number of training epochs for each iteration.
pruning_algorithm: str
Algorithms being used to prune model,
choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
trainer
,
criterion
,
num_iterations
=
10
,
epochs_per_iteration
=
1
,
pruning_algorithm
=
'level'
):
super
().
__init__
(
model
,
config_list
,
optimizer
=
optimizer
,
trainer
=
trainer
,
criterion
=
criterion
,
num_iterations
=
num_iterations
,
epochs_per_iteration
=
epochs_per_iteration
)
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self
.
masker
=
MASKER_DICT
[
pruning_algorithm
](
model
,
self
)
self
.
now_epoch
=
0
self
.
freq
=
epochs_per_iteration
self
.
end_epoch
=
epochs_per_iteration
*
num_iterations
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
masker
=
MASKER_DICT
[
pruning_algorithm
](
model
,
self
,
**
algo_kwargs
)
def
validate_config
(
self
,
model
,
config_list
):
def
validate_config
(
self
,
model
,
config_list
):
"""
"""
...
@@ -53,276 +126,259 @@ class OneshotPruner(Pruner):
...
@@ -53,276 +126,259 @@ class OneshotPruner(Pruner):
List on pruning configs
List on pruning configs
"""
"""
schema
=
CompressorSchema
([{
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
=
n
<
=
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
}],
model
,
logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
_supported_dependency_aware
(
self
):
return
False
def
calc_mask
(
self
,
wrapper
,
wrapper_idx
=
None
):
def
calc_mask
(
self
,
wrapper
,
wrapper_idx
=
None
):
"""
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
Parameters
----------
----------
wrapper : Module
wrapper : Module
the
module
to instrument the compression operation
the
layer
to instrument the compression operation
wrapper_idx: int
wrapper_idx: int
index of this wrapper in pruner's all wrappers
index of this wrapper in pruner's all wrappers
Returns
Returns
-------
-------
dict
dict
| None
d
ictionary for storing masks, keys of the dict:
D
ictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
'bias_mask': bias mask tensor (optional)
"""
"""
config
=
wrapper
.
config
if
wrapper
.
if_calculated
:
if
wrapper
.
if_calculated
:
return
None
return
None
sparsity
=
wrapper
.
config
[
'sparsity'
]
if
not
self
.
now_epoch
%
self
.
freq
==
0
:
if
not
wrapper
.
if_calculated
:
masks
=
self
.
masker
.
calc_mask
(
sparsity
=
sparsity
,
wrapper
=
wrapper
,
wrapper_idx
=
wrapper_idx
)
# masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
if
masks
is
not
None
:
wrapper
.
if_calculated
=
True
return
masks
else
:
return
None
return
None
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
new_mask
=
self
.
masker
.
calc_mask
(
sparsity
=
target_sparsity
,
wrapper
=
wrapper
,
wrapper_idx
=
wrapper_idx
)
class
LevelPruner
(
OneshotPruner
):
if
new_mask
is
not
None
:
"""
wrapper
.
if_calculated
=
True
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
return
new_mask
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'level'
,
optimizer
=
optimizer
)
def
compute_target_sparsity
(
self
,
config
):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
initial_sparsity
=
0
self
.
target_sparsity
=
final_sparsity
=
config
.
get
(
'sparsity'
,
0
)
if
initial_sparsity
>=
final_sparsity
:
logger
.
warning
(
'your initial_sparsity >= final_sparsity'
)
return
final_sparsity
if
self
.
end_epoch
==
1
or
self
.
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
span
=
((
self
.
end_epoch
-
1
)
//
self
.
freq
)
*
self
.
freq
assert
span
>
0
self
.
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
(
self
.
now_epoch
/
span
))
**
3
)
return
self
.
target_sparsity
def
update_epoch
(
self
,
epoch
):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
if_calculated
=
False
# TODO: need refactor
def
compress
(
self
):
training
=
self
.
bound_model
.
training
self
.
bound_model
.
train
()
for
epoch
in
range
(
self
.
end_epoch
):
self
.
update_epoch
(
epoch
)
self
.
_trainer
(
self
.
bound_model
,
optimizer
=
self
.
optimizer
,
criterion
=
self
.
_criterion
,
epoch
=
epoch
)
self
.
update_mask
()
logger
.
info
(
f
'sparsity is
{
self
.
target_sparsity
:.
2
f
}
at epoch
{
epoch
}
'
)
self
.
get_pruned_weights
()
self
.
bound_model
.
train
(
training
)
class
SlimPruner
(
OneshotPruner
):
return
self
.
bound_model
class
ADMMPruner
(
IterativePruner
):
"""
"""
A Pytorch implementation of ADMM Pruner algorithm.
Parameters
Parameters
----------
----------
model : torch.nn.Module
model : torch.nn.Module
Model to be pruned
Model to be pruned
.
config_list : list
config_list : list
Supported keys:
List on pruning configs.
- sparsity : This is to specify the sparsity operations to be compressed to.
trainer : function
- op_types : Only BatchNorm2d is supported in Slim Pruner.
Function used for the first subproblem.
optimizer: torch.optim.Optimizer
Users should write this function as a normal function to train the Pytorch model
Optimizer used to train model
and include `model, optimizer, criterion, epoch` as function arguments.
"""
criterion: function
Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner.
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
num_iterations: int
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'slim'
,
optimizer
=
optimizer
)
Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner.
epochs_per_iteration: int
Training epochs of the first subproblem.
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune.
def
validate_config
(
self
,
model
,
config_list
):
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'BatchNorm2d'
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
__init__
(
self
,
model
,
config_list
,
trainer
,
criterion
=
torch
.
nn
.
CrossEntropyLoss
(),
num_iterations
=
30
,
epochs_per_iteration
=
5
,
row
=
1e-4
,
base_algo
=
'l1'
):
self
.
_base_algo
=
base_algo
if
len
(
config_list
)
>
1
:
super
().
__init__
(
model
,
config_list
)
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
self
.
_trainer
=
trainer
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
bound_model
.
parameters
(),
lr
=
1e-3
,
weight_decay
=
5e-5
)
self
.
_criterion
=
criterion
self
.
_num_iterations
=
num_iterations
self
.
_training_epochs
=
epochs_per_iteration
self
.
_row
=
row
class
_StructuredFilterPruner
(
OneshotPruner
):
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
"""
self
.
masker
=
MASKER_DICT
[
self
.
_base_algo
](
self
.
bound_model
,
self
)
_StructuredFilterPruner has two ways to calculate the masks
for conv layers. In the normal way, the _StructuredFilterPruner
will calculate the mask of each layer separately. For example, each
conv layer determine which filters should be pruned according to its L1
norm. In constrast, in the dependency-aware way, the layers that in a
dependency group will be pruned jointly and these layers will be forced
to prune the same channels.
"""
def
__init__
(
self
,
model
,
config_list
,
pruning_algorithm
,
optimizer
=
None
,
dependency_aware
=
False
,
dummy_input
=
None
,
**
algo_kwargs
):
self
.
patch_optimizer_before
(
self
.
_callback
)
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
pruning_algorithm
,
optimizer
=
optimizer
,
**
algo_kwargs
)
self
.
dependency_aware
=
dependency_aware
# set the dependency-aware switch for the masker
self
.
masker
.
dependency_aware
=
dependency_aware
self
.
dummy_input
=
dummy_input
if
self
.
dependency_aware
:
errmsg
=
"When dependency_aware is set, the dummy_input should not be None"
assert
self
.
dummy_input
is
not
None
,
errmsg
# Get the TorchModuleGraph of the target model
# to trace the model, we need to unwrap the wrappers
self
.
_unwrap_model
()
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
)
self
.
_wrap_model
()
self
.
channel_depen
=
ChannelDependency
(
traced_model
=
self
.
graph
.
trace
)
self
.
group_depen
=
GroupDependency
(
traced_model
=
self
.
graph
.
trace
)
self
.
channel_depen
=
self
.
channel_depen
.
dependency_sets
self
.
channel_depen
=
{
name
:
sets
for
sets
in
self
.
channel_depen
for
name
in
sets
}
self
.
group_depen
=
self
.
group_depen
.
dependency_sets
def
update_mask
(
self
):
if
not
self
.
dependency_aware
:
# if we use the normal way to update the mask,
# then call the update_mask of the father class
super
(
_StructuredFilterPruner
,
self
).
update_mask
()
else
:
# if we update the mask in a dependency-aware way
# then we call _dependency_update_mask
self
.
_dependency_update_mask
()
def
validate_config
(
self
,
model
,
config_list
):
def
validate_config
(
self
,
model
,
config_list
):
schema
=
CompressorSchema
([{
"""
Optional
(
'sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Parameters
Optional
(
'op_types'
):
[
'Conv2d'
],
----------
Optional
(
'op_names'
):
[
str
],
model : torch.nn.Module
Optional
(
'exclude'
):
bool
Model to be pruned
}],
model
,
logger
)
config_list : list
List on pruning configs
"""
if
self
.
_base_algo
==
'level'
:
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
}],
model
,
logger
)
elif
self
.
_base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
for
config
in
config_list
:
if
'exclude'
not
in
config
and
'sparsity'
not
in
config
:
raise
SchemaError
(
'Either sparisty or exclude must be specified!'
)
def
_dependency_calc_mask
(
self
,
wrappers
,
channel_dsets
,
wrappers_idx
=
None
):
def
_supported_dependency_aware
(
self
):
"""
return
False
calculate the masks for the conv layers in the same
channel dependecy set. All the layers passed in have
def
_projection
(
self
,
weight
,
sparsity
,
wrapper
):
the same number of channels.
'''
Return the Euclidean projection of the weight matrix according to the pruning mode.
Parameters
Parameters
----------
----------
wrappers: list
weight : tensor
The list of the wrappers that in the same channel dependency
original matrix
set.
sparsity : float
wrappers_idx: list
the ratio of parameters which need to be set to zero
The list of the indexes of wrapppers.
wrapper: PrunerModuleWrapper
layer wrapper of this layer
Returns
Returns
-------
-------
masks: dict
tensor
A dict object that contains the masks of the layers in this
the projected matrix
dependency group, the key is the name of the convolutional layers.
'''
wrapper_copy
=
copy
.
deepcopy
(
wrapper
)
wrapper_copy
.
module
.
weight
.
data
=
weight
return
weight
.
data
.
mul
(
self
.
masker
.
calc_mask
(
sparsity
,
wrapper_copy
)[
'weight_mask'
])
def
_callback
(
self
):
# callback function to do additonal optimization, refer to the deriatives of Formula (7)
for
i
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
wrapper
.
module
.
weight
.
data
-=
self
.
_row
*
\
(
wrapper
.
module
.
weight
.
data
-
self
.
Z
[
i
]
+
self
.
U
[
i
])
def
compress
(
self
):
"""
"""
# The number of the groups for each conv layers
Compress the model with ADMM.
# Note that, this number may be different from its
# original number of groups of filters.
Returns
groups
=
[
self
.
group_depen
[
_w
.
name
]
for
_w
in
wrappers
]
-------
sparsities
=
[
_w
.
config
[
'sparsity'
]
for
_w
in
wrappers
]
torch.nn.Module
masks
=
self
.
masker
.
calc_mask
(
model with specified modules compressed.
sparsities
,
wrappers
,
wrappers_idx
,
channel_dsets
=
channel_dsets
,
groups
=
groups
)
if
masks
is
not
None
:
# if masks is None, then the mask calculation fails.
# for example, in activation related maskers, we should
# pass enough batches of data to the model, so that the
# masks can be calculated successfully.
for
_w
in
wrappers
:
_w
.
if_calculated
=
True
return
masks
def
_dependency_update_mask
(
self
):
"""
In the original update_mask, the wraper of each layer will update its
own mask according to the sparsity specified in the config_list. However, in
the _dependency_update_mask, we may prune several layers at the same
time according the sparsities and the channel/group dependencies.
"""
"""
name2wrapper
=
{
x
.
name
:
x
for
x
in
self
.
get_modules_wrapper
()}
logger
.
info
(
'Starting ADMM Compression...'
)
wrapper2index
=
{
x
:
i
for
i
,
x
in
enumerate
(
self
.
get_modules_wrapper
())}
# initiaze Z, U
# Z_i^0 = W_i^0
# U_i^0 = 0
self
.
Z
=
[]
self
.
U
=
[]
for
wrapper
in
self
.
get_modules_wrapper
():
for
wrapper
in
self
.
get_modules_wrapper
():
if
wrapper
.
if_calculated
:
z
=
wrapper
.
module
.
weight
.
data
continue
self
.
Z
.
append
(
z
)
# find all the conv layers that have channel dependecy with this layer
self
.
U
.
append
(
torch
.
zeros_like
(
z
))
# and prune all these layers at the same time.
_names
=
[
x
for
x
in
self
.
channel_depen
[
wrapper
.
name
]]
logger
.
info
(
'Pruning the dependent layers: %s'
,
','
.
join
(
_names
))
_wrappers
=
[
name2wrapper
[
name
]
for
name
in
_names
if
name
in
name2wrapper
]
_wrapper_idxes
=
[
wrapper2index
[
_w
]
for
_w
in
_wrappers
]
masks
=
self
.
_dependency_calc_mask
(
_wrappers
,
_names
,
wrappers_idx
=
_wrapper_idxes
)
if
masks
is
not
None
:
for
layer
in
masks
:
for
mask_type
in
masks
[
layer
]:
assert
hasattr
(
name2wrapper
[
layer
],
mask_type
),
"there is no attribute '%s' in wrapper on %s"
%
(
mask_type
,
layer
)
setattr
(
name2wrapper
[
layer
],
mask_type
,
masks
[
layer
][
mask_type
])
class
L1FilterPruner
(
_StructuredFilterPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dependency_aware
=
False
,
dummy_input
=
None
):
# Loss = cross_entropy + l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
,
optimizer
=
optimizer
,
# optimization iteration
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
for
k
in
range
(
self
.
_num_iterations
):
logger
.
info
(
'ADMM iteration : %d'
,
k
)
# step 1: optimize W with AdamOptimizer
for
epoch
in
range
(
self
.
_training_epochs
):
self
.
_trainer
(
self
.
bound_model
,
optimizer
=
self
.
optimizer
,
criterion
=
self
.
_criterion
,
epoch
=
epoch
)
class
L2FilterPruner
(
_StructuredFilterPruner
):
# step 2: update Z, U
"""
# Z_i^{k+1} = projection(W_i^{k+1} + U_i^k)
Parameters
# U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
----------
for
i
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
model : torch.nn.Module
z
=
wrapper
.
module
.
weight
.
data
+
self
.
U
[
i
]
Model to be pruned
self
.
Z
[
i
]
=
self
.
_projection
(
z
,
wrapper
.
config
[
'sparsity'
],
wrapper
)
config_list : list
self
.
U
[
i
]
=
self
.
U
[
i
]
+
wrapper
.
module
.
weight
.
data
-
self
.
Z
[
i
]
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dependency_aware
=
False
,
dummy_input
=
None
):
# apply prune
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l2'
,
optimizer
=
optimizer
,
self
.
update_mask
()
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
logger
.
info
(
'Compression finished.'
)
return
self
.
bound_model
class
FPGM
Pruner
(
_StructuredFilter
Pruner
):
class
Slim
Pruner
(
Iterative
Pruner
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -331,9 +387,19 @@ class FPGMPruner(_StructuredFilterPruner):
...
@@ -331,9 +387,19 @@ class FPGMPruner(_StructuredFilterPruner):
config_list : list
config_list : list
Supported keys:
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only
Conv
2d is supported in
FPGM
Pruner.
- op_types : Only
BatchNorm
2d is supported in
Slim
Pruner.
optimizer: torch.optim.Optimizer
optimizer
: torch.optim.Optimizer
Optimizer used to train model
Optimizer used to train model
trainer : function
Function used to sparsify BatchNorm2d scaling factors.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
sparsity_training_epochs: int
The number of channel sparsity regularization training epochs before pruning.
scale : float
Penalty parameters for sparsification.
dependency_aware: bool
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
prune the model according to the l2-norm of weights and the channel-dependency or
...
@@ -347,12 +413,35 @@ class FPGMPruner(_StructuredFilterPruner):
...
@@ -347,12 +413,35 @@ class FPGMPruner(_StructuredFilterPruner):
should on the same device with the model.
should on the same device with the model.
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dependency_aware
=
False
,
dummy_input
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
trainer
,
criterion
,
sparsity_training_epochs
=
10
,
scale
=
0.0001
,
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'fpgm'
,
dependency_aware
=
False
,
dummy_input
=
None
):
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
,
optimizer
=
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
=
optimizer
,
pruning_algorithm
=
'slim'
,
trainer
=
trainer
,
criterion
=
criterion
,
num_iterations
=
1
,
epochs_per_iteration
=
sparsity_training_epochs
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
self
.
scale
=
scale
self
.
patch_optimizer_before
(
self
.
_callback
)
def
validate_config
(
self
,
model
,
config_list
):
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'BatchNorm2d'
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
def
_supported_dependency_aware
(
self
):
return
True
def
_callback
(
self
):
for
_
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
wrapper
.
module
.
weight
.
grad
.
data
.
add_
(
self
.
scale
*
torch
.
sign
(
wrapper
.
module
.
weight
.
data
))
class
TaylorFOWeightFilterPruner
(
_StructuredFilterPruner
):
class
TaylorFOWeightFilterPruner
(
IterativePruner
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -364,8 +453,14 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
...
@@ -364,8 +453,14 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
optimizer: torch.optim.Optimizer
optimizer: torch.optim.Optimizer
Optimizer used to train model
Optimizer used to train model
statistics_batch_num: int
trainer : function
The number of batches to statistic the activation.
Function used to sparsify BatchNorm2d scaling factors.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
sparsity_training_epochs: int
The number of epochs to collect the contributions.
dependency_aware: bool
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
prune the model according to the l2-norm of weights and the channel-dependency or
...
@@ -380,14 +475,17 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
...
@@ -380,14 +475,17 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
statistics_batch_num
=
1
,
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
trainer
,
criterion
,
sparsity_training_epochs
=
1
,
dependency_aware
=
False
,
dependency_aware
=
False
,
dummy_input
=
None
):
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'taylorfo'
,
super
().
__init__
(
model
,
config_list
,
optimizer
=
optimizer
,
pruning_algorithm
=
'taylorfo'
,
trainer
=
trainer
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
,
criterion
=
criterion
,
num_iterations
=
1
,
epochs_per_iteration
=
sparsity_training_epochs
,
optimizer
=
optimizer
,
statistics_batch_num
=
statistics_batch_num
)
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
_supported_dependency_aware
(
self
):
return
True
class
ActivationAPoZRankFilterPruner
(
_StructuredFilter
Pruner
):
class
ActivationAPoZRankFilterPruner
(
Iterative
Pruner
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -399,10 +497,16 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
...
@@ -399,10 +497,16 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
optimizer: torch.optim.Optimizer
optimizer: torch.optim.Optimizer
Optimizer used to train model
Optimizer used to train model
trainer: function
Function used to train the model.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
activation: str
activation: str
The activation type.
The activation type.
s
tatistics_batch_num
: int
s
parsity_training_epochs
: int
The number of
bat
ch
e
s to statistic the activation.
The number of
epo
chs to statistic the activation.
dependency_aware: bool
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
prune the model according to the l2-norm of weights and the channel-dependency or
...
@@ -417,14 +521,17 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
...
@@ -417,14 +521,17 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
N
on
e
,
activation
=
'relu'
,
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
trainer
,
criteri
on
,
activation
=
'relu'
,
s
tatistics_batch_num
=
1
,
dependency_aware
=
False
,
dummy_input
=
None
):
s
parsity_training_epochs
=
1
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'apoz'
,
optimizer
=
optimizer
,
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'apoz'
,
optimizer
=
optimizer
,
trainer
=
trainer
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
,
criterion
=
criterion
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
,
activation
=
activation
,
statistics_batch_num
=
statistics_batch_num
)
activation
=
activation
,
num_iterations
=
1
,
epochs_per_iteration
=
sparsity_training_epochs
)
def
_supported_dependency_aware
(
self
):
return
True
class
ActivationMeanRankFilterPruner
(
_StructuredFilterPruner
):
class
ActivationMeanRankFilterPruner
(
IterativePruner
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -436,9 +543,15 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
...
@@ -436,9 +543,15 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
optimizer: torch.optim.Optimizer
optimizer: torch.optim.Optimizer
Optimizer used to train model.
Optimizer used to train model.
trainer: function
Function used to train the model.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
activation: str
activation: str
The activation type.
The activation type.
s
tatistics_batch_num
: int
s
parsity_training_epochs
: int
The number of batches to statistic the activation.
The number of batches to statistic the activation.
dependency_aware: bool
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
If prune the model in a dependency-aware way. If it is `True`, this pruner will
...
@@ -453,8 +566,11 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
...
@@ -453,8 +566,11 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
should on the same device with the model.
should on the same device with the model.
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
activation
=
'relu'
,
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
trainer
,
criterion
,
activation
=
'relu'
,
statistics_batch_num
=
1
,
dependency_aware
=
False
,
dummy_input
=
None
):
sparsity_training_epochs
=
1
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'mean_activation'
,
optimizer
=
optimizer
,
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'mean_activation'
,
optimizer
=
optimizer
,
trainer
=
trainer
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
,
criterion
=
criterion
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
,
activation
=
activation
,
statistics_batch_num
=
statistics_batch_num
)
activation
=
activation
,
num_iterations
=
1
,
epochs_per_iteration
=
sparsity_training_epochs
)
def
_supported_dependency_aware
(
self
):
return
True
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
View file @
92f6754e
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
schema
import
And
,
Optional
from
schema
import
And
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.compressor
import
Pruner
from
.finegrained_pruning
import
LevelPrunerMasker
from
.finegrained_pruning
_masker
import
LevelPrunerMasker
logger
=
logging
.
getLogger
(
'torch pruner'
)
logger
=
logging
.
getLogger
(
'torch pruner'
)
...
...
nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
0 → 100644
View file @
92f6754e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
schema
import
And
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
.dependency_aware_pruner
import
DependencyAwarePruner
__all__
=
[
'LevelPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
]
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
OneshotPruner
(
DependencyAwarePruner
):
"""
Prune model to an exact pruning level for one time.
"""
def
__init__
(
self
,
model
,
config_list
,
pruning_algorithm
=
'level'
,
dependency_aware
=
False
,
dummy_input
=
None
,
**
algo_kwargs
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
pruning_algorithm: str
algorithms being used to prune model
dependency_aware: bool
If prune the model in a dependency-aware way.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that,
the dummy_input should on the same device with the model.
algo_kwargs: dict
Additional parameters passed to pruning algorithm masker class
"""
super
().
__init__
(
model
,
config_list
,
None
,
pruning_algorithm
,
dependency_aware
,
dummy_input
,
**
algo_kwargs
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
class
LevelPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'level'
)
def
_supported_dependency_aware
(
self
):
return
False
class
L1FilterPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
_supported_dependency_aware
(
self
):
return
True
class
L2FilterPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l2'
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
_supported_dependency_aware
(
self
):
return
True
class
FPGMPruner
(
OneshotPruner
):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in FPGM Pruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def
__init__
(
self
,
model
,
config_list
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'fpgm'
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
def
_supported_dependency_aware
(
self
):
return
True
nni/algorithms/compression/pytorch/pruning/structured_pruning.py
→
nni/algorithms/compression/pytorch/pruning/structured_pruning
_masker
.py
View file @
92f6754e
...
@@ -474,8 +474,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
...
@@ -474,8 +474,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
def
__init__
(
self
,
model
,
pruner
,
statistics_batch_num
=
1
):
def
__init__
(
self
,
model
,
pruner
,
statistics_batch_num
=
1
):
super
().
__init__
(
model
,
pruner
)
super
().
__init__
(
model
,
pruner
)
self
.
pruner
.
statistics_batch_num
=
statistics_batch_num
self
.
pruner
.
statistics_batch_num
=
statistics_batch_num
self
.
pruner
.
set_wrappers_attribute
(
"contribution"
,
None
)
self
.
pruner
.
iterations
=
0
self
.
pruner
.
iterations
=
0
self
.
pruner
.
set_wrappers_attribute
(
"contribution"
,
None
)
self
.
pruner
.
patch_optimizer
(
self
.
calc_contributions
)
self
.
pruner
.
patch_optimizer
(
self
.
calc_contributions
)
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
,
wrapper
,
wrapper_idx
,
channel_masks
=
None
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
,
wrapper
,
wrapper_idx
,
channel_masks
=
None
):
...
@@ -499,6 +499,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
...
@@ -499,6 +499,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
"""
"""
if
self
.
pruner
.
iterations
>=
self
.
pruner
.
statistics_batch_num
:
if
self
.
pruner
.
iterations
>=
self
.
pruner
.
statistics_batch_num
:
return
return
for
wrapper
in
self
.
pruner
.
get_modules_wrapper
():
for
wrapper
in
self
.
pruner
.
get_modules_wrapper
():
filters
=
wrapper
.
module
.
weight
.
size
(
0
)
filters
=
wrapper
.
module
.
weight
.
size
(
0
)
contribution
=
(
contribution
=
(
...
@@ -677,16 +678,24 @@ class SlimPrunerMasker(WeightMasker):
...
@@ -677,16 +678,24 @@ class SlimPrunerMasker(WeightMasker):
def
__init__
(
self
,
model
,
pruner
,
**
kwargs
):
def
__init__
(
self
,
model
,
pruner
,
**
kwargs
):
super
().
__init__
(
model
,
pruner
)
super
().
__init__
(
model
,
pruner
)
self
.
global_threshold
=
None
def
_get_global_threshold
(
self
):
weight_list
=
[]
weight_list
=
[]
for
(
layer
,
_
)
in
pruner
.
get_modules_to_compress
():
for
(
layer
,
_
)
in
self
.
pruner
.
get_modules_to_compress
():
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
pruner
.
config_list
[
0
][
'sparsity'
])
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
self
.
pruner
.
config_list
[
0
][
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
print
(
f
'set global threshold to
{
self
.
global_threshold
}
'
)
def
calc_mask
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
def
calc_mask
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
assert
wrapper
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
assert
wrapper
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
self
.
global_threshold
is
None
:
self
.
_get_global_threshold
()
weight
=
wrapper
.
module
.
weight
.
data
.
clone
()
weight
=
wrapper
.
module
.
weight
.
data
.
clone
()
if
wrapper
.
weight_mask
is
not
None
:
if
wrapper
.
weight_mask
is
not
None
:
# apply base mask for iterative pruning
# apply base mask for iterative pruning
...
@@ -706,7 +715,6 @@ class SlimPrunerMasker(WeightMasker):
...
@@ -706,7 +715,6 @@ class SlimPrunerMasker(WeightMasker):
),
'bias_mask'
:
mask_bias
.
detach
()}
),
'bias_mask'
:
mask_bias
.
detach
()}
return
mask
return
mask
def
least_square_sklearn
(
X
,
Y
):
def
least_square_sklearn
(
X
,
Y
):
from
sklearn.linear_model
import
LinearRegression
from
sklearn.linear_model
import
LinearRegression
reg
=
LinearRegression
(
fit_intercept
=
False
)
reg
=
LinearRegression
(
fit_intercept
=
False
)
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
92f6754e
...
@@ -148,6 +148,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -148,6 +148,7 @@ class QAT_Quantizer(Quantizer):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QATGrad
.
apply
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
device
=
next
(
model
.
parameters
()).
device
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
...
@@ -161,7 +162,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -161,7 +162,7 @@ class QAT_Quantizer(Quantizer):
layer
.
module
.
register_buffer
(
'activation_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'activation_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_activation'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_activation'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_activation'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_activation'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
"""
"""
...
@@ -359,7 +360,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -359,7 +360,7 @@ class QAT_Quantizer(Quantizer):
"""
"""
override `compressor` `step` method, quantization only happens after certain number of steps
override `compressor` `step` method, quantization only happens after certain number of steps
"""
"""
self
.
bound_model
.
steps
+=
1
self
.
bound_model
.
steps
+=
1
class
DoReFaQuantizer
(
Quantizer
):
class
DoReFaQuantizer
(
Quantizer
):
...
@@ -370,10 +371,12 @@ class DoReFaQuantizer(Quantizer):
...
@@ -370,10 +371,12 @@ class DoReFaQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
"""
"""
...
@@ -474,11 +477,13 @@ class BNNQuantizer(Quantizer):
...
@@ -474,11 +477,13 @@ class BNNQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
ClipGrad
.
apply
self
.
quant_grad
=
ClipGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
"""
"""
...
@@ -589,6 +594,7 @@ class LsqQuantizer(Quantizer):
...
@@ -589,6 +594,7 @@ class LsqQuantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
QuantForward
()
self
.
quant_grad
=
QuantForward
()
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
...
@@ -631,6 +637,8 @@ class LsqQuantizer(Quantizer):
...
@@ -631,6 +637,8 @@ class LsqQuantizer(Quantizer):
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
input_scale
})
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
input_scale
})
self
.
bound_model
.
to
(
device
)
@
staticmethod
@
staticmethod
def
grad_scale
(
x
,
scale
):
def
grad_scale
(
x
,
scale
):
"""
"""
...
...
nni/algorithms/compression/tensorflow/pruning/__init__.py
View file @
92f6754e
from
.one_shot
import
*
from
.one_shot
_pruner
import
*
nni/algorithms/compression/tensorflow/pruning/one_shot.py
→
nni/algorithms/compression/tensorflow/pruning/one_shot
_pruner
.py
View file @
92f6754e
File moved
nni/compression/pytorch/compressor.py
View file @
92f6754e
...
@@ -8,7 +8,6 @@ from . import default_layers
...
@@ -8,7 +8,6 @@ from . import default_layers
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
class
LayerInfo
:
class
LayerInfo
:
def
__init__
(
self
,
name
,
module
):
def
__init__
(
self
,
name
,
module
):
self
.
module
=
module
self
.
module
=
module
...
@@ -235,7 +234,6 @@ class Compressor:
...
@@ -235,7 +234,6 @@ class Compressor:
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
add_activation_collector
(
self
,
collector
):
def
add_activation_collector
(
self
,
collector
):
self
.
_fwd_hook_id
+=
1
self
.
_fwd_hook_id
+=
1
self
.
_fwd_hook_handles
[
self
.
_fwd_hook_id
]
=
[]
self
.
_fwd_hook_handles
[
self
.
_fwd_hook_id
]
=
[]
...
@@ -264,6 +262,18 @@ class Compressor:
...
@@ -264,6 +262,18 @@ class Compressor:
if
self
.
optimizer
is
not
None
:
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
step
=
types
.
MethodType
(
patch_step
(
self
.
optimizer
.
step
),
self
.
optimizer
)
self
.
optimizer
.
step
=
types
.
MethodType
(
patch_step
(
self
.
optimizer
.
step
),
self
.
optimizer
)
def
patch_optimizer_before
(
self
,
*
tasks
):
def
patch_step
(
old_step
):
def
new_step
(
_
,
*
args
,
**
kwargs
):
for
task
in
tasks
:
task
()
# call origin optimizer step method
output
=
old_step
(
*
args
,
**
kwargs
)
return
output
return
new_step
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
step
=
types
.
MethodType
(
patch_step
(
self
.
optimizer
.
step
),
self
.
optimizer
)
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
"""
"""
...
@@ -319,8 +329,6 @@ class Pruner(Compressor):
...
@@ -319,8 +329,6 @@ class Pruner(Compressor):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
if
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
update_mask
)
def
compress
(
self
):
def
compress
(
self
):
self
.
update_mask
()
self
.
update_mask
()
...
@@ -386,7 +394,7 @@ class Pruner(Compressor):
...
@@ -386,7 +394,7 @@ class Pruner(Compressor):
"""
"""
assert
model_path
is
not
None
,
'model_path must be specified'
assert
model_path
is
not
None
,
'model_path must be specified'
mask_dict
=
{}
mask_dict
=
{}
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
for
wrapper
in
self
.
get_modules_wrapper
():
for
wrapper
in
self
.
get_modules_wrapper
():
weight_mask
=
wrapper
.
weight_mask
weight_mask
=
wrapper
.
weight_mask
...
@@ -433,6 +441,27 @@ class Pruner(Compressor):
...
@@ -433,6 +441,27 @@ class Pruner(Compressor):
else
:
else
:
self
.
bound_model
.
load_state_dict
(
model_state
)
self
.
bound_model
.
load_state_dict
(
model_state
)
def
get_pruned_weights
(
self
,
dim
=
0
):
"""
Log the simulated prune sparsity.
Parameters
----------
dim : int
the pruned dim.
"""
for
_
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
weight_mask
=
wrapper
.
weight_mask
mask_size
=
weight_mask
.
size
()
if
len
(
mask_size
)
==
1
:
index
=
torch
.
nonzero
(
weight_mask
.
abs
()
!=
0
).
tolist
()
else
:
sum_idx
=
list
(
range
(
len
(
mask_size
)))
sum_idx
.
remove
(
dim
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
).
tolist
()
_logger
.
info
(
f
'simulated prune
{
wrapper
.
name
}
remain/total:
{
len
(
index
)
}
/
{
weight_mask
.
size
(
dim
)
}
'
)
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
"""
"""
...
@@ -549,7 +578,6 @@ class Quantizer(Compressor):
...
@@ -549,7 +578,6 @@ class Quantizer(Compressor):
"""
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
def
_wrap_modules
(
self
,
layer
,
config
):
def
_wrap_modules
(
self
,
layer
,
config
):
"""
"""
Create a wrapper forward function to replace the original one.
Create a wrapper forward function to replace the original one.
...
@@ -571,8 +599,8 @@ class Quantizer(Compressor):
...
@@ -571,8 +599,8 @@ class Quantizer(Compressor):
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
\
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
input_shape
=
None
,
device
=
None
):
"""
"""
This method helps save pytorch model, calibration config, onnx model in quantizer.
This method helps save pytorch model, calibration config, onnx model in quantizer.
...
@@ -671,6 +699,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -671,6 +699,7 @@ class QuantGrad(torch.autograd.Function):
quantized x without clamped
quantized x without clamped
"""
"""
return
((
x
/
scale
)
+
zero_point
).
round
()
return
((
x
/
scale
)
+
zero_point
).
round
()
@
classmethod
@
classmethod
def
get_bits_length
(
cls
,
config
,
quant_type
):
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
"""
...
@@ -703,8 +732,8 @@ class QuantGrad(torch.autograd.Function):
...
@@ -703,8 +732,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
scale : Tensor
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
`QuantType.QUANT_OUTPUT`,
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
you can define different behavior for different types.
`QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
zero_point : Tensor
zero_point : Tensor
zero_point for quantizing tensor
zero_point for quantizing tensor
qmin : Tensor
qmin : Tensor
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment