Commit 2de52a89 authored by Cjkkkk's avatar Cjkkkk Committed by chicm-ms
Browse files

add data parallel proposal (#1923)

* add data parallel proposal

* fix mask_weight bug

* add slim pruner support and example

* fix typo

* fix typo

* fix setattr error

* fix buffer update

* rename instrument_layer and prunerLayerWrapper

* fix pylint

* update reverse travsal

* add wrap and unwrap

* add register_buffer API

* update docstring

* update docstring

* add quantizer support

* fix typo

* update MeanActivationPruner, weight_rank_filter_pruner and example
parent 13d03757
import math import math
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -40,6 +41,12 @@ def test(model, device, test_loader): ...@@ -40,6 +41,12 @@ def test(model, device, test_loader):
def main(): def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -63,10 +70,11 @@ def main(): ...@@ -63,10 +70,11 @@ def main():
model.to(device) model.to(device)
# Train the base VGG-16 model # Train the base VGG-16 model
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10) print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160): for epoch in range(args.epochs):
train(model, device, train_loader, optimizer) train(model, device, train_loader, optimizer)
test(model, device, test_loader) test(model, device, test_loader)
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
...@@ -90,6 +98,14 @@ def main(): ...@@ -90,6 +98,14 @@ def main():
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list) pruner = L1FilterPruner(model, configure_list)
model = pruner.compress() model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
else:
print("only detect 1 gpu, fall back")
model.to(device)
test(model, device, test_loader) test(model, device, test_loader)
# top1 = 88.19% # top1 = 88.19%
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner
class fc1(nn.Module):
def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.linear1 = nn.Linear(32*28*28, 300)
self.relu2 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(300, 100)
self.relu3 = nn.ReLU(inplace=True)
self.linear3 = nn.Linear(100, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = torch.flatten(x,1)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
x = self.relu3(x)
x = self.linear3(x)
return x
def train(model, train_loader, optimizer, criterion, device):
model.train()
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()
def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy
if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)
device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")
model = fc1()
criterion = nn.CrossEntropyLoss()
configure_list = [{
'prune_iterations': 5,
'sparsity': 0.86,
'op_types': ['BatchNorm2d']
}]
pruner = SlimPruner(model, configure_list)
pruner.compress()
if torch.cuda.device_count()>1:
model = nn.DataParallel(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
for name, par in model.named_parameters():
print(name)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, criterion, device)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
\ No newline at end of file
import math import math
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -6,7 +7,6 @@ from torchvision import datasets, transforms ...@@ -6,7 +7,6 @@ from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner from nni.compression.torch import SlimPruner
from models.cifar10.vgg import VGG from models.cifar10.vgg import VGG
def updateBN(model): def updateBN(model):
for m in model.modules(): for m in model.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
...@@ -49,6 +49,13 @@ def test(model, device, test_loader): ...@@ -49,6 +49,13 @@ def test(model, device, test_loader):
def main(): def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -70,15 +77,16 @@ def main(): ...@@ -70,15 +77,16 @@ def main():
model = VGG(depth=19) model = VGG(depth=19)
model.to(device) model.to(device)
# Train the base VGG-19 model # Train the base VGG-19 model
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10) print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
epochs = 160 epochs = args.epochs
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
for epoch in range(epochs): for epoch in range(epochs):
if epoch in [epochs * 0.5, epochs * 0.75]: if epoch in [epochs * 0.5, epochs * 0.75]:
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1 param_group['lr'] *= 0.1
print("epoch {}".format(epoch))
train(model, device, train_loader, optimizer, True) train(model, device, train_loader, optimizer, True)
test(model, device, test_loader) test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth') torch.save(model.state_dict(), 'vgg19_cifar10.pth')
...@@ -99,9 +107,14 @@ def main(): ...@@ -99,9 +107,14 @@ def main():
print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10) print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
pruner = SlimPruner(model, configure_list) pruner = SlimPruner(model, configure_list)
model = pruner.compress() model = pruner.compress()
test(model, device, test_loader) if args.parallel:
# top1 = 93.55% if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
# model = nn.DataParallel(model, device_ids=[0, 1])
else:
print("only detect 1 gpu, fall back")
model.to(device)
# Fine tune the pruned model for 40 epochs and test accuracy # Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10) print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
......
...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num self.statistics_batch_num = statistics_batch_num
self.collected_activation = {} self.collected_activation = {}
self.hooks = {} self.hooks = {}
...@@ -63,7 +63,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -63,7 +63,7 @@ class ActivationRankFilterPruner(Pruner):
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked. Filters with the smallest importance criterion which is calculated from the activation are masked.
...@@ -82,14 +82,13 @@ class ActivationRankFilterPruner(Pruner): ...@@ -82,14 +82,13 @@ class ActivationRankFilterPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d" assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if_calculated = kwargs["if_calculated"]
assert op_name in self.mask_dict if if_calculated:
return self.mask_dict.get(op_name) return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
...@@ -104,8 +103,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -104,8 +103,7 @@ class ActivationRankFilterPruner(Pruner):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune) mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally: finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num: if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name)
return mask return mask
......
...@@ -14,8 +14,11 @@ class LayerInfo: ...@@ -14,8 +14,11 @@ class LayerInfo:
self.name = name self.name = name
self.type = type(module).__name__ self.type = type(module).__name__
self._forward = None def _setattr(model, name, module):
name_list = name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor: class Compressor:
""" """
...@@ -36,6 +39,8 @@ class Compressor: ...@@ -36,6 +39,8 @@ class Compressor:
self.bound_model = model self.bound_model = model
self.config_list = config_list self.config_list = config_list
self.modules_to_compress = None self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
def detect_modules_to_compress(self): def detect_modules_to_compress(self):
""" """
...@@ -51,21 +56,58 @@ class Compressor: ...@@ -51,21 +56,58 @@ class Compressor:
self.modules_to_compress.append((layer, config)) self.modules_to_compress.append((layer, config))
return self.modules_to_compress return self.modules_to_compress
def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
def compress(self): def compress(self):
""" """
Compress the model with algorithm implemented by subclass. Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers `self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
""" """
if self.modules_wrapper is not None:
# already compressed
return
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
self._instrument_layer(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model return self.bound_model
def register_buffer(self, name, value):
"""
To register buffers used in wrapped module's forward method.
"""
self.buffers[name] = value
def get_modules_to_compress(self): def get_modules_to_compress(self):
""" """
To obtain all the to-be-compressed layers. To obtain all the to-be-compressed modules.
Returns Returns
------- -------
...@@ -75,6 +117,17 @@ class Compressor: ...@@ -75,6 +117,17 @@ class Compressor:
""" """
return self.modules_to_compress return self.modules_to_compress
def get_modules_wrapper(self):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return self.modules_wrapper
def select_config(self, layer): def select_config(self, layer):
""" """
Find the configuration for `layer` by parsing `self.config_list` Find the configuration for `layer` by parsing `self.config_list`
...@@ -119,7 +172,7 @@ class Compressor: ...@@ -119,7 +172,7 @@ class Compressor:
If user want to update model every step, user can override this method If user want to update model every step, user can override this method
""" """
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
...@@ -143,6 +196,57 @@ class Compressor: ...@@ -143,6 +196,57 @@ class Compressor:
expanded_op_types.append(op_type) expanded_op_types.append(op_type)
return expanded_op_types return expanded_op_types
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.pruner = pruner
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
# register user specified buffer
self.registered_buffers = {}
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers)
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
class Pruner(Compressor): class Pruner(Compressor):
""" """
...@@ -158,7 +262,6 @@ class Pruner(Compressor): ...@@ -158,7 +262,6 @@ class Pruner(Compressor):
def __init__(self, model, config_list): def __init__(self, model, config_list):
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_dict = {}
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
...@@ -176,9 +279,9 @@ class Pruner(Compressor): ...@@ -176,9 +279,9 @@ class Pruner(Compressor):
""" """
raise NotImplementedError("Pruners must overload calc_mask()") raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper module to replace the original one.
Parameters Parameters
---------- ----------
...@@ -187,28 +290,8 @@ class Pruner(Compressor): ...@@ -187,28 +290,8 @@ class Pruner(Compressor):
config : dict config : dict
the configuration for generating the mask the configuration for generating the mask
""" """
assert layer._forward is None, 'Each model can only be compressed once' _logger.info("compressing module %s.", layer.name)
if not _check_weight(layer.module): return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
_logger.warning('Module %s does not have parameter "weight"', layer.name)
return
layer._forward = layer.module.forward
def new_forward(*inputs):
mask = self.calc_mask(layer, config)
# apply mask to weight
old_weight = layer.module.weight.data
mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward
ret = layer._forward(*inputs)
return ret
layer.module.forward = new_forward
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None): def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None):
""" """
...@@ -225,26 +308,29 @@ class Pruner(Compressor): ...@@ -225,26 +308,29 @@ class Pruner(Compressor):
input_shape : list or tuple input_shape : list or tuple
input shape to onnx model input shape to onnx model
""" """
if self.detect_modules_to_compress() and not self.mask_dict: # if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks') # _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules(): mask_dict = {}
if name == "": self._unwrap_model() # used for generating correct state_dict name without wrapper state
continue
masks = self.mask_dict.get(name) for wrapper in self.get_modules_wrapper():
if masks is not None: weight_mask = wrapper.weight_mask
mask_sum = masks['weight'].sum().item() bias_mask = wrapper.bias_mask
mask_num = masks['weight'].numel() if weight_mask is not None:
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num) mask_sum = weight_mask.sum().item()
m.weight.data = m.weight.data.mul(masks['weight']) mask_num = weight_mask.numel()
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None: _logger.info('Layer: %s Sparsity: %.2f', wrapper.name, 1 - mask_sum / mask_num)
m.bias.data = m.bias.data.mul(masks['bias']) wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
else: if bias_mask is not None:
_logger.info('Layer: %s NOT compressed', name) wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask}
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path) _logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None: if mask_path is not None:
torch.save(self.mask_dict, mask_path) torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path) _logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None: if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model' assert input_shape is not None, 'input_shape must be specified to export onnx model'
...@@ -253,6 +339,86 @@ class Pruner(Compressor): ...@@ -253,6 +339,86 @@ class Pruner(Compressor):
torch.onnx.export(self.bound_model, input_data, onnx_path) torch.onnx.export(self.bound_model, input_data, onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model()
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.quantizer = quantizer
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if 'weight' in config['quant_types']:
if not _check_weight(self.module):
_logger.warning('Module %s does not have parameter "weight"', self.name)
else:
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight))
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
self.registered_buffers = {}
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self.quantizer.quantize_input,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
return result
class Quantizer(Compressor): class Quantizer(Compressor):
""" """
...@@ -303,7 +469,7 @@ class Quantizer(Compressor): ...@@ -303,7 +469,7 @@ class Quantizer(Compressor):
raise NotImplementedError('Quantizer must overload quantize_input()') raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper forward function to replace the original one.
Parameters Parameters
...@@ -313,7 +479,6 @@ class Quantizer(Compressor): ...@@ -313,7 +479,6 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for quantization the configuration for quantization
""" """
assert layer._forward is None, 'Each model can only be compressed once'
assert 'quant_types' in config, 'must provide quant_types in config' assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type' assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config' assert 'quant_bits' in config, 'must provide quant_bits in config'
...@@ -323,35 +488,7 @@ class Quantizer(Compressor): ...@@ -323,35 +488,7 @@ class Quantizer(Compressor):
for quant_type in config['quant_types']: for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config['quant_types']: return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward
def new_forward(*inputs):
if 'input' in config['quant_types']:
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config['quant_types'] and _check_weight(layer.module):
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
else:
result = layer._forward(*inputs)
if 'output' in config['quant_types']:
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result
layer.module.forward = new_forward
class QuantType: class QuantType:
""" """
...@@ -387,15 +524,15 @@ class QuantGrad(torch.autograd.Function): ...@@ -387,15 +524,15 @@ class QuantGrad(torch.autograd.Function):
return grad_output return grad_output
@staticmethod @staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer): def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type])) ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name) return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs)
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type) output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None return output, None, None, None, None, None
def _check_weight(module): def _check_weight(module):
try: try:
......
...@@ -187,7 +187,6 @@ class SlimPruner(Pruner): ...@@ -187,7 +187,6 @@ class SlimPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = [] weight_list = []
if len(config_list) > 1: if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration') logger.warning('Slim pruner only supports 1 configuration')
...@@ -198,8 +197,9 @@ class SlimPruner(Pruner): ...@@ -198,8 +197,9 @@ class SlimPruner(Pruner):
all_bn_weights = torch.cat(weight_list) all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity']) k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked. Scale factors with the smallest absolute value in the BN layer are masked.
...@@ -209,6 +209,8 @@ class SlimPruner(Pruner): ...@@ -209,6 +209,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict config : dict
layer's pruning config layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
...@@ -216,27 +218,21 @@ class SlimPruner(Pruner): ...@@ -216,27 +218,21 @@ class SlimPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
if_calculated = kwargs["if_calculated"]
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops: if if_calculated:
assert op_name in self.mask_dict return None
return self.mask_dict.get(op_name)
base_mask = torch.ones(weight.size()).type_as(weight).detach() base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()} mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try:
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * config.get('sparsity')) num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1: if filters >= 2 and num_prune >= 1:
return mask
w_abs = weight.abs() w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone() mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally: if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask return mask
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
......
...@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner): ...@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() # operations whose mask has been calculated self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked. Filters with the smallest importance criterion of the kernel weights are masked.
...@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner): ...@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d" assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if_calculated = kwargs["if_calculated"]
assert op_name in self.mask_dict if if_calculated:
return self.mask_dict.get(op_name) return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
...@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return mask return mask
mask = self.get_mask(mask, weight, num_prune) mask = self.get_mask(mask, weight, num_prune)
finally: finally:
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name)
return mask return mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment