Commit ac6f420f authored by Tang Lang's avatar Tang Lang Committed by chicm-ms
Browse files

Pruners refactor (#1820)

parent 9484efb5
...@@ -4,59 +4,7 @@ import torch.nn as nn ...@@ -4,59 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner from nni.compression.torch import L1FilterPruner
from models.cifar10.vgg import VGG
class vgg(nn.Module):
def __init__(self, init_weights=True):
super(vgg, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.cfg = cfg
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
if init_weights:
self._initialize_weights()
def make_layers(self, cfg, batch_norm=True):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def train(model, device, train_loader, optimizer): def train(model, device, train_loader, optimizer):
...@@ -111,7 +59,7 @@ def main(): ...@@ -111,7 +59,7 @@ def main():
])), ])),
batch_size=200, shuffle=False) batch_size=200, shuffle=False)
model = vgg() model = VGG(depth=16)
model.to(device) model.to(device)
# Train the base VGG-16 model # Train the base VGG-16 model
...@@ -162,7 +110,7 @@ def main(): ...@@ -162,7 +110,7 @@ def main():
# Test the exported model # Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10) print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = vgg() new_model = VGG(depth=16)
new_model.to(device) new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth')) new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader) test(new_model, device, test_loader)
......
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
defaultcfg = {
11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}
class VGG(nn.Module):
def __init__(self, depth=16):
super(VGG, self).__init__()
cfg = defaultcfg[depth]
self.cfg = cfg
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
self._initialize_weights()
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
...@@ -5,59 +5,7 @@ import torch.nn.functional as F ...@@ -5,59 +5,7 @@ import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner from nni.compression.torch import L1FilterPruner
from knowledge_distill.knowledge_distill import KnowledgeDistill from knowledge_distill.knowledge_distill import KnowledgeDistill
from models.cifar10.vgg import VGG
class vgg(nn.Module):
def __init__(self, init_weights=True):
super(vgg, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.cfg = cfg
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
if init_weights:
self._initialize_weights()
def make_layers(self, cfg, batch_norm=True):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def train(model, device, train_loader, optimizer, kd=None): def train(model, device, train_loader, optimizer, kd=None):
...@@ -119,7 +67,7 @@ def main(): ...@@ -119,7 +67,7 @@ def main():
])), ])),
batch_size=200, shuffle=False) batch_size=200, shuffle=False)
model = vgg() model = VGG(depth=16)
model.to(device) model.to(device)
# Train the base VGG-16 model # Train the base VGG-16 model
...@@ -156,7 +104,7 @@ def main(): ...@@ -156,7 +104,7 @@ def main():
print('=' * 10 + 'Fine tuning' + '=' * 10) print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0 best_top1 = 0
kd_teacher_model = vgg() kd_teacher_model = VGG(depth=16)
kd_teacher_model.to(device) kd_teacher_model.to(device)
kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth')) kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth'))
kd = KnowledgeDistill(kd_teacher_model, kd_T=5) kd = KnowledgeDistill(kd_teacher_model, kd_T=5)
...@@ -173,7 +121,7 @@ def main(): ...@@ -173,7 +121,7 @@ def main():
# Test the exported model # Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10) print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = vgg() new_model = VGG(depth=16)
new_model.to(device) new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth')) new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader) test(new_model, device, test_loader)
......
...@@ -4,53 +4,7 @@ import torch.nn as nn ...@@ -4,53 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner from nni.compression.torch import SlimPruner
from models.cifar10.vgg import VGG
class vgg(nn.Module):
def __init__(self, init_weights=True):
super(vgg, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Linear(cfg[-1], num_classes)
if init_weights:
self._initialize_weights()
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def updateBN(model): def updateBN(model):
...@@ -114,7 +68,7 @@ def main(): ...@@ -114,7 +68,7 @@ def main():
])), ])),
batch_size=200, shuffle=False) batch_size=200, shuffle=False)
model = vgg() model = VGG(depth=19)
model.to(device) model.to(device)
# Train the base VGG-19 model # Train the base VGG-19 model
...@@ -165,7 +119,7 @@ def main(): ...@@ -165,7 +119,7 @@ def main():
# Test the exported model # Test the exported model
print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10) print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10)
new_model = vgg() new_model = VGG(depth=19)
new_model.to(device) new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth')) new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
test(new_model, device, test_loader) test(new_model, device, test_loader)
......
...@@ -5,7 +5,7 @@ import logging ...@@ -5,7 +5,7 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner', 'L1FilterPruner', 'SlimPruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
...@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner): ...@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner):
self.if_init_list[k] = True self.if_init_list[k] = True
class FPGMPruner(Pruner): class SlimPruner(Pruner):
""" """
A filter pruner via geometric median. A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
https://arxiv.org/pdf/1811.00250.pdf "Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
""" """
Parameters Parameters
---------- ----------
model : pytorch model config_list : list
the model user wants to compress
config_list: list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_dict = {} self.mask_calculated_ops = set()
self.epoch_pruned_layers = set() weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
Supports Conv1d, Conv2d Calculate the mask of given layer.
filter dimensions for Conv1d: Scale factors with the smallest absolute value in the BN layer are masked.
OUT: number of output channel
IN: number of input channel
LEN: filter length
filter dimensions for Conv2d:
OUT: number of output channel
IN: number of input channel
H: filter height
W: filter width
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
calculate mask for `layer`'s weight the layer to instrument the compression operation
config : dict config : dict
the configuration for generating the mask layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
assert 0 <= config.get('sparsity') < 1 op_name = layer.name
assert layer.type in ['Conv1d', 'Conv2d'] op_type = layer.type
assert layer.type in config['op_types'] assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
if layer.name in self.epoch_pruned_layers: return mask
assert layer.name in self.mask_dict
return self.mask_dict.get(layer.name)
masks = torch.ones(weight.size()).type_as(weight)
try: class RankFilterPruner(Pruner):
num_filters = weight.size(0) """
num_prune = int(num_filters * config.get('sparsity')) A structured pruning base class that prunes the filters with the smallest
if num_filters < 2 or num_prune < 1: importance criterion in convolution layers to achieve a preset level of network sparsity.
return masks """
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
masks[idx] = 0.
finally:
self.mask_dict.update({layer.name: masks})
self.epoch_pruned_layers.add(layer.name)
return masks def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
def _get_min_gm_kernel_idx(self, weight, n): super().__init__(model, config_list)
assert len(weight.size()) in [3, 4] self.mask_calculated_ops = set()
dist_list = [] def _get_mask(self, base_mask, weight, num_prune):
for out_i in range(weight.size(0)): return torch.ones(weight.size()).type_as(weight)
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx): def calc_mask(self, layer, config):
""" """
Calculate the total distance between a specified filter (by out_idex and in_idx) and Calculate the mask of given layer.
all other filters. Filters with the smallest importance criterion of the kernel weights are masked.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters Parameters
---------- ----------
weight: Tensor layer : LayerInfo
convolutional filter weight the layer to instrument the compression operation
out_idx: int config : dict
output channel index of specified filter, this method calculates the total distance layer's pruning config
between this specified filter and all other filters.
Returns Returns
------- -------
float32 torch.Tensor
The total distance mask of the layer's weight
""" """
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch): weight = layer.module.weight.data
self.epoch_pruned_layers = set() op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1
assert op_type in ['Conv1d', 'Conv2d']
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
mask = self._get_mask(mask, weight, num_prune)
finally:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask.detach()
class L1FilterPruner(Pruner): class L1FilterPruner(RankFilterPruner):
""" """
A structured pruning algorithm that prunes the filters of smallest magnitude A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity. weights sum in the convolution layers to achieve a preset level of network sparsity.
...@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner): ...@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set()
def calc_mask(self, layer, config): def _get_mask(self, base_mask, weight, num_prune):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked. Filters with the smallest sum of its absolute kernel weights are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo base_mask : torch.Tensor
the layer to instrument the compression operation The basic mask with the same shape of weight, all item in the basic mask is 1.
config : dict weight : torch.Tensor
layer's pruning config Layer's weight
num_prune : int
Num of filters to prune
Returns Returns
------- -------
torch.Tensor torch.Tensor
mask of the layer's weight Mask of the layer's weight
""" """
weight = layer.module.weight.data filters = weight.shape[0]
op_name = layer.name w_abs = weight.abs()
op_type = layer.type w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
assert op_type == 'Conv2d', 'L1FilterPruner only supports 2d convolution layer pruning' threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
if op_name in self.mask_calculated_ops: mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.shape[0]
w_abs = weight.abs()
k = int(filters * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask return mask
class SlimPruner(Pruner): class L2FilterPruner(RankFilterPruner):
""" """
A structured pruning algorithm that prunes channels by pruning the weights of BN layers. A structured pruning algorithm that prunes the filters with the
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang smallest L2 norm of the absolute kernel weights are masked.
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module
Model to be pruned
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config): def _get_mask(self, base_mask, weight, num_prune):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked. Filters with the smallest L2 norm of the absolute kernel weights are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo base_mask : torch.Tensor
the layer to instrument the compression operation The basic mask with the same shape of weight, all item in the basic mask is 1.
config : dict weight : torch.Tensor
layer's pruning config Layer's weight
num_prune : int
Num of filters to prune
Returns Returns
------- -------
torch.Tensor torch.Tensor
mask of the layer's weight Mask of the layer's weight
""" """
filters = weight.shape[0]
weight = layer.module.weight.data w = weight.view(filters, -1)
op_name = layer.name w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
op_type = layer.type threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' mask = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask return mask
class FPGMPruner(RankFilterPruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
Mask of the layer's weight
"""
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask[idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
self.mask_calculated_ops = set()
...@@ -58,8 +58,9 @@ def tf2(func): ...@@ -58,8 +58,9 @@ def tf2(func):
return test_tf2_func return test_tf2_func
# for fpgm filter pruner test # for fpgm filter pruner test
w = np.array([[[[i+1]*3]*3]*5 for i in range(10)]) w = np.array([[[[i + 1] * 3] * 3] * 5 for i in range(10)])
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
...@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase): ...@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase):
config_list = [{ config_list = [{
'quant_types': ['weight'], 'quant_types': ['weight'],
'quant_bits': 8, 'quant_bits': 8,
'op_types':['Conv2d', 'Linear'] 'op_types': ['Conv2d', 'Linear']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
'quant_bits': 8, 'quant_bits': 8,
'quant_start_step': 0, 'quant_start_step': 0,
'op_types':['ReLU'] 'op_types': ['ReLU']
}] }]
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list) quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress() quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress() modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [ t[0].name for t in modules_to_compress] modules_to_compress_name = [t[0].name for t in modules_to_compress]
assert "conv1" in modules_to_compress_name assert "conv1" in modules_to_compress_name
assert "conv2" in modules_to_compress_name assert "conv2" in modules_to_compress_name
assert "fc1" in modules_to_compress_name assert "fc1" in modules_to_compress_name
...@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase): ...@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase):
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2, w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4]) np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel() model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}] config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
{'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list) pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.weight.data = torch.tensor(w).float() model.conv1.weight.data = torch.tensor(w).float()
...@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase): ...@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase):
config_list = [{ config_list = [{
'quant_types': ['weight'], 'quant_types': ['weight'],
'quant_bits': 8, 'quant_bits': 8,
'op_types':['Conv2d', 'Linear'] 'op_types': ['Conv2d', 'Linear']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
'quant_bits': 8, 'quant_bits': 8,
'quant_start_step': 0, 'quant_start_step': 0,
'op_types':['ReLU'] 'op_types': ['ReLU']
}] }]
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list) quantizer = torch_compressor.QAT_Quantizer(model, config_list)
...@@ -253,7 +255,7 @@ class CompressorTestCase(TestCase): ...@@ -253,7 +255,7 @@ class CompressorTestCase(TestCase):
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2) quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 5 / 255, abs_tol=eps) assert math.isclose(model.conv2.scale, 5 / 255, abs_tol=eps)
assert model.conv2.zero_point == 0 assert model.conv2.zero_point == 0
# range including 0 # range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float() weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2) quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 6 / 255, abs_tol=eps) assert math.isclose(model.conv2.scale, 6 / 255, abs_tol=eps)
...@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase): ...@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps) assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment