Unverified Commit b9a7a95d authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #223 from microsoft/master

merge master
parents f9ee589c 0c7f22fb
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.nas.pytorch.spos import SPOSEvolution
from network import ShuffleNetV2OneShot
class EvolutionWithFlops(SPOSEvolution):
"""
This tuner extends the function of evolution tuner, by limiting the flops generated by tuner.
Needs a function to examine the flops.
"""
def __init__(self, flops_limit=330E6, **kwargs):
super().__init__(**kwargs)
self.model = ShuffleNetV2OneShot()
self.flops_limit = flops_limit
def _is_legal(self, cand):
if not super()._is_legal(cand):
return False
if self.model.get_candidate_flops(cand) > self.flops_limit:
return False
return True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def accuracy(output, target, topk=(1, 5)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
...@@ -5,7 +5,8 @@ import logging ...@@ -5,7 +5,8 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner',
'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
...@@ -26,7 +27,7 @@ class LevelPruner(Pruner): ...@@ -26,7 +27,7 @@ class LevelPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.if_init_list = {} self.mask_calculated_ops = set()
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
...@@ -39,22 +40,24 @@ class LevelPruner(Pruner): ...@@ -39,22 +40,24 @@ class LevelPruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name op_name = layer.name
if self.if_init_list.get(op_name, True): if op_name not in self.mask_calculated_ops:
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
return torch.ones(weight.shape).type_as(weight) return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask}) self.mask_dict.update({op_name: mask})
self.if_init_list.update({op_name: False}) self.mask_calculated_ops.add(op_name)
else: else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name] mask = self.mask_dict[op_name]
return mask return mask
...@@ -94,8 +97,8 @@ class AGP_Pruner(Pruner): ...@@ -94,8 +97,8 @@ class AGP_Pruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
...@@ -104,7 +107,7 @@ class AGP_Pruner(Pruner): ...@@ -104,7 +107,7 @@ class AGP_Pruner(Pruner):
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \ if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0: and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight)) mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
...@@ -112,11 +115,11 @@ class AGP_Pruner(Pruner): ...@@ -112,11 +115,11 @@ class AGP_Pruner(Pruner):
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = torch.gt(w_abs, threshold).type_as(weight) new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask}) self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False}) self.if_init_list.update({op_name: False})
else: else:
new_mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight)) new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
return new_mask return new_mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
...@@ -208,8 +211,8 @@ class SlimPruner(Pruner): ...@@ -208,8 +211,8 @@ class SlimPruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
...@@ -219,10 +222,17 @@ class SlimPruner(Pruner): ...@@ -219,10 +222,17 @@ class SlimPruner(Pruner):
if op_name in self.mask_calculated_ops: if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict assert op_name in self.mask_dict
return self.mask_dict.get(op_name) return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight) base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try: try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
w_abs = weight.abs() w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally: finally:
self.mask_dict.update({layer.name: mask}) self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name) self.mask_calculated_ops.add(layer.name)
...@@ -230,7 +240,7 @@ class SlimPruner(Pruner): ...@@ -230,7 +240,7 @@ class SlimPruner(Pruner):
return mask return mask
class RankFilterPruner(Pruner): class WeightRankFilterPruner(Pruner):
""" """
A structured pruning base class that prunes the filters with the smallest A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity. importance criterion in convolution layers to achieve a preset level of network sparsity.
...@@ -248,10 +258,10 @@ class RankFilterPruner(Pruner): ...@@ -248,10 +258,10 @@ class RankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.mask_calculated_ops = set() # operations whose mask has been calculated
def _get_mask(self, base_mask, weight, num_prune): def _get_mask(self, base_mask, weight, num_prune):
return torch.ones(weight.size()).type_as(weight) return {'weight': None, 'bias': None}
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
...@@ -265,20 +275,25 @@ class RankFilterPruner(Pruner): ...@@ -265,20 +275,25 @@ class RankFilterPruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1 assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'] assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict assert op_name in self.mask_dict
return self.mask_dict.get(op_name) return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight) mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
try: try:
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * config.get('sparsity')) num_prune = int(filters * config.get('sparsity'))
...@@ -288,10 +303,10 @@ class RankFilterPruner(Pruner): ...@@ -288,10 +303,10 @@ class RankFilterPruner(Pruner):
finally: finally:
self.mask_dict.update({op_name: mask}) self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name) self.mask_calculated_ops.add(op_name)
return mask.detach() return mask
class L1FilterPruner(RankFilterPruner): class L1FilterPruner(WeightRankFilterPruner):
""" """
A structured pruning algorithm that prunes the filters of smallest magnitude A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity. weights sum in the convolution layers to achieve a preset level of network sparsity.
...@@ -319,31 +334,33 @@ class L1FilterPruner(RankFilterPruner): ...@@ -319,31 +334,33 @@ class L1FilterPruner(RankFilterPruner):
Filters with the smallest sum of its absolute kernel weights are masked. Filters with the smallest sum of its absolute kernel weights are masked.
Parameters Parameters
---------- ----------
base_mask : torch.Tensor base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1. The basic mask with the same shape of weight or bias, all item in the basic mask is 1.
weight : torch.Tensor weight : torch.Tensor
Layer's weight Layer's weight
num_prune : int num_prune : int
Num of filters to prune Num of filters to prune
Returns Returns
------- -------
torch.Tensor dict
Mask of the layer's weight dictionary for storing masks
""" """
filters = weight.shape[0] filters = weight.shape[0]
w_abs = weight.abs() w_abs = weight.abs()
w_abs_structured = w_abs.view(filters, -1).sum(dim=1) w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max() threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight)
return mask return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
class L2FilterPruner(RankFilterPruner): class L2FilterPruner(WeightRankFilterPruner):
""" """
A structured pruning algorithm that prunes the filters with the A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the absolute kernel weights are masked. smallest L2 norm of the weights.
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
...@@ -365,27 +382,28 @@ class L2FilterPruner(RankFilterPruner): ...@@ -365,27 +382,28 @@ class L2FilterPruner(RankFilterPruner):
Filters with the smallest L2 norm of the absolute kernel weights are masked. Filters with the smallest L2 norm of the absolute kernel weights are masked.
Parameters Parameters
---------- ----------
base_mask : torch.Tensor base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1. The basic mask with the same shape of weight or bias, all item in the basic mask is 1.
weight : torch.Tensor weight : torch.Tensor
Layer's weight Layer's weight
num_prune : int num_prune : int
Num of filters to prune Num of filters to prune
Returns Returns
------- -------
torch.Tensor dict
Mask of the layer's weight dictionary for storing masks
""" """
filters = weight.shape[0] filters = weight.shape[0]
w = weight.view(filters, -1) w = weight.view(filters, -1)
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1)) w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max() threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight)
return mask return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
class FPGMPruner(RankFilterPruner): class FPGMPruner(WeightRankFilterPruner):
""" """
A filter pruner via geometric median. A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
...@@ -410,20 +428,22 @@ class FPGMPruner(RankFilterPruner): ...@@ -410,20 +428,22 @@ class FPGMPruner(RankFilterPruner):
Filters with the smallest sum of its absolute kernel weights are masked. Filters with the smallest sum of its absolute kernel weights are masked.
Parameters Parameters
---------- ----------
base_mask : torch.Tensor base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1. The basic mask with the same shape of weight and bias, all item in the basic mask is 1.
weight : torch.Tensor weight : torch.Tensor
Layer's weight Layer's weight
num_prune : int num_prune : int
Num of filters to prune Num of filters to prune
Returns Returns
------- -------
torch.Tensor dict
Mask of the layer's weight dictionary for storing masks
""" """
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx: for idx in min_gm_idx:
base_mask[idx] = 0. base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask return base_mask
def _get_min_gm_kernel_idx(self, weight, n): def _get_min_gm_kernel_idx(self, weight, n):
...@@ -471,3 +491,251 @@ class FPGMPruner(RankFilterPruner): ...@@ -471,3 +491,251 @@ class FPGMPruner(RankFilterPruner):
def update_epoch(self, epoch): def update_epoch(self, epoch):
self.mask_calculated_ops = set() self.mask_calculated_ops = set()
class ActivationRankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook)
return self.bound_model
def _get_mask(self, base_mask, activations, num_prune):
return {'weight': None, 'bias': None}
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or len(self.collected_activation[layer.name]) < self.statistics_batch_num:
return mask
mask = self._get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask
class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def _get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _calc_apoz(self, activations):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations = torch.cat(activations, 0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
_apoz = torch.sum(_eq_zero, dim=(0, 2, 3)) / torch.numel(_eq_zero[:, 0, :, :])
return _apoz
class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def _get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _cal_mean_activation(self, activations):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations = torch.cat(activations, 0)
mean_activation = torch.mean(activations, dim=(0, 2, 3))
return mean_activation
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import logging import logging
import torch import torch
from .compressor import Quantizer from .compressor import Quantizer, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer'] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
...@@ -240,4 +240,34 @@ class DoReFaQuantizer(Quantizer): ...@@ -240,4 +240,34 @@ class DoReFaQuantizer(Quantizer):
def quantize(self, input_ri, q_bits): def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1 scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale output = torch.round(input_ri*scale)/scale
return output return output
\ No newline at end of file
class ClipGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type):
if quant_type == QuantType.QUANT_OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
class BNNQuantizer(Quantizer):
"""Binarized Neural Networks, as defined in:
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = ClipGrad
def quantize_weight(self, weight, config, **kwargs):
out = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
def quantize_output(self, output, config, **kwargs):
out = torch.sign(output)
# remove zeros
out[out == 0] = 1
return out
...@@ -16,6 +16,7 @@ class LayerInfo: ...@@ -16,6 +16,7 @@ class LayerInfo:
self._forward = None self._forward = None
class Compressor: class Compressor:
""" """
Abstract base PyTorch compressor Abstract base PyTorch compressor
...@@ -193,10 +194,16 @@ class Pruner(Compressor): ...@@ -193,10 +194,16 @@ class Pruner(Compressor):
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*inputs): def new_forward(*inputs):
mask = self.calc_mask(layer, config)
# apply mask to weight # apply mask to weight
old_weight = layer.module.weight.data old_weight = layer.module.weight.data
mask = self.calc_mask(layer, config) mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask) layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward # calculate forward
ret = layer._forward(*inputs) ret = layer._forward(*inputs)
return ret return ret
...@@ -224,12 +231,14 @@ class Pruner(Compressor): ...@@ -224,12 +231,14 @@ class Pruner(Compressor):
for name, m in self.bound_model.named_modules(): for name, m in self.bound_model.named_modules():
if name == "": if name == "":
continue continue
mask = self.mask_dict.get(name) masks = self.mask_dict.get(name)
if mask is not None: if masks is not None:
mask_sum = mask.sum().item() mask_sum = masks['weight'].sum().item()
mask_num = mask.numel() mask_num = masks['weight'].numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num) _logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
m.weight.data = m.weight.data.mul(mask) m.weight.data = m.weight.data.mul(masks['weight'])
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None:
m.bias.data = m.bias.data.mul(masks['bias'])
else: else:
_logger.info('Layer: %s NOT compressed', name) _logger.info('Layer: %s NOT compressed', name)
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
...@@ -258,7 +267,6 @@ class Quantizer(Compressor): ...@@ -258,7 +267,6 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize weight. quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
weight : Tensor weight : Tensor
...@@ -272,7 +280,6 @@ class Quantizer(Compressor): ...@@ -272,7 +280,6 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize output. quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
output : Tensor output : Tensor
...@@ -286,7 +293,6 @@ class Quantizer(Compressor): ...@@ -286,7 +293,6 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize input. quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
...@@ -300,7 +306,6 @@ class Quantizer(Compressor): ...@@ -300,7 +306,6 @@ class Quantizer(Compressor):
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper forward function to replace the original one.
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
...@@ -365,7 +370,6 @@ class QuantGrad(torch.autograd.Function): ...@@ -365,7 +370,6 @@ class QuantGrad(torch.autograd.Function):
""" """
This method should be overrided by subclass to provide customized backward function, This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator default implementation is Straight-Through Estimator
Parameters Parameters
---------- ----------
tensor : Tensor tensor : Tensor
...@@ -375,7 +379,6 @@ class QuantGrad(torch.autograd.Function): ...@@ -375,7 +379,6 @@ class QuantGrad(torch.autograd.Function):
quant_type : QuantType quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types. you can define different behavior for different types.
Returns Returns
------- -------
tensor tensor
...@@ -399,3 +402,4 @@ def _check_weight(module): ...@@ -399,3 +402,4 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
\ No newline at end of file
...@@ -17,6 +17,7 @@ class LotteryTicketPruner(Pruner): ...@@ -17,6 +17,7 @@ class LotteryTicketPruner(Pruner):
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0). 4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4. 5. Repeat step 2, 3, and 4.
""" """
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True): def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
""" """
Parameters Parameters
...@@ -55,7 +56,8 @@ class LotteryTicketPruner(Pruner): ...@@ -55,7 +56,8 @@ class LotteryTicketPruner(Pruner):
assert 'prune_iterations' in config, 'prune_iterations must exist in your config' assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config' assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None: if prune_iterations is not None:
assert prune_iterations == config['prune_iterations'], 'The values of prune_iterations must be equal in your config' assert prune_iterations == config[
'prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations'] prune_iterations = config['prune_iterations']
return prune_iterations return prune_iterations
...@@ -67,8 +69,8 @@ class LotteryTicketPruner(Pruner): ...@@ -67,8 +69,8 @@ class LotteryTicketPruner(Pruner):
if print_mask: if print_mask:
print('mask: ', mask) print('mask: ', mask)
# calculate current sparsity # calculate current sparsity
mask_num = mask.sum().item() mask_num = mask['weight'].sum().item()
mask_size = mask.numel() mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size) print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default') torch.set_printoptions(profile='default')
...@@ -84,11 +86,11 @@ class LotteryTicketPruner(Pruner): ...@@ -84,11 +86,11 @@ class LotteryTicketPruner(Pruner):
curr_sparsity = self._calc_sparsity(sparsity) curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name) curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask w_abs = weight.abs() * curr_mask['weight']
k = int(w_abs.numel() * curr_sparsity) k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask = torch.gt(w_abs, threshold).type_as(weight)
return mask return {'weight': mask}
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .evolution import SPOSEvolution
from .mutator import SPOSSupernetTrainingMutator
from .trainer import SPOSSupernetTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import re
from collections import deque
import numpy as np
from nni.tuner import Tuner
from nni.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
_logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner):
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
"""
Initialize SPOS Evolution Tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
assert num_population >= num_select
self.max_epochs = max_epochs
self.num_select = num_select
self.num_population = num_population
self.m_prob = m_prob
self.num_crossover = num_crossover
self.num_mutation = num_mutation
self.epoch = 0
self.candidates = []
self.search_space = None
self.random_state = np.random.RandomState(0)
# async status
self._to_evaluate_queue = deque()
self._sending_parameter_queue = deque()
self._pending_result_ids = set()
self._reward_dict = dict()
self._id2candidate = dict()
self._st_callback = None
def update_search_space(self, search_space):
"""
Handle the initialization/update event of search space.
"""
self._search_space = search_space
self._next_round()
def _next_round(self):
_logger.info("Epoch %d, generating...", self.epoch)
if self.epoch == 0:
self._get_random_population()
self.export_results(self.candidates)
else:
best_candidates = self._select_top_candidates()
self.export_results(best_candidates)
if self.epoch >= self.max_epochs:
return
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
self._get_random_population()
self.epoch += 1
def _random_candidate(self):
chosen_arch = dict()
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
index = self.random_state.randint(len(choices))
chosen_arch[key] = {"_value": choices[index], "_idx": index}
elif val["_type"] == INPUT_CHOICE:
raise NotImplementedError("Input choice is not implemented yet.")
return chosen_arch
def _add_to_evaluate_queue(self, cand):
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
self._reward_dict[self._hashcode(cand)] = 0.
self._to_evaluate_queue.append(cand)
def _get_random_population(self):
while len(self.candidates) < self.num_population:
cand = self._random_candidate()
if self._is_legal(cand):
_logger.info("Random candidate generated.")
self._add_to_evaluate_queue(cand)
self.candidates.append(cand)
def _get_crossover(self, best):
result = []
for _ in range(10 * self.num_crossover):
cand_p1 = best[self.random_state.randint(len(best))]
cand_p2 = best[self.random_state.randint(len(best))]
assert cand_p1.keys() == cand_p2.keys()
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
for k in cand_p1.keys()}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_crossover:
break
_logger.info("Found %d architectures with crossover.", len(result))
return result
def _get_mutation(self, best):
result = []
for _ in range(10 * self.num_mutation):
cand = best[self.random_state.randint(len(best))].copy()
mutation_sample = np.random.random_sample(len(cand))
for s, k in zip(mutation_sample, cand):
if s < self.m_prob:
choices = self._search_space[k]["_value"]
index = self.random_state.randint(len(choices))
cand[k] = {"_value": choices[index], "_idx": index}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_mutation:
break
_logger.info("Found %d architectures with mutation.", len(result))
return result
def _get_architecture_repr(self, cand):
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
self._hashcode(cand))
def _is_legal(self, cand):
if self._hashcode(cand) in self._reward_dict:
return False
return True
def _select_top_candidates(self):
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
return result
@staticmethod
def _hashcode(d):
return json.dumps(d, sort_keys=True)
def _bind_and_send_parameters(self):
"""
There are two types of resources: parameter ids and candidates. This function is called at
necessary times to bind these resources to send new trials with st_callback.
"""
result = []
while self._sending_parameter_queue and self._to_evaluate_queue:
parameter_id = self._sending_parameter_queue.popleft()
parameters = self._to_evaluate_queue.popleft()
self._id2candidate[parameter_id] = parameters
result.append(parameters)
self._pending_result_ids.add(parameter_id)
self._st_callback(parameter_id, parameters)
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
return result
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Callback function necessary to implement a tuner. This will put more parameter ids into the
parameter id queue.
"""
if "st_callback" in kwargs and self._st_callback is None:
self._st_callback = kwargs["st_callback"]
for parameter_id in parameter_id_list:
self._sending_parameter_queue.append(parameter_id)
self._bind_and_send_parameters()
return [] # always not use this. might induce problem of over-sending
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Callback function. Receive a trial result.
"""
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
def trial_end(self, parameter_id, success, **kwargs):
"""
Callback function when a trial is ended and resource is released.
"""
self._pending_result_ids.remove(parameter_id)
if not self._pending_result_ids and not self._to_evaluate_queue:
# a new epoch now
self._next_round()
assert self._st_callback is not None
self._bind_and_send_parameters()
def export_results(self, result):
"""
Export a number of candidates to `checkpoints` dir.
Parameters
----------
result : dict
"""
os.makedirs("checkpoints", exist_ok=True)
for i, cand in enumerate(result):
converted = dict()
for cand_key, cand_val in cand.items():
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
converted[cand_key] = onehot
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
json.dump(converted, fp)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
from nni.nas.pytorch.random import RandomMutator
_logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator):
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
"""
Parameters
----------
model : nn.Module
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
super().__init__(model)
self._flops_func = flops_func
if self._flops_func is not None:
self._flops_bin_num = flops_bin_num
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
self._flops_sample_timeout = flops_sample_timeout
def sample_search(self):
"""
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
relative to flops.
Returns
-------
dict
"""
if self._flops_func is not None:
for times in range(self._flops_sample_timeout):
idx = np.random.randint(self._flops_bin_num)
cand = super().sample_search()
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
return cand
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
return super().sample_search()
def sample_final(self):
"""
Implement only to suffice the interface of Mutator.
"""
return self.sample_search()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import SPOSSupernetTrainingMutator
logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer):
"""
This trainer trains a supernet that can be used for evolution search.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None,
batch_size, workers, device, log_frequency, callbacks)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
...@@ -136,12 +136,12 @@ class CompressorTestCase(TestCase): ...@@ -136,12 +136,12 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0]) masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1) pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1]) masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2 @tf2
def test_tf_fpgm_pruner(self): def test_tf_fpgm_pruner(self):
...@@ -190,8 +190,8 @@ class CompressorTestCase(TestCase): ...@@ -190,8 +190,8 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1]) mask2 = pruner.calc_mask(layer2, config_list[1])
assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.])) assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.])) assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
def test_torch_slim_pruner(self): def test_torch_slim_pruner(self):
""" """
...@@ -218,8 +218,10 @@ class CompressorTestCase(TestCase): ...@@ -218,8 +218,10 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float() model.bn1.weight.data = torch.tensor(w).float()
...@@ -230,8 +232,10 @@ class CompressorTestCase(TestCase): ...@@ -230,8 +232,10 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
def test_torch_QAT_quantizer(self): def test_torch_QAT_quantizer(self):
model = TorchModel() model = TorchModel()
......
...@@ -65,7 +65,7 @@ tuner_schema_dict = { ...@@ -65,7 +65,7 @@ tuner_schema_dict = {
'builtinTunerName': 'SMAC', 'builtinTunerName': 'SMAC',
Optional('classArgs'): { Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'), 'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
'config_dedup': setType('config_dedup', bool) Optional('config_dedup'): setType('config_dedup', bool)
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment