"src/include/functional.hpp" did not exist on "5e5c27a63b1637556a17e17546147da6cb6d732e"
Unverified Commit c45c30b6 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

bug bash for sensitivity_pruner (#2815)



* fix bug

* update

* Remove the unnecessary debug info.
Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>

* fix pylint errpr

* log the final sparsity config.

* fix pylint

* Fix typo and remove todo.
Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>
parent b168b016
...@@ -53,7 +53,7 @@ class MobileNet(nn.Module): ...@@ -53,7 +53,7 @@ class MobileNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.features(x) x = self.features(x)
x = x.mean(3).mean(2) # global average pooling x = x.mean([2, 3]) # global average pooling
x = self.classifier(x) x = self.classifier(x)
return x return x
......
...@@ -108,7 +108,10 @@ class MobileNetV2(nn.Module): ...@@ -108,7 +108,10 @@ class MobileNetV2(nn.Module):
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = x.mean(3).mean(2) # it's same with .mean(3).mean(2), but
# speedup only suport the mean option
# whose output only have two dimensions
x = x.mean([2, 3])
x = self.classifier(x) x = self.classifier(x)
return x return x
......
...@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis ...@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis
MAX_PRUNE_RATIO_PER_ITER = 0.95 MAX_PRUNE_RATIO_PER_ITER = 0.95
_logger = logging.getLogger('Sensitivity_Pruner') _logger = logging.getLogger('Sensitivity_Pruner')
_logger.setLevel(logging.INFO)
class SensitivityPruner(Pruner): class SensitivityPruner(Pruner):
""" """
...@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner): ...@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner):
prune_ratios = sorted(sensitivities[layer].keys()) prune_ratios = sorted(sensitivities[layer].keys())
last_ratio = 0 last_ratio = 0
for ratio in prune_ratios: for ratio in prune_ratios:
last_ratio = ratio
cur_acc = sensitivities[layer][ratio] cur_acc = sensitivities[layer][ratio]
if cur_acc + threshold < ori_acc: if cur_acc + threshold < ori_acc:
break break
last_ratio = ratio
max_ratio[layer] = last_ratio max_ratio[layer] = last_ratio
return max_ratio return max_ratio
...@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner): ...@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner):
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune # MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# ratios under this threshold # ratios under this threshold
if _Max > MAX_PRUNE_RATIO_PER_ITER: if _Max > MAX_PRUNE_RATIO_PER_ITER:
for layername in ratios: for layername in ratios:
ratios[layername] = ratios[layername] * \ ratios[layername] = ratios[layername] * \
MAX_PRUNE_RATIO_PER_ITER / _Max MAX_PRUNE_RATIO_PER_ITER / _Max
...@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner): ...@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner):
finetune_kwargs = {} finetune_kwargs = {}
if self.ori_acc is None: if self.ori_acc is None:
self.ori_acc = self.evaluator(*eval_args, **eval_kwargs) self.ori_acc = self.evaluator(*eval_args, **eval_kwargs)
assert isinstance(self.ori_acc, float) or isinstance(self.ori_acc, int)
if not resume_sensitivity: if not resume_sensitivity:
self.sensitivities = self.analyzer.analysis( self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs) val_args=eval_args, val_kwargs=eval_kwargs)
...@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner): ...@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner):
iteration_count = 0 iteration_count = 0
if self.checkpoint_dir is not None: if self.checkpoint_dir is not None:
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
modules_wrapper_final = None
while cur_ratio > target_ratio: while cur_ratio > target_ratio:
iteration_count += 1 iteration_count += 1
# Each round have three steps: # Each round have three steps:
...@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner): ...@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner):
# layers according to the sensitivity result # layers according to the sensitivity result
proportion = self.sparsity_proportion_calc( proportion = self.sparsity_proportion_calc(
ori_acc, self.acc_drop_threshold, self.sensitivities) ori_acc, self.acc_drop_threshold, self.sensitivities)
new_pruneratio = self.normalize(proportion, self.sparsity_per_iter) new_pruneratio = self.normalize(proportion, self.sparsity_per_iter)
cfg_list = self.create_cfg(new_pruneratio) cfg_list = self.create_cfg(new_pruneratio)
if not cfg_list:
_logger.error('The threshold is too small, please set a larger threshold')
return self.model
_logger.debug('Pruner Config: %s', str(cfg_list)) _logger.debug('Pruner Config: %s', str(cfg_list))
cfg_str = ['%s:%.3f'%(cfg['op_names'][0], cfg['sparsity']) for cfg in cfg_list]
_logger.info('Current Sparsities: %s', ','.join(cfg_str))
pruner = self.Pruner(self.model, cfg_list) pruner = self.Pruner(self.model, cfg_list)
pruner.compress() pruner.compress()
pruned_acc = self.evaluator(*eval_args, **eval_kwargs) pruned_acc = self.evaluator(*eval_args, **eval_kwargs)
...@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner): ...@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner):
self.analyzer.already_pruned[name] = sparsity self.analyzer.already_pruned[name] = sparsity
# update the cur_ratio # update the cur_ratio
cur_ratio = 1 - self.current_sparsity() cur_ratio = 1 - self.current_sparsity()
modules_wrapper_final = pruner.get_modules_wrapper()
del pruner del pruner
_logger.info('Currently remained weights: %f', cur_ratio) _logger.info('Currently remained weights: %f', cur_ratio)
...@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner): ...@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner):
with open(cfg_path, 'w') as jf: with open(cfg_path, 'w') as jf:
json.dump(cfg_list, jf) json.dump(cfg_list, jf)
self.analyzer.export(sensitivity_path) self.analyzer.export(sensitivity_path)
if cur_ratio > target_ratio: if cur_ratio > target_ratio:
# If this is the last prune iteration, skip the time-consuming # If this is the last prune iteration, skip the time-consuming
# sensitivity analysis # sensitivity analysis
self.analyzer.load_state_dict(self.model.state_dict()) self.analyzer.load_state_dict(self.model.state_dict())
self.sensitivities = self.analyzer.analysis( self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs) val_args=eval_args, val_kwargs=eval_kwargs)
_logger.info('After Pruning: %.2f weights remains', cur_ratio) _logger.info('After Pruning: %.2f weights remains', cur_ratio)
self.modules_wrapper = modules_wrapper_final
self._wrap_model()
return self.model return self.model
def calc_mask(self, wrapper, **kwargs): def calc_mask(self, wrapper, **kwargs):
......
...@@ -163,7 +163,7 @@ class SensitivityAnalysis: ...@@ -163,7 +163,7 @@ class SensitivityAnalysis:
if val_kwargs is None: if val_kwargs is None:
val_kwargs = {} val_kwargs = {}
# Get the original validation metric(accuracy/loss) before pruning # Get the original validation metric(accuracy/loss) before pruning
if self.ori_metric is None: # Get the accuracy baseline before starting the analysis.
self.ori_metric = self.val_func(*val_args, **val_kwargs) self.ori_metric = self.val_func(*val_args, **val_kwargs)
namelist = list(self.target_layer.keys()) namelist = list(self.target_layer.keys())
if specified_layers is not None: if specified_layers is not None:
...@@ -172,19 +172,21 @@ class SensitivityAnalysis: ...@@ -172,19 +172,21 @@ class SensitivityAnalysis:
for name in namelist: for name in namelist:
self.sensitivities[name] = {} self.sensitivities[name] = {}
for sparsity in self.sparsities: for sparsity in self.sparsities:
# here the sparsity is the relative sparsity of the
# the remained weights
# Calculate the actual prune ratio based on the already pruned ratio # Calculate the actual prune ratio based on the already pruned ratio
sparsity = ( real_sparsity = (
1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name] 1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary # TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly # I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names # according to the op_names
cfg = [{'sparsity': sparsity, 'op_names': [ cfg = [{'sparsity': real_sparsity, 'op_names': [
name], 'op_types': ['Conv2d']}] name], 'op_types': ['Conv2d']}]
pruner = self.Pruner(self.model, cfg) pruner = self.Pruner(self.model, cfg)
pruner.compress() pruner.compress()
val_metric = self.val_func(*val_args, **val_kwargs) val_metric = self.val_func(*val_args, **val_kwargs)
logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f', logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f',
name, sparsity, val_metric) name, real_sparsity, val_metric)
self.sensitivities[name][sparsity] = val_metric self.sensitivities[name][sparsity] = val_metric
pruner._unwrap_model() pruner._unwrap_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment