Commit ac6f420f authored by Tang Lang's avatar Tang Lang Committed by chicm-ms
Browse files

Pruners refactor (#1820)

parent 9484efb5
......@@ -4,59 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner
class vgg(nn.Module):
def __init__(self, init_weights=True):
super(vgg, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.cfg = cfg
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
if init_weights:
self._initialize_weights()
def make_layers(self, cfg, batch_norm=True):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer):
......@@ -111,7 +59,7 @@ def main():
])),
batch_size=200, shuffle=False)
model = vgg()
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
......@@ -162,7 +110,7 @@ def main():
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = vgg()
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
......
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
defaultcfg = {
11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}
class VGG(nn.Module):
def __init__(self, depth=16):
super(VGG, self).__init__()
cfg = defaultcfg[depth]
self.cfg = cfg
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
self._initialize_weights()
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
......@@ -5,59 +5,7 @@ import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner
from knowledge_distill.knowledge_distill import KnowledgeDistill
class vgg(nn.Module):
def __init__(self, init_weights=True):
super(vgg, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.cfg = cfg
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
if init_weights:
self._initialize_weights()
def make_layers(self, cfg, batch_norm=True):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer, kd=None):
......@@ -119,7 +67,7 @@ def main():
])),
batch_size=200, shuffle=False)
model = vgg()
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
......@@ -156,7 +104,7 @@ def main():
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
kd_teacher_model = vgg()
kd_teacher_model = VGG(depth=16)
kd_teacher_model.to(device)
kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth'))
kd = KnowledgeDistill(kd_teacher_model, kd_T=5)
......@@ -173,7 +121,7 @@ def main():
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = vgg()
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
......
......@@ -4,53 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner
class vgg(nn.Module):
def __init__(self, init_weights=True):
super(vgg, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Linear(cfg[-1], num_classes)
if init_weights:
self._initialize_weights()
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
from models.cifar10.vgg import VGG
def updateBN(model):
......@@ -114,7 +68,7 @@ def main():
])),
batch_size=200, shuffle=False)
model = vgg()
model = VGG(depth=19)
model.to(device)
# Train the base VGG-19 model
......@@ -165,7 +119,7 @@ def main():
# Test the exported model
print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10)
new_model = vgg()
new_model = VGG(depth=19)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
test(new_model, device, test_loader)
......
......@@ -5,7 +5,7 @@ import logging
import torch
from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner', 'L1FilterPruner', 'SlimPruner']
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
logger = logging.getLogger('torch pruner')
......@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner):
self.if_init_list[k] = True
class FPGMPruner(Pruner):
class SlimPruner(Pruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_dict = {}
self.epoch_pruned_layers = set()
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config):
"""
Supports Conv1d, Conv2d
filter dimensions for Conv1d:
OUT: number of output channel
IN: number of input channel
LEN: filter length
filter dimensions for Conv2d:
OUT: number of output channel
IN: number of input channel
H: filter height
W: filter width
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
the layer to instrument the compression operation
config : dict
the configuration for generating the mask
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
weight = layer.module.weight.data
assert 0 <= config.get('sparsity') < 1
assert layer.type in ['Conv1d', 'Conv2d']
assert layer.type in config['op_types']
op_name = layer.name
op_type = layer.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
if layer.name in self.epoch_pruned_layers:
assert layer.name in self.mask_dict
return self.mask_dict.get(layer.name)
return mask
masks = torch.ones(weight.size()).type_as(weight)
try:
num_filters = weight.size(0)
num_prune = int(num_filters * config.get('sparsity'))
if num_filters < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
masks[idx] = 0.
finally:
self.mask_dict.update({layer.name: masks})
self.epoch_pruned_layers.add(layer.name)
class RankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""
return masks
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
super().__init__(model, config_list)
self.mask_calculated_ops = set()
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_mask(self, base_mask, weight, num_prune):
return torch.ones(weight.size()).type_as(weight)
def _get_distance_sum(self, weight, out_idx):
def calc_mask(self, layer, config):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
float32
The total distance
torch.Tensor
mask of the layer's weight
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
self.epoch_pruned_layers = set()
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1
assert op_type in ['Conv1d', 'Conv2d']
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
mask = self._get_mask(mask, weight, num_prune)
finally:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask.detach()
class L1FilterPruner(Pruner):
class L1FilterPruner(RankFilterPruner):
"""
A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity.
......@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
def calc_mask(self, layer, config):
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
mask of the layer's weight
Mask of the layer's weight
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'Conv2d', 'L1FilterPruner only supports 2d convolution layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.shape[0]
w_abs = weight.abs()
k = int(filters * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask
class SlimPruner(Pruner):
class L2FilterPruner(RankFilterPruner):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the absolute kernel weights are masked.
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config):
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Filters with the smallest L2 norm of the absolute kernel weights are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
mask of the layer's weight
Mask of the layer's weight
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
filters = weight.shape[0]
w = weight.view(filters, -1)
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
return mask
class FPGMPruner(RankFilterPruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
Mask of the layer's weight
"""
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask[idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
self.mask_calculated_ops = set()
......@@ -58,8 +58,9 @@ def tf2(func):
return test_tf2_func
# for fpgm filter pruner test
w = np.array([[[[i+1]*3]*3]*5 for i in range(10)])
w = np.array([[[[i + 1] * 3] * 3] * 5 for i in range(10)])
class CompressorTestCase(TestCase):
......@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase):
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types':['Conv2d', 'Linear']
'op_types': ['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 0,
'op_types':['ReLU']
'op_types': ['ReLU']
}]
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [ t[0].name for t in modules_to_compress]
modules_to_compress_name = [t[0].name for t in modules_to_compress]
assert "conv1" in modules_to_compress_name
assert "conv2" in modules_to_compress_name
assert "fc1" in modules_to_compress_name
......@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase):
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}]
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
{'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.weight.data = torch.tensor(w).float()
......@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase):
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types':['Conv2d', 'Linear']
'op_types': ['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 0,
'op_types':['ReLU']
'op_types': ['ReLU']
}]
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
......@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment