Commit ed121315 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #2022 from microsoft/dev-pruner-dataparallel

Dev pruner DataParallel
parents c7187946 8092c8bd
import math import math
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner from nni.compression.torch import ActivationMeanRankFilterPruner
from models.cifar10.vgg import VGG from models.cifar10.vgg import VGG
...@@ -40,6 +41,12 @@ def test(model, device, test_loader): ...@@ -40,6 +41,12 @@ def test(model, device, test_loader):
def main(): def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -63,10 +70,11 @@ def main(): ...@@ -63,10 +70,11 @@ def main():
model.to(device) model.to(device)
# Train the base VGG-16 model # Train the base VGG-16 model
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10) print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160): for epoch in range(args.epochs):
train(model, device, train_loader, optimizer) train(model, device, train_loader, optimizer)
test(model, device, test_loader) test(model, device, test_loader)
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
...@@ -88,8 +96,16 @@ def main(): ...@@ -88,8 +96,16 @@ def main():
# Prune model and test accuracy without fine tuning. # Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list) pruner = ActivationMeanRankFilterPruner(model, configure_list)
model = pruner.compress() model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
else:
print("only detect 1 gpu, fall back")
model.to(device)
test(model, device, test_loader) test(model, device, test_loader)
# top1 = 88.19% # top1 = 88.19%
......
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner from nni.compression.torch import FPGMPruner
...@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner ...@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
class Mnist(torch.nn.Module): class Mnist(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10) self.fc2 = nn.Linear(500, 10)
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.conv1(x))
...@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module): ...@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters return num_zero_filters, num_filters, float(num_zero_filters)/num_filters
def print_conv_filter_sparsity(self): def print_conv_filter_sparsity(self):
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1) conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2) conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2])) print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2])) print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))
......
...@@ -71,6 +71,8 @@ if __name__ == '__main__': ...@@ -71,6 +71,8 @@ if __name__ == '__main__':
pruner = LotteryTicketPruner(model, configure_list, optimizer) pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress() pruner.compress()
#model = nn.DataParallel(model)
for i in pruner.get_prune_iterations(): for i in pruner.get_prune_iterations():
pruner.prune_iteration_start() pruner.prune_iteration_start()
loss = 0 loss = 0
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner
class fc1(nn.Module):
def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.linear1 = nn.Linear(32*28*28, 300)
self.relu2 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(300, 100)
self.relu3 = nn.ReLU(inplace=True)
self.linear3 = nn.Linear(100, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = torch.flatten(x,1)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
x = self.relu3(x)
x = self.linear3(x)
return x
def train(model, train_loader, optimizer, criterion, device):
model.train()
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()
def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy
if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1()
criterion = nn.CrossEntropyLoss()
configure_list = [{
'prune_iterations': 5,
'sparsity': 0.86,
'op_types': ['BatchNorm2d']
}]
pruner = SlimPruner(model, configure_list)
pruner.compress()
if torch.cuda.device_count()>1:
model = nn.DataParallel(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
for name, par in model.named_parameters():
print(name)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, criterion, device)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
\ No newline at end of file
import math import math
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -6,7 +7,6 @@ from torchvision import datasets, transforms ...@@ -6,7 +7,6 @@ from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner from nni.compression.torch import SlimPruner
from models.cifar10.vgg import VGG from models.cifar10.vgg import VGG
def updateBN(model): def updateBN(model):
for m in model.modules(): for m in model.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
...@@ -49,6 +49,13 @@ def test(model, device, test_loader): ...@@ -49,6 +49,13 @@ def test(model, device, test_loader):
def main(): def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda') device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -70,15 +77,16 @@ def main(): ...@@ -70,15 +77,16 @@ def main():
model = VGG(depth=19) model = VGG(depth=19)
model.to(device) model.to(device)
# Train the base VGG-19 model # Train the base VGG-19 model
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10) print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
epochs = 160 epochs = args.epochs
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
for epoch in range(epochs): for epoch in range(epochs):
if epoch in [epochs * 0.5, epochs * 0.75]: if epoch in [epochs * 0.5, epochs * 0.75]:
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1 param_group['lr'] *= 0.1
print("epoch {}".format(epoch))
train(model, device, train_loader, optimizer, True) train(model, device, train_loader, optimizer, True)
test(model, device, test_loader) test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth') torch.save(model.state_dict(), 'vgg19_cifar10.pth')
...@@ -99,9 +107,14 @@ def main(): ...@@ -99,9 +107,14 @@ def main():
print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10) print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
pruner = SlimPruner(model, configure_list) pruner = SlimPruner(model, configure_list)
model = pruner.compress() model = pruner.compress()
test(model, device, test_loader) if args.parallel:
# top1 = 93.55% if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
# model = nn.DataParallel(model, device_ids=[0, 1])
else:
print("only detect 1 gpu, fall back")
model.to(device)
# Fine tune the pruned model for 40 epochs and test accuracy # Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10) print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
......
...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num self.statistics_batch_num = statistics_batch_num
self.collected_activation = {} self.collected_activation = {}
self.hooks = {} self.hooks = {}
...@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner): ...@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner):
""" """
Compress the model, register a hook for collecting activations. Compress the model, register a hook for collecting activations.
""" """
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
self._instrument_layer(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = [] self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name): def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num: if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu())) self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook) wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model return self.bound_model
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked. Filters with the smallest importance criterion which is calculated from the activation are masked.
...@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner): ...@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d" assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if_calculated = kwargs["if_calculated"]
assert op_name in self.mask_dict if if_calculated:
return self.mask_dict.get(op_name) return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
...@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune) mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally: finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num: if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name)
return mask return mask
......
...@@ -14,8 +14,11 @@ class LayerInfo: ...@@ -14,8 +14,11 @@ class LayerInfo:
self.name = name self.name = name
self.type = type(module).__name__ self.type = type(module).__name__
self._forward = None def _setattr(model, name, module):
name_list = name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor: class Compressor:
""" """
...@@ -36,6 +39,9 @@ class Compressor: ...@@ -36,6 +39,9 @@ class Compressor:
self.bound_model = model self.bound_model = model
self.config_list = config_list self.config_list = config_list
self.modules_to_compress = None self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
self.is_wrapped = False
def detect_modules_to_compress(self): def detect_modules_to_compress(self):
""" """
...@@ -51,21 +57,60 @@ class Compressor: ...@@ -51,21 +57,60 @@ class Compressor:
self.modules_to_compress.append((layer, config)) self.modules_to_compress.append((layer, config))
return self.modules_to_compress return self.modules_to_compress
def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
def compress(self): def compress(self):
""" """
Compress the model with algorithm implemented by subclass. Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers `self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
""" """
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
self._instrument_layer(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model return self.bound_model
def register_buffer(self, name, value):
"""
To register buffers used in wrapped module's forward method.
"""
self.buffers[name] = value
def get_modules_to_compress(self): def get_modules_to_compress(self):
""" """
To obtain all the to-be-compressed layers. To obtain all the to-be-compressed modules.
Returns Returns
------- -------
...@@ -75,6 +120,17 @@ class Compressor: ...@@ -75,6 +120,17 @@ class Compressor:
""" """
return self.modules_to_compress return self.modules_to_compress
def get_modules_wrapper(self):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return self.modules_wrapper
def select_config(self, layer): def select_config(self, layer):
""" """
Find the configuration for `layer` by parsing `self.config_list` Find the configuration for `layer` by parsing `self.config_list`
...@@ -119,7 +175,7 @@ class Compressor: ...@@ -119,7 +175,7 @@ class Compressor:
If user want to update model every step, user can override this method If user want to update model every step, user can override this method
""" """
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
...@@ -143,6 +199,59 @@ class Compressor: ...@@ -143,6 +199,59 @@ class Compressor:
expanded_op_types.append(op_type) expanded_op_types.append(op_type)
return expanded_op_types return expanded_op_types
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.pruner = pruner
self.registered_buffers = {}
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
self.registered_buffers['weight_mask'] = self.weight_mask
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
self.registered_buffers['bias_mask'] = self.bias_mask
# register user specified buffer
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers)
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
class Pruner(Compressor): class Pruner(Compressor):
""" """
...@@ -158,9 +267,8 @@ class Pruner(Compressor): ...@@ -158,9 +267,8 @@ class Pruner(Compressor):
def __init__(self, model, config_list): def __init__(self, model, config_list):
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_dict = {}
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Pruners should overload this method to provide mask for weight tensors. Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight. The mask must have the same shape and type comparing to the weight.
...@@ -176,9 +284,9 @@ class Pruner(Compressor): ...@@ -176,9 +284,9 @@ class Pruner(Compressor):
""" """
raise NotImplementedError("Pruners must overload calc_mask()") raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper module to replace the original one.
Parameters Parameters
---------- ----------
...@@ -187,30 +295,13 @@ class Pruner(Compressor): ...@@ -187,30 +295,13 @@ class Pruner(Compressor):
config : dict config : dict
the configuration for generating the mask the configuration for generating the mask
""" """
assert layer._forward is None, 'Each model can only be compressed once' _logger.info("compressing module %s.", layer.name)
if not _check_weight(layer.module): wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
_logger.warning('Module %s does not have parameter "weight"', layer.name) assert hasattr(layer.module, 'weight')
return wrapper.to(layer.module.weight.device)
layer._forward = layer.module.forward return wrapper
def new_forward(*inputs):
mask = self.calc_mask(layer, config)
# apply mask to weight
old_weight = layer.module.weight.data
mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward
ret = layer._forward(*inputs)
return ret
layer.module.forward = new_forward
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None): def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
""" """
Export pruned model weights, masks and onnx model(optional) Export pruned model weights, masks and onnx model(optional)
...@@ -224,35 +315,138 @@ class Pruner(Compressor): ...@@ -224,35 +315,138 @@ class Pruner(Compressor):
(optional) path to save onnx model (optional) path to save onnx model
input_shape : list or tuple input_shape : list or tuple
input shape to onnx model input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
""" """
if self.detect_modules_to_compress() and not self.mask_dict: # if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks') # _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules(): mask_dict = {}
if name == "": self._unwrap_model() # used for generating correct state_dict name without wrapper state
continue
masks = self.mask_dict.get(name) for wrapper in self.get_modules_wrapper():
if masks is not None: weight_mask = wrapper.weight_mask
mask_sum = masks['weight'].sum().item() bias_mask = wrapper.bias_mask
mask_num = masks['weight'].numel() if weight_mask is not None:
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num) mask_sum = weight_mask.sum().item()
m.weight.data = m.weight.data.mul(masks['weight']) mask_num = weight_mask.numel()
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None: _logger.info('Layer: %s Sparsity: %.2f', wrapper.name, 1 - mask_sum / mask_num)
m.bias.data = m.bias.data.mul(masks['bias']) wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
else: if bias_mask is not None:
_logger.info('Layer: %s NOT compressed', name) wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask}
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path) _logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None: if mask_path is not None:
torch.save(self.mask_dict, mask_path) torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path) _logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None: if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model' assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed # input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape) input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path) torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model()
def load_model_state_dict(self, model_state):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if self.is_wrapped:
self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
self.bound_model.load_state_dict(model_state)
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.quantizer = quantizer
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if 'weight' in config['quant_types']:
if not _check_weight(self.module):
_logger.warning('Module %s does not have parameter "weight"', self.name)
else:
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight))
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
self.registered_buffers = {}
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self.quantizer.quantize_input,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output,
self.config,
LayerInfo(self.name, self.module),
**self.registered_buffers)
return result
class Quantizer(Compressor): class Quantizer(Compressor):
""" """
...@@ -303,7 +497,7 @@ class Quantizer(Compressor): ...@@ -303,7 +497,7 @@ class Quantizer(Compressor):
raise NotImplementedError('Quantizer must overload quantize_input()') raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper forward function to replace the original one.
Parameters Parameters
...@@ -313,7 +507,6 @@ class Quantizer(Compressor): ...@@ -313,7 +507,6 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for quantization the configuration for quantization
""" """
assert layer._forward is None, 'Each model can only be compressed once'
assert 'quant_types' in config, 'must provide quant_types in config' assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type' assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config' assert 'quant_bits' in config, 'must provide quant_bits in config'
...@@ -323,35 +516,7 @@ class Quantizer(Compressor): ...@@ -323,35 +516,7 @@ class Quantizer(Compressor):
for quant_type in config['quant_types']: for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config['quant_types']: return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward
def new_forward(*inputs):
if 'input' in config['quant_types']:
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config['quant_types'] and _check_weight(layer.module):
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
else:
result = layer._forward(*inputs)
if 'output' in config['quant_types']:
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result
layer.module.forward = new_forward
class QuantType: class QuantType:
""" """
...@@ -387,19 +552,18 @@ class QuantGrad(torch.autograd.Function): ...@@ -387,19 +552,18 @@ class QuantGrad(torch.autograd.Function):
return grad_output return grad_output
@staticmethod @staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer): def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type])) ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name) return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs)
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type) output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None return output, None, None, None, None, None
def _check_weight(module): def _check_weight(module):
try: try:
return isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
\ No newline at end of file
...@@ -27,9 +27,9 @@ class LevelPruner(Pruner): ...@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer
Parameters Parameters
...@@ -45,8 +45,9 @@ class LevelPruner(Pruner): ...@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name if_calculated = kwargs["if_calculated"]
if op_name not in self.mask_calculated_ops:
if not if_calculated:
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
...@@ -54,12 +55,10 @@ class LevelPruner(Pruner): ...@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight) mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight} mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name)
else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name]
return mask return mask
else:
return None
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
...@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner): ...@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner):
super().__init__(model, config_list) super().__init__(model, config_list)
self.now_epoch = 0 self.now_epoch = 0
self.if_init_list = {} self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict config : dict
layer's pruning config layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
...@@ -102,12 +104,16 @@ class AGP_Pruner(Pruner): ...@@ -102,12 +104,16 @@ class AGP_Pruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0: if_calculated = kwargs["if_calculated"]
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)}) if if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight': kwargs['weight_mask'] if 'weight_mask' in kwargs else torch.ones(weight.shape).type_as(weight)}
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
...@@ -116,10 +122,8 @@ class AGP_Pruner(Pruner): ...@@ -116,10 +122,8 @@ class AGP_Pruner(Pruner):
w_abs = weight.abs() * mask['weight'] w_abs = weight.abs() * mask['weight']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
return new_mask return new_mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
...@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner): ...@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list.keys(): for wrapper in self.get_modules_wrapper():
self.if_init_list[k] = True wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
class SlimPruner(Pruner): class SlimPruner(Pruner):
""" """
...@@ -187,7 +190,6 @@ class SlimPruner(Pruner): ...@@ -187,7 +190,6 @@ class SlimPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = [] weight_list = []
if len(config_list) > 1: if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration') logger.warning('Slim pruner only supports 1 configuration')
...@@ -198,8 +200,9 @@ class SlimPruner(Pruner): ...@@ -198,8 +200,9 @@ class SlimPruner(Pruner):
all_bn_weights = torch.cat(weight_list) all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity']) k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked. Scale factors with the smallest absolute value in the BN layer are masked.
...@@ -209,6 +212,8 @@ class SlimPruner(Pruner): ...@@ -209,6 +212,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict config : dict
layer's pruning config layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
...@@ -216,27 +221,21 @@ class SlimPruner(Pruner): ...@@ -216,27 +221,21 @@ class SlimPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
if_calculated = kwargs["if_calculated"]
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops: if if_calculated:
assert op_name in self.mask_dict return None
return self.mask_dict.get(op_name)
base_mask = torch.ones(weight.size()).type_as(weight).detach() base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()} mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try:
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * config.get('sparsity')) num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1: if filters >= 2 and num_prune >= 1:
return mask
w_abs = weight.abs() w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone() mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally: if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask return mask
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
...@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner): ...@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations = config['prune_iterations'] prune_iterations = config['prune_iterations']
return prune_iterations return prune_iterations
def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask['weight'].sum().item()
mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')
def _calc_sparsity(self, sparsity): def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations) keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0) return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, weight, sparsity, op_name): def _calc_mask(self, weight, sparsity, curr_w_mask):
if self.curr_prune_iteration == 0: if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight) mask = torch.ones(weight.shape).type_as(weight)
else: else:
curr_sparsity = self._calc_sparsity(sparsity) curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None w_abs = weight.abs() * curr_w_mask
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask['weight']
k = int(w_abs.numel() * curr_sparsity) k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask} return {'weight': mask}
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Generate mask for the given ``weight``. Generate mask for the given ``weight``.
...@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner): ...@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned The layer to be pruned
config : dict config : dict
Pruning configurations for this weight Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns Returns
------- -------
tensor tensor
The mask for this weight The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
""" """
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training' return None
mask = self.mask_dict[layer.name]
return mask
def get_prune_iterations(self): def get_prune_iterations(self):
""" """
...@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner): ...@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner):
self.curr_prune_iteration += 1 self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations' assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
if wrapper.name == layer.name:
module_wrapper = wrapper
break
assert module_wrapper is not None
sparsity = config.get('sparsity') sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name) mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
self.mask_dict.update({layer.name: mask}) # TODO: directly use weight_mask is not good
self._print_masks() module_wrapper.weight_mask.copy_(mask['weight'])
# there is no mask for bias
# reinit weights back to original after new masks are generated # reinit weights back to original after new masks are generated
if self.reset_weights: if self.reset_weights:
self._model.load_state_dict(self._model_state) # should use this member function to reset model weights
self.load_model_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state) self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None: if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state) self._lr_scheduler.load_state_dict(self._scheduler_state)
...@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner): ...@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() # operations whose mask has been calculated self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked. Filters with the smallest importance criterion of the kernel weights are masked.
...@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner): ...@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d" assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if_calculated = kwargs["if_calculated"]
assert op_name in self.mask_dict if if_calculated:
return self.mask_dict.get(op_name) return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
...@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return mask return mask
mask = self.get_mask(mask, weight, num_prune) mask = self.get_mask(mask, weight, num_prune)
finally: finally:
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name)
return mask return mask
...@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return x.sum() return x.sum()
def update_epoch(self, epoch): def update_epoch(self, epoch):
self.mask_calculated_ops = set() for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
...@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase): ...@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0]) masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1]) masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2 @tf2
...@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase): ...@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()]) model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy() masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0]) masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
...@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase): ...@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase):
model.conv1.weight.data = torch.tensor(w).float() model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1) layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1]) mask2 = pruner.calc_mask(layer2, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.])) assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.])) assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
...@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase): ...@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list) pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
...@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase): ...@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list) pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
...@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase): ...@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase):
# test ema # test ema
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps)
quantizer.step() quantizer.step()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment