"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "1546962f83397710fe095538d052dc74bd981707"
Commit 2de52a89 authored by Cjkkkk's avatar Cjkkkk Committed by chicm-ms
Browse files

add data parallel proposal (#1923)

* add data parallel proposal

* fix mask_weight bug

* add slim pruner support and example

* fix typo

* fix typo

* fix setattr error

* fix buffer update

* rename instrument_layer and prunerLayerWrapper

* fix pylint

* update reverse travsal

* add wrap and unwrap

* add register_buffer API

* update docstring

* update docstring

* add quantizer support

* fix typo

* update MeanActivationPruner, weight_rank_filter_pruner and example
parent 13d03757
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -40,6 +41,12 @@ def test(model, device, test_loader):
def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
......@@ -63,14 +70,15 @@ def main():
model.to(device)
# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(args.epochs):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
......@@ -90,6 +98,14 @@ def main():
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
else:
print("only detect 1 gpu, fall back")
model.to(device)
test(model, device, test_loader)
# top1 = 88.19%
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner
class fc1(nn.Module):
def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.linear1 = nn.Linear(32*28*28, 300)
self.relu2 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(300, 100)
self.relu3 = nn.ReLU(inplace=True)
self.linear3 = nn.Linear(100, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = torch.flatten(x,1)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
x = self.relu3(x)
x = self.linear3(x)
return x
def train(model, train_loader, optimizer, criterion, device):
model.train()
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()
def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy
if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)
device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")
model = fc1()
criterion = nn.CrossEntropyLoss()
configure_list = [{
'prune_iterations': 5,
'sparsity': 0.86,
'op_types': ['BatchNorm2d']
}]
pruner = SlimPruner(model, configure_list)
pruner.compress()
if torch.cuda.device_count()>1:
model = nn.DataParallel(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
for name, par in model.named_parameters():
print(name)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, criterion, device)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
\ No newline at end of file
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -6,7 +7,6 @@ from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner
from models.cifar10.vgg import VGG
def updateBN(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
......@@ -49,6 +49,13 @@ def test(model, device, test_loader):
def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
......@@ -70,18 +77,19 @@ def main():
model = VGG(depth=19)
model.to(device)
# Train the base VGG-19 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
epochs = 160
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
for epoch in range(epochs):
if epoch in [epochs * 0.5, epochs * 0.75]:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
train(model, device, train_loader, optimizer, True)
test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth')
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
epochs = args.epochs
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
for epoch in range(epochs):
if epoch in [epochs * 0.5, epochs * 0.75]:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
print("epoch {}".format(epoch))
train(model, device, train_loader, optimizer, True)
test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth')
# Test base model accuracy
print('=' * 10 + 'Test the original model' + '=' * 10)
......@@ -94,14 +102,19 @@ def main():
'sparsity': 0.7,
'op_types': ['BatchNorm2d'],
}]
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
pruner = SlimPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 93.55%
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
# model = nn.DataParallel(model, device_ids=[0, 1])
else:
print("only detect 1 gpu, fall back")
model.to(device)
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
......
......@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
......@@ -63,7 +63,7 @@ class ActivationRankFilterPruner(Pruner):
def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
......@@ -82,14 +82,13 @@ class ActivationRankFilterPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
......@@ -104,8 +103,7 @@ class ActivationRankFilterPruner(Pruner):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
return mask
......
......@@ -14,8 +14,11 @@ class LayerInfo:
self.name = name
self.type = type(module).__name__
self._forward = None
def _setattr(model, name, module):
name_list = name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor:
"""
......@@ -36,6 +39,8 @@ class Compressor:
self.bound_model = model
self.config_list = config_list
self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
def detect_modules_to_compress(self):
"""
......@@ -51,21 +56,58 @@ class Compressor:
self.modules_to_compress.append((layer, config))
return self.modules_to_compress
def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
def compress(self):
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
if self.modules_wrapper is not None:
# already compressed
return
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model
def register_buffer(self, name, value):
"""
To register buffers used in wrapped module's forward method.
"""
self.buffers[name] = value
def get_modules_to_compress(self):
"""
To obtain all the to-be-compressed layers.
To obtain all the to-be-compressed modules.
Returns
-------
......@@ -75,6 +117,17 @@ class Compressor:
"""
return self.modules_to_compress
def get_modules_wrapper(self):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return self.modules_wrapper
def select_config(self, layer):
"""
Find the configuration for `layer` by parsing `self.config_list`
......@@ -119,7 +172,7 @@ class Compressor:
If user want to update model every step, user can override this method
"""
def _instrument_layer(self, layer, config):
def _wrap_modules(self, layer, config):
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
......@@ -143,6 +196,57 @@ class Compressor:
expanded_op_types.append(op_type)
return expanded_op_types
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.pruner = pruner
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
# register user specified buffer
self.registered_buffers = {}
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers)
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
class Pruner(Compressor):
"""
......@@ -158,7 +262,6 @@ class Pruner(Compressor):
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.mask_dict = {}
def calc_mask(self, layer, config):
"""
......@@ -176,9 +279,9 @@ class Pruner(Compressor):
"""
raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config):
def _wrap_modules(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
Create a wrapper module to replace the original one.
Parameters
----------
......@@ -187,28 +290,8 @@ class Pruner(Compressor):
config : dict
the configuration for generating the mask
"""
assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
return
layer._forward = layer.module.forward
def new_forward(*inputs):
mask = self.calc_mask(layer, config)
# apply mask to weight
old_weight = layer.module.weight.data
mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward
ret = layer._forward(*inputs)
return ret
layer.module.forward = new_forward
_logger.info("compressing module %s.", layer.name)
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None):
"""
......@@ -225,26 +308,29 @@ class Pruner(Compressor):
input_shape : list or tuple
input shape to onnx model
"""
if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules():
if name == "":
continue
masks = self.mask_dict.get(name)
if masks is not None:
mask_sum = masks['weight'].sum().item()
mask_num = masks['weight'].numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
m.weight.data = m.weight.data.mul(masks['weight'])
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None:
m.bias.data = m.bias.data.mul(masks['bias'])
else:
_logger.info('Layer: %s NOT compressed', name)
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
for wrapper in self.get_modules_wrapper():
weight_mask = wrapper.weight_mask
bias_mask = wrapper.bias_mask
if weight_mask is not None:
mask_sum = weight_mask.sum().item()
mask_num = weight_mask.numel()
_logger.info('Layer: %s Sparsity: %.2f', wrapper.name, 1 - mask_sum / mask_num)
wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
if bias_mask is not None:
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask}
torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None:
torch.save(self.mask_dict, mask_path)
torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
......@@ -253,6 +339,86 @@ class Pruner(Compressor):
torch.onnx.export(self.bound_model, input_data, onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model()
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.quantizer = quantizer
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if 'weight' in config['quant_types']:
if not _check_weight(self.module):
_logger.warning('Module %s does not have parameter "weight"', self.name)
else:
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight))
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
self.registered_buffers = {}
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self.quantizer.quantize_input,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
return result
class Quantizer(Compressor):
"""
......@@ -303,7 +469,7 @@ class Quantizer(Compressor):
raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config):
def _wrap_modules(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
Parameters
......@@ -313,7 +479,6 @@ class Quantizer(Compressor):
config : dict
the configuration for quantization
"""
assert layer._forward is None, 'Each model can only be compressed once'
assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config'
......@@ -323,35 +488,7 @@ class Quantizer(Compressor):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config['quant_types']:
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward
def new_forward(*inputs):
if 'input' in config['quant_types']:
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config['quant_types'] and _check_weight(layer.module):
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
else:
result = layer._forward(*inputs)
if 'output' in config['quant_types']:
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result
layer.module.forward = new_forward
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
class QuantType:
"""
......@@ -387,15 +524,15 @@ class QuantGrad(torch.autograd.Function):
return grad_output
@staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer):
def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name)
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs)
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None
return output, None, None, None, None, None
def _check_weight(module):
try:
......
......@@ -187,7 +187,6 @@ class SlimPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
......@@ -198,8 +197,9 @@ class SlimPruner(Pruner):
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
......@@ -209,6 +209,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
......@@ -216,27 +218,21 @@ class SlimPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
if_calculated = kwargs["if_calculated"]
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
if if_calculated:
return None
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters >= 2 and num_prune >= 1:
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
return mask
class LotteryTicketPruner(Pruner):
......
......@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set() # operations whose mask has been calculated
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
......@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
......@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
return mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment