Commit ed121315 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #2022 from microsoft/dev-pruner-dataparallel

Dev pruner DataParallel
parents c7187946 8092c8bd
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner
from nni.compression.torch import ActivationMeanRankFilterPruner
from models.cifar10.vgg import VGG
......@@ -40,6 +41,12 @@ def test(model, device, test_loader):
def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
......@@ -63,10 +70,11 @@ def main():
model.to(device)
# Train the base VGG-16 model
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
for epoch in range(args.epochs):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
......@@ -88,8 +96,16 @@ def main():
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
pruner = ActivationMeanRankFilterPruner(model, configure_list)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
else:
print("only detect 1 gpu, fall back")
model.to(device)
test(model, device, test_loader)
# top1 = 88.19%
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner
......@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
......@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters
def print_conv_filter_sparsity(self):
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))
......
......@@ -71,6 +71,8 @@ if __name__ == '__main__':
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()
#model = nn.DataParallel(model)
for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner
class fc1(nn.Module):
def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.linear1 = nn.Linear(32*28*28, 300)
self.relu2 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(300, 100)
self.relu3 = nn.ReLU(inplace=True)
self.linear3 = nn.Linear(100, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = torch.flatten(x,1)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
x = self.relu3(x)
x = self.linear3(x)
return x
def train(model, train_loader, optimizer, criterion, device):
model.train()
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()
def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy
if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1()
criterion = nn.CrossEntropyLoss()
configure_list = [{
'prune_iterations': 5,
'sparsity': 0.86,
'op_types': ['BatchNorm2d']
}]
pruner = SlimPruner(model, configure_list)
pruner.compress()
if torch.cuda.device_count()>1:
model = nn.DataParallel(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
for name, par in model.named_parameters():
print(name)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, criterion, device)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
\ No newline at end of file
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -6,7 +7,6 @@ from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner
from models.cifar10.vgg import VGG
def updateBN(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
......@@ -49,6 +49,13 @@ def test(model, device, test_loader):
def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
......@@ -70,15 +77,16 @@ def main():
model = VGG(depth=19)
model.to(device)
# Train the base VGG-19 model
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
epochs = 160
epochs = args.epochs
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
for epoch in range(epochs):
if epoch in [epochs * 0.5, epochs * 0.75]:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
print("epoch {}".format(epoch))
train(model, device, train_loader, optimizer, True)
test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth')
......@@ -99,9 +107,14 @@ def main():
print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
pruner = SlimPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 93.55%
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
# model = nn.DataParallel(model, device_ids=[0, 1])
else:
print("only detect 1 gpu, fall back")
model.to(device)
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
......
......@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
......@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner):
"""
Compress the model, register a hook for collecting activations.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook)
wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model
def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
......@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
......@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
......
......@@ -14,8 +14,11 @@ class LayerInfo:
self.name = name
self.type = type(module).__name__
self._forward = None
def _setattr(model, name, module):
name_list = name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor:
"""
......@@ -36,6 +39,9 @@ class Compressor:
self.bound_model = model
self.config_list = config_list
self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
self.is_wrapped = False
def detect_modules_to_compress(self):
"""
......@@ -51,21 +57,60 @@ class Compressor:
self.modules_to_compress.append((layer, config))
return self.modules_to_compress
def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
def compress(self):
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model
def register_buffer(self, name, value):
"""
To register buffers used in wrapped module's forward method.
"""
self.buffers[name] = value
def get_modules_to_compress(self):
"""
To obtain all the to-be-compressed layers.
To obtain all the to-be-compressed modules.
Returns
-------
......@@ -75,6 +120,17 @@ class Compressor:
"""
return self.modules_to_compress
def get_modules_wrapper(self):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return self.modules_wrapper
def select_config(self, layer):
"""
Find the configuration for `layer` by parsing `self.config_list`
......@@ -119,7 +175,7 @@ class Compressor:
If user want to update model every step, user can override this method
"""
def _instrument_layer(self, layer, config):
def _wrap_modules(self, layer, config):
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
......@@ -143,6 +199,59 @@ class Compressor:
expanded_op_types.append(op_type)
return expanded_op_types
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.pruner = pruner
self.registered_buffers = {}
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
self.registered_buffers['weight_mask'] = self.weight_mask
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
self.registered_buffers['bias_mask'] = self.bias_mask
# register user specified buffer
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers)
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
class Pruner(Compressor):
"""
......@@ -158,9 +267,8 @@ class Pruner(Compressor):
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.mask_dict = {}
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
......@@ -176,9 +284,9 @@ class Pruner(Compressor):
"""
raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config):
def _wrap_modules(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
Create a wrapper module to replace the original one.
Parameters
----------
......@@ -187,30 +295,13 @@ class Pruner(Compressor):
config : dict
the configuration for generating the mask
"""
assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
return
layer._forward = layer.module.forward
def new_forward(*inputs):
mask = self.calc_mask(layer, config)
# apply mask to weight
old_weight = layer.module.weight.data
mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward
ret = layer._forward(*inputs)
return ret
layer.module.forward = new_forward
_logger.info("compressing module %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight')
wrapper.to(layer.module.weight.device)
return wrapper
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None):
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export pruned model weights, masks and onnx model(optional)
......@@ -224,35 +315,138 @@ class Pruner(Compressor):
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules():
if name == "":
continue
masks = self.mask_dict.get(name)
if masks is not None:
mask_sum = masks['weight'].sum().item()
mask_num = masks['weight'].numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
m.weight.data = m.weight.data.mul(masks['weight'])
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None:
m.bias.data = m.bias.data.mul(masks['bias'])
else:
_logger.info('Layer: %s NOT compressed', name)
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
for wrapper in self.get_modules_wrapper():
weight_mask = wrapper.weight_mask
bias_mask = wrapper.bias_mask
if weight_mask is not None:
mask_sum = weight_mask.sum().item()
mask_num = weight_mask.numel()
_logger.info('Layer: %s Sparsity: %.2f', wrapper.name, 1 - mask_sum / mask_num)
wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
if bias_mask is not None:
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask}
torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None:
torch.save(self.mask_dict, mask_path)
torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model()
def load_model_state_dict(self, model_state):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if self.is_wrapped:
self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
self.bound_model.load_state_dict(model_state)
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.quantizer = quantizer
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if 'weight' in config['quant_types']:
if not _check_weight(self.module):
_logger.warning('Module %s does not have parameter "weight"', self.name)
else:
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight))
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
self.registered_buffers = {}
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self.quantizer.quantize_input,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
return result
class Quantizer(Compressor):
"""
......@@ -303,7 +497,7 @@ class Quantizer(Compressor):
raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config):
def _wrap_modules(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
Parameters
......@@ -313,7 +507,6 @@ class Quantizer(Compressor):
config : dict
the configuration for quantization
"""
assert layer._forward is None, 'Each model can only be compressed once'
assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config'
......@@ -323,35 +516,7 @@ class Quantizer(Compressor):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config['quant_types']:
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward
def new_forward(*inputs):
if 'input' in config['quant_types']:
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config['quant_types'] and _check_weight(layer.module):
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
else:
result = layer._forward(*inputs)
if 'output' in config['quant_types']:
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result
layer.module.forward = new_forward
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
class QuantType:
"""
......@@ -387,19 +552,18 @@ class QuantGrad(torch.autograd.Function):
return grad_output
@staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer):
def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name)
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs)
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None
return output, None, None, None, None, None
def _check_weight(module):
try:
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False
\ No newline at end of file
......@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer
Parameters
......@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
if op_name not in self.mask_calculated_ops:
if_calculated = kwargs["if_calculated"]
if not if_calculated:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
......@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name]
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
else:
return None
class AGP_Pruner(Pruner):
......@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner):
super().__init__(model, config_list)
self.now_epoch = 0
self.if_init_list = {}
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
......@@ -102,12 +104,16 @@ class AGP_Pruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight': kwargs['weight_mask'] if 'weight_mask' in kwargs else torch.ones(weight.shape).type_as(weight)}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
......@@ -116,10 +122,8 @@ class AGP_Pruner(Pruner):
w_abs = weight.abs() * mask['weight']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return new_mask
def compute_target_sparsity(self, config):
......@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner):
if epoch > 0:
self.now_epoch = epoch
for k in self.if_init_list.keys():
self.if_init_list[k] = True
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
class SlimPruner(Pruner):
"""
......@@ -187,7 +190,6 @@ class SlimPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
......@@ -198,8 +200,9 @@ class SlimPruner(Pruner):
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
......@@ -209,6 +212,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
......@@ -216,27 +221,21 @@ class SlimPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
if_calculated = kwargs["if_calculated"]
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
if if_calculated:
return None
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
if filters >= 2 and num_prune >= 1:
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
class LotteryTicketPruner(Pruner):
......@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations = config['prune_iterations']
return prune_iterations
def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask['weight'].sum().item()
mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')
def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, weight, sparsity, op_name):
def _calc_mask(self, weight, sparsity, curr_w_mask):
if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight)
else:
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask['weight']
w_abs = weight.abs() * curr_w_mask
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask}
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Generate mask for the given ``weight``.
......@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training'
mask = self.mask_dict[layer.name]
return mask
return None
def get_prune_iterations(self):
"""
......@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner):
self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
if wrapper.name == layer.name:
module_wrapper = wrapper
break
assert module_wrapper is not None
sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name)
self.mask_dict.update({layer.name: mask})
self._print_masks()
mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
# TODO: directly use weight_mask is not good
module_wrapper.weight_mask.copy_(mask['weight'])
# there is no mask for bias
# reinit weights back to original after new masks are generated
if self.reset_weights:
self._model.load_state_dict(self._model_state)
# should use this member function to reset model weights
self.load_model_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state)
......@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set() # operations whose mask has been calculated
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
......@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
......@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
......@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return x.sum()
def update_epoch(self, epoch):
self.mask_calculated_ops = set()
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
......@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0])
masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1])
masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2
......@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
......@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase):
model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0])
mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1])
mask2 = pruner.calc_mask(layer2, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
......@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
......@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
......@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase):
# test ema
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps)
quantizer.step()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment