Unverified Commit 92f6754e authored by colorjam's avatar colorjam Committed by GitHub
Browse files

[Model Compression] Update api of iterative pruners (#3507)

parent 26f47727
......@@ -10,15 +10,16 @@ import logging
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from models.mnist.lenet import LeNet
from nni.algorithms.compression.pytorch.pruning import LevelPruner
import nni
import sys
sys.path.append('../models')
from mnist.lenet import LeNet
_logger = logging.getLogger('mnist_example')
_logger.setLevel(logging.INFO)
......@@ -108,7 +109,7 @@ def main(args):
'op_types': ['default'],
}]
pruner = LevelPruner(model, prune_config, optimizer_finetune)
pruner = LevelPruner(model, prune_config)
model = pruner.compress()
# fine-tuning
......@@ -149,5 +150,4 @@ if __name__ == '__main__':
help='target overall target sparsity')
args = parser.parse_args()
main(args)
\ No newline at end of file
main(args)
......@@ -31,7 +31,6 @@ class VGG_Cifar10(nn.Module):
nn.BatchNorm2d(256, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-4, momentum=0.1),
nn.Hardtanh(inplace=True),
......
......@@ -3,27 +3,9 @@ import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import DoReFaQuantizer
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
def forward(self, x):
x = self.relu1(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = self.relu2(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
import sys
sys.path.append('../models')
from mnist.naive import NaiveModel
def train(model, quantizer, device, train_loader, optimizer):
......@@ -66,7 +48,7 @@ def main():
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
model = NaiveModel()
model = model.to(device)
configure_list = [{
'quant_types': ['weight'],
......
......@@ -3,28 +3,9 @@ import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
def forward(self, x):
x = self.relu1(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = self.relu2(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
import sys
sys.path.append('../models')
from mnist.naive import NaiveModel
def train(model, quantizer, device, train_loader, optimizer):
model.train()
......@@ -66,7 +47,7 @@ def main():
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
model = NaiveModel()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
......
......@@ -5,28 +5,9 @@ from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
import sys
sys.path.append('../models')
from mnist.naive import NaiveModel
def train(model, device, train_loader, optimizer):
......@@ -74,7 +55,7 @@ def test_trt(engine, test_loader):
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))
def post_training_quantization_example(train_loader, test_loader, device):
model = Mnist()
model = NaiveModel()
config = {
'conv1':{'weight_bit':8, 'activation_bit':8},
......@@ -99,7 +80,7 @@ def post_training_quantization_example(train_loader, test_loader, device):
test_trt(engine, test_loader)
def quantization_aware_training_example(train_loader, test_loader, device):
model = Mnist()
model = NaiveModel()
configure_list = [{
'quant_types': ['weight', 'output'],
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .finegrained_pruning import *
from .structured_pruning import *
from .one_shot import *
from .agp import *
from .finegrained_pruning_masker import *
from .structured_pruning_masker import *
from .one_shot_pruner import *
from .iterative_pruner import *
from .lottery_ticket import LotteryTicketPruner
from .simulated_annealing_pruner import SimulatedAnnealingPruner
from .net_adapt_pruner import NetAdaptPruner
from .admm_pruner import ADMMPruner
from .auto_compress_pruner import AutoCompressPruner
from .sensitivity_pruner import SensitivityPruner
from .amc import AMCPruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from schema import And, Optional
import copy
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from .constants import MASKER_DICT
from .one_shot import OneshotPruner
_logger = logging.getLogger(__name__)
class ADMMPruner(OneshotPruner):
"""
A Pytorch implementation of ADMM Pruner algorithm.
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : list
List on pruning configs.
trainer : function
Function used for the first subproblem.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
num_iterations : int
Total number of iterations.
training_epochs : int
Training epochs of the first subproblem.
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
"""
def __init__(self, model, config_list, trainer, num_iterations=30, training_epochs=5, row=1e-4, base_algo='l1'):
self._base_algo = base_algo
super().__init__(model, config_list)
self._trainer = trainer
self._num_iterations = num_iterations
self._training_epochs = training_epochs
self._row = row
self.set_wrappers_attribute("if_calculated", False)
self.masker = MASKER_DICT[self._base_algo](self.bound_model, self)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
}], model, _logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
}], model, _logger)
schema.validate(config_list)
def _projection(self, weight, sparsity, wrapper):
'''
Return the Euclidean projection of the weight matrix according to the pruning mode.
Parameters
----------
weight : tensor
original matrix
sparsity : float
the ratio of parameters which need to be set to zero
wrapper: PrunerModuleWrapper
layer wrapper of this layer
Returns
-------
tensor
the projected matrix
'''
wrapper_copy = copy.deepcopy(wrapper)
wrapper_copy.module.weight.data = weight
return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask'])
def compress(self):
"""
Compress the model with ADMM.
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
_logger.info('Starting ADMM Compression...')
# initiaze Z, U
# Z_i^0 = W_i^0
# U_i^0 = 0
Z = []
U = []
for wrapper in self.get_modules_wrapper():
z = wrapper.module.weight.data
Z.append(z)
U.append(torch.zeros_like(z))
optimizer = torch.optim.Adam(
self.bound_model.parameters(), lr=1e-3, weight_decay=5e-5)
# Loss = cross_entropy + l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2
criterion = torch.nn.CrossEntropyLoss()
# callback function to do additonal optimization, refer to the deriatives of Formula (7)
def callback():
for i, wrapper in enumerate(self.get_modules_wrapper()):
wrapper.module.weight.data -= self._row * \
(wrapper.module.weight.data - Z[i] + U[i])
# optimization iteration
for k in range(self._num_iterations):
_logger.info('ADMM iteration : %d', k)
# step 1: optimize W with AdamOptimizer
for epoch in range(self._training_epochs):
self._trainer(self.bound_model, optimizer=optimizer,
criterion=criterion, epoch=epoch, callback=callback)
# step 2: update Z, U
# Z_i^{k+1} = projection(W_i^{k+1} + U_i^k)
# U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
for i, wrapper in enumerate(self.get_modules_wrapper()):
z = wrapper.module.weight.data + U[i]
Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
U[i] = U[i] + wrapper.module.weight.data - Z[i]
# apply prune
self.update_mask()
_logger.info('Compression finished.')
return self.bound_model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices.
"""
import logging
import torch
from schema import And, Optional
from .constants import MASKER_DICT
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.compressor import Pruner
__all__ = ['AGPPruner']
logger = logging.getLogger('torch pruner')
class AGPPruner(Pruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : listlist
Supported keys:
- initial_sparsity: This is to specify the sparsity when compressor starts to compress.
- final_sparsity: This is to specify the sparsity when compressor finishes to compress.
- start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
- end_epoch: This is to specify the epoch number when compressor finishes to compress.
- frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
optimizer: torch.optim.Optimizer
Optimizer used to train model.
pruning_algorithm: str
Algorithms being used to prune model,
choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
"""
def __init__(self, model, config_list, optimizer, pruning_algorithm='level'):
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.masker = MASKER_DICT[pruning_algorithm](model, self)
self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'initial_sparsity': And(float, lambda n: 0 <= n <= 1),
'final_sparsity': And(float, lambda n: 0 <= n <= 1),
'start_epoch': And(int, lambda n: n >= 0),
'end_epoch': And(int, lambda n: n >= 0),
'frequency': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict | None
Dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
config = wrapper.config
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if wrapper.if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
target_sparsity = self.compute_target_sparsity(config)
new_mask = self.masker.calc_mask(sparsity=target_sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
if new_mask is not None:
wrapper.if_calculated = True
return new_mask
def compute_target_sparsity(self, config):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0)
if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
return final_sparsity
if end_epoch <= self.now_epoch:
return final_sparsity
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity
def update_epoch(self, epoch):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if epoch > 0:
self.now_epoch = epoch
for wrapper in self.get_modules_wrapper():
wrapper.if_calculated = False
......@@ -13,8 +13,7 @@ from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from .simulated_annealing_pruner import SimulatedAnnealingPruner
from .admm_pruner import ADMMPruner
from .iterative_pruner import ADMMPruner
_logger = logging.getLogger(__name__)
......@@ -34,26 +33,7 @@ class AutoCompressPruner(Pruner):
trainer : function
Function used for the first subproblem of ADMM Pruner.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
and include `model, optimizer, criterion, epoch` as function arguments.
evaluator : function
function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value.
......@@ -80,8 +60,8 @@ class AutoCompressPruner(Pruner):
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Start temperature of the simulated annealing process.
stop_temperature : float
......@@ -92,7 +72,7 @@ class AutoCompressPruner(Pruner):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
admm_num_iterations : int
Number of iterations of ADMM Pruner.
admm_training_epochs : int
admm_epochs_per_iteration : int
Training epochs of the first optimization subproblem of ADMMPruner.
row : float
Penalty parameters for ADMM training.
......@@ -100,18 +80,19 @@ class AutoCompressPruner(Pruner):
PATH to store temporary experiment data.
"""
def __init__(self, model, config_list, trainer, evaluator, dummy_input,
def __init__(self, model, config_list, trainer, criterion, evaluator, dummy_input,
num_iterations=3, optimize_mode='maximize', base_algo='l1',
# SimulatedAnnealing related
start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35,
# ADMM related
admm_num_iterations=30, admm_training_epochs=5, row=1e-4,
admm_num_iterations=30, admm_epochs_per_iteration=5, row=1e-4,
experiment_data_dir='./'):
# original model
self._model_to_prune = model
self._base_algo = base_algo
self._trainer = trainer
self._criterion = criterion
self._evaluator = evaluator
self._dummy_input = dummy_input
self._num_iterations = num_iterations
......@@ -125,7 +106,7 @@ class AutoCompressPruner(Pruner):
# hyper parameters for ADMM algorithm
self._admm_num_iterations = admm_num_iterations
self._admm_training_epochs = admm_training_epochs
self._admm_epochs_per_iteration = admm_epochs_per_iteration
self._row = row
# overall pruning rate
......@@ -174,12 +155,12 @@ class AutoCompressPruner(Pruner):
"""
_logger.info('Starting AutoCompress pruning...')
sparsity_each_round = 1 - pow(1-self._sparsity, 1/self._num_iterations)
sparsity_each_round = 1 - pow(1 - self._sparsity, 1 / self._num_iterations)
for i in range(self._num_iterations):
_logger.info('Pruning iteration: %d', i)
_logger.info('Target sparsity this round: %s',
1-pow(1-sparsity_each_round, i+1))
1 - pow(1 - sparsity_each_round, i + 1))
# SimulatedAnnealingPruner
_logger.info(
......@@ -204,9 +185,10 @@ class AutoCompressPruner(Pruner):
ADMMpruner = ADMMPruner(
model=copy.deepcopy(self._model_to_prune),
config_list=config_list,
criterion=self._criterion,
trainer=self._trainer,
num_iterations=self._admm_num_iterations,
training_epochs=self._admm_training_epochs,
epochs_per_iteration=self._admm_epochs_per_iteration,
row=self._row,
base_algo=self._base_algo)
ADMMpruner.compress()
......@@ -214,12 +196,13 @@ class AutoCompressPruner(Pruner):
ADMMpruner.export_model(os.path.join(self._experiment_data_dir, 'model_admm_masked.pth'), os.path.join(
self._experiment_data_dir, 'mask.pth'))
# use speed up to prune the model before next iteration, because SimulatedAnnealingPruner & ADMMPruner don't take masked models
# use speed up to prune the model before next iteration,
# because SimulatedAnnealingPruner & ADMMPruner don't take masked models
self._model_to_prune.load_state_dict(torch.load(os.path.join(
self._experiment_data_dir, 'model_admm_masked.pth')))
masks_file = os.path.join(self._experiment_data_dir, 'mask.pth')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = next(self._model_to_prune.parameters()).device
_logger.info('Speeding up models...')
m_speedup = ModelSpeedup(self._model_to_prune, self._dummy_input, masks_file, device)
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from .one_shot import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner
from .one_shot_pruner import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner
PRUNER_DICT = {
'level': LevelPruner,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from schema import And, Optional, SchemaError
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.compressor import Pruner
from .constants import MASKER_DICT
__all__ = ['DependencyAwarePruner']
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class DependencyAwarePruner(Pruner):
"""
DependencyAwarePruner has two ways to calculate the masks
for conv layers. In the normal way, the DependencyAwarePruner
will calculate the mask of each layer separately. For example, each
conv layer determine which filters should be pruned according to its L1
norm. In constrast, in the dependency-aware way, the layers that in a
dependency group will be pruned jointly and these layers will be forced
to prune the same channels.
"""
def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False,
dummy_input=None, **algo_kwargs):
super().__init__(model, config_list=config_list, optimizer=optimizer)
self.dependency_aware = dependency_aware
self.dummy_input = dummy_input
if self.dependency_aware:
if not self._supported_dependency_aware():
raise ValueError('This pruner does not support dependency aware!')
errmsg = "When dependency_aware is set, the dummy_input should not be None"
assert self.dummy_input is not None, errmsg
# Get the TorchModuleGraph of the target model
# to trace the model, we need to unwrap the wrappers
self._unwrap_model()
self.graph = TorchModuleGraph(model, dummy_input)
self._wrap_model()
self.channel_depen = ChannelDependency(
traced_model=self.graph.trace)
self.group_depen = GroupDependency(traced_model=self.graph.trace)
self.channel_depen = self.channel_depen.dependency_sets
self.channel_depen = {
name: sets for sets in self.channel_depen for name in sets}
self.group_depen = self.group_depen.dependency_sets
self.masker = MASKER_DICT[pruning_algorithm](
model, self, **algo_kwargs)
# set the dependency-aware switch for the masker
self.masker.dependency_aware = dependency_aware
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, wrapper, wrapper_idx=None):
if not wrapper.if_calculated:
sparsity = wrapper.config['sparsity']
masks = self.masker.calc_mask(
sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
# masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
if masks is not None:
wrapper.if_calculated = True
return masks
else:
return None
def update_mask(self):
if not self.dependency_aware:
# if we use the normal way to update the mask,
# then call the update_mask of the father class
super(DependencyAwarePruner, self).update_mask()
else:
# if we update the mask in a dependency-aware way
# then we call _dependency_update_mask
self._dependency_update_mask()
def validate_config(self, model, config_list):
schema = CompressorSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
for config in config_list:
if 'exclude' not in config and 'sparsity' not in config:
raise SchemaError('Either sparisty or exclude must be specified!')
def _supported_dependency_aware(self):
raise NotImplementedError
def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
"""
calculate the masks for the conv layers in the same
channel dependecy set. All the layers passed in have
the same number of channels.
Parameters
----------
wrappers: list
The list of the wrappers that in the same channel dependency
set.
wrappers_idx: list
The list of the indexes of wrapppers.
Returns
-------
masks: dict
A dict object that contains the masks of the layers in this
dependency group, the key is the name of the convolutional layers.
"""
# The number of the groups for each conv layers
# Note that, this number may be different from its
# original number of groups of filters.
groups = [self.group_depen[_w.name] for _w in wrappers]
sparsities = [_w.config['sparsity'] for _w in wrappers]
masks = self.masker.calc_mask(
sparsities, wrappers, wrappers_idx, channel_dsets=channel_dsets, groups=groups)
if masks is not None:
# if masks is None, then the mask calculation fails.
# for example, in activation related maskers, we should
# pass enough batches of data to the model, so that the
# masks can be calculated successfully.
for _w in wrappers:
_w.if_calculated = True
return masks
def _dependency_update_mask(self):
"""
In the original update_mask, the wraper of each layer will update its
own mask according to the sparsity specified in the config_list. However, in
the _dependency_update_mask, we may prune several layers at the same
time according the sparsities and the channel/group dependencies.
"""
name2wrapper = {x.name: x for x in self.get_modules_wrapper()}
wrapper2index = {x: i for i, x in enumerate(self.get_modules_wrapper())}
for wrapper in self.get_modules_wrapper():
if wrapper.if_calculated:
continue
# find all the conv layers that have channel dependecy with this layer
# and prune all these layers at the same time.
_names = [x for x in self.channel_depen[wrapper.name]]
logger.info('Pruning the dependent layers: %s', ','.join(_names))
_wrappers = [name2wrapper[name]
for name in _names if name in name2wrapper]
_wrapper_idxes = [wrapper2index[_w] for _w in _wrappers]
masks = self._dependency_calc_mask(
_wrappers, _names, wrappers_idx=_wrapper_idxes)
if masks is not None:
for layer in masks:
for mask_type in masks[layer]:
assert hasattr(name2wrapper[layer], mask_type), "there is no attribute '%s' in wrapper on %s" \
% (mask_type, layer)
setattr(name2wrapper[layer], mask_type, masks[layer][mask_type])
......@@ -7,7 +7,7 @@ import torch
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.compressor import Pruner
from .finegrained_pruning import LevelPrunerMasker
from .finegrained_pruning_masker import LevelPrunerMasker
logger = logging.getLogger('torch pruner')
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from .dependency_aware_pruner import DependencyAwarePruner
__all__ = ['LevelPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class OneshotPruner(DependencyAwarePruner):
"""
Prune model to an exact pruning level for one time.
"""
def __init__(self, model, config_list, pruning_algorithm='level', dependency_aware=False, dummy_input=None,
**algo_kwargs):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
pruning_algorithm: str
algorithms being used to prune model
dependency_aware: bool
If prune the model in a dependency-aware way.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that,
the dummy_input should on the same device with the model.
algo_kwargs: dict
Additional parameters passed to pruning algorithm masker class
"""
super().__init__(model, config_list, None, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
class LevelPruner(OneshotPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='level')
def _supported_dependency_aware(self):
return False
class L1FilterPruner(OneshotPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def __init__(self, model, config_list, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l1', dependency_aware=dependency_aware,
dummy_input=dummy_input)
def _supported_dependency_aware(self):
return True
class L2FilterPruner(OneshotPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def __init__(self, model, config_list, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l2', dependency_aware=dependency_aware,
dummy_input=dummy_input)
def _supported_dependency_aware(self):
return True
class FPGMPruner(OneshotPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in FPGM Pruner.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def __init__(self, model, config_list, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='fpgm', dependency_aware=dependency_aware,
dummy_input=dummy_input)
def _supported_dependency_aware(self):
return True
......@@ -474,8 +474,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
def __init__(self, model, pruner, statistics_batch_num=1):
super().__init__(model, pruner)
self.pruner.statistics_batch_num = statistics_batch_num
self.pruner.set_wrappers_attribute("contribution", None)
self.pruner.iterations = 0
self.pruner.set_wrappers_attribute("contribution", None)
self.pruner.patch_optimizer(self.calc_contributions)
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
......@@ -499,6 +499,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
"""
if self.pruner.iterations >= self.pruner.statistics_batch_num:
return
for wrapper in self.pruner.get_modules_wrapper():
filters = wrapper.module.weight.size(0)
contribution = (
......@@ -677,16 +678,24 @@ class SlimPrunerMasker(WeightMasker):
def __init__(self, model, pruner, **kwargs):
super().__init__(model, pruner)
self.global_threshold = None
def _get_global_threshold(self):
weight_list = []
for (layer, _) in pruner.get_modules_to_compress():
for (layer, _) in self.pruner.get_modules_to_compress():
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity'])
k = int(all_bn_weights.shape[0] * self.pruner.config_list[0]['sparsity'])
self.global_threshold = torch.topk(
all_bn_weights.view(-1), k, largest=False)[0].max()
print(f'set global threshold to {self.global_threshold}')
def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
assert wrapper.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if self.global_threshold is None:
self._get_global_threshold()
weight = wrapper.module.weight.data.clone()
if wrapper.weight_mask is not None:
# apply base mask for iterative pruning
......@@ -706,7 +715,6 @@ class SlimPrunerMasker(WeightMasker):
), 'bias_mask': mask_bias.detach()}
return mask
def least_square_sklearn(X, Y):
from sklearn.linear_model import LinearRegression
reg = LinearRegression(fit_intercept=False)
......
......@@ -148,6 +148,7 @@ class QAT_Quantizer(Quantizer):
super().__init__(model, config_list, optimizer)
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
......@@ -161,7 +162,7 @@ class QAT_Quantizer(Quantizer):
layer.module.register_buffer('activation_bit', torch.zeros(1))
layer.module.register_buffer('tracked_min_activation', torch.zeros(1))
layer.module.register_buffer('tracked_max_activation', torch.zeros(1))
self.bound_model.to(device)
def _del_simulated_attr(self, module):
"""
......@@ -359,7 +360,7 @@ class QAT_Quantizer(Quantizer):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps +=1
self.bound_model.steps += 1
class DoReFaQuantizer(Quantizer):
......@@ -370,10 +371,12 @@ class DoReFaQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
device = next(model.parameters()).device
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
self.bound_model.to(device)
def _del_simulated_attr(self, module):
"""
......@@ -474,11 +477,13 @@ class BNNQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
device = next(model.parameters()).device
self.quant_grad = ClipGrad.apply
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
self.bound_model.to(device)
def _del_simulated_attr(self, module):
"""
......@@ -589,6 +594,7 @@ class LsqQuantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list, optimizer)
device = next(model.parameters()).device
self.quant_grad = QuantForward()
modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
......@@ -631,6 +637,8 @@ class LsqQuantizer(Quantizer):
self.optimizer.add_param_group({"params": layer.module.input_scale})
self.bound_model.to(device)
@staticmethod
def grad_scale(x, scale):
"""
......
from .one_shot import *
from .one_shot_pruner import *
......@@ -8,7 +8,6 @@ from . import default_layers
_logger = logging.getLogger(__name__)
class LayerInfo:
def __init__(self, name, module):
self.module = module
......@@ -235,7 +234,6 @@ class Compressor:
"""
raise NotImplementedError()
def add_activation_collector(self, collector):
self._fwd_hook_id += 1
self._fwd_hook_handles[self._fwd_hook_id] = []
......@@ -264,6 +262,18 @@ class Compressor:
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
def patch_optimizer_before(self, *tasks):
def patch_step(old_step):
def new_step(_, *args, **kwargs):
for task in tasks:
task()
# call origin optimizer step method
output = old_step(*args, **kwargs)
return output
return new_step
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
......@@ -319,8 +329,6 @@ class Pruner(Compressor):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
if optimizer is not None:
self.patch_optimizer(self.update_mask)
def compress(self):
self.update_mask()
......@@ -386,7 +394,7 @@ class Pruner(Compressor):
"""
assert model_path is not None, 'model_path must be specified'
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
self._unwrap_model() # used for generating correct state_dict name without wrapper state
for wrapper in self.get_modules_wrapper():
weight_mask = wrapper.weight_mask
......@@ -433,6 +441,27 @@ class Pruner(Compressor):
else:
self.bound_model.load_state_dict(model_state)
def get_pruned_weights(self, dim=0):
"""
Log the simulated prune sparsity.
Parameters
----------
dim : int
the pruned dim.
"""
for _, wrapper in enumerate(self.get_modules_wrapper()):
weight_mask = wrapper.weight_mask
mask_size = weight_mask.size()
if len(mask_size) == 1:
index = torch.nonzero(weight_mask.abs() != 0).tolist()
else:
sum_idx = list(range(len(mask_size)))
sum_idx.remove(dim)
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0).tolist()
_logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}')
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
"""
......@@ -549,7 +578,6 @@ class Quantizer(Compressor):
"""
raise NotImplementedError('Quantizer must overload quantize_input()')
def _wrap_modules(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
......@@ -571,8 +599,8 @@ class Quantizer(Compressor):
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, \
input_shape=None, device=None):
def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
input_shape=None, device=None):
"""
This method helps save pytorch model, calibration config, onnx model in quantizer.
......@@ -671,6 +699,7 @@ class QuantGrad(torch.autograd.Function):
quantized x without clamped
"""
return ((x / scale) + zero_point).round()
@classmethod
def get_bits_length(cls, config, quant_type):
"""
......@@ -703,8 +732,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
gradient of the output of quantization operation
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
`QuantType.QUANT_OUTPUT`, you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment