Unverified Commit c45c30b6 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

bug bash for sensitivity_pruner (#2815)



* fix bug

* update

* Remove the unnecessary debug info.
Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>

* fix pylint errpr

* log the final sparsity config.

* fix pylint

* Fix typo and remove todo.
Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>
parent b168b016
......@@ -53,7 +53,7 @@ class MobileNet(nn.Module):
def forward(self, x):
x = self.conv1(x)
x = self.features(x)
x = x.mean(3).mean(2) # global average pooling
x = x.mean([2, 3]) # global average pooling
x = self.classifier(x)
return x
......
......@@ -108,7 +108,10 @@ class MobileNetV2(nn.Module):
def forward(self, x):
x = self.features(x)
x = x.mean(3).mean(2)
# it's same with .mean(3).mean(2), but
# speedup only suport the mean option
# whose output only have two dimensions
x = x.mean([2, 3])
x = self.classifier(x)
return x
......
......@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis
MAX_PRUNE_RATIO_PER_ITER = 0.95
_logger = logging.getLogger('Sensitivity_Pruner')
_logger.setLevel(logging.INFO)
class SensitivityPruner(Pruner):
"""
......@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner):
prune_ratios = sorted(sensitivities[layer].keys())
last_ratio = 0
for ratio in prune_ratios:
last_ratio = ratio
cur_acc = sensitivities[layer][ratio]
if cur_acc + threshold < ori_acc:
break
last_ratio = ratio
max_ratio[layer] = last_ratio
return max_ratio
......@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner):
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# ratios under this threshold
if _Max > MAX_PRUNE_RATIO_PER_ITER:
for layername in ratios:
ratios[layername] = ratios[layername] * \
MAX_PRUNE_RATIO_PER_ITER / _Max
......@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner):
finetune_kwargs = {}
if self.ori_acc is None:
self.ori_acc = self.evaluator(*eval_args, **eval_kwargs)
assert isinstance(self.ori_acc, float) or isinstance(self.ori_acc, int)
if not resume_sensitivity:
self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs)
......@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner):
iteration_count = 0
if self.checkpoint_dir is not None:
os.makedirs(self.checkpoint_dir, exist_ok=True)
modules_wrapper_final = None
while cur_ratio > target_ratio:
iteration_count += 1
# Each round have three steps:
......@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner):
# layers according to the sensitivity result
proportion = self.sparsity_proportion_calc(
ori_acc, self.acc_drop_threshold, self.sensitivities)
new_pruneratio = self.normalize(proportion, self.sparsity_per_iter)
cfg_list = self.create_cfg(new_pruneratio)
if not cfg_list:
_logger.error('The threshold is too small, please set a larger threshold')
return self.model
_logger.debug('Pruner Config: %s', str(cfg_list))
cfg_str = ['%s:%.3f'%(cfg['op_names'][0], cfg['sparsity']) for cfg in cfg_list]
_logger.info('Current Sparsities: %s', ','.join(cfg_str))
pruner = self.Pruner(self.model, cfg_list)
pruner.compress()
pruned_acc = self.evaluator(*eval_args, **eval_kwargs)
......@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner):
self.analyzer.already_pruned[name] = sparsity
# update the cur_ratio
cur_ratio = 1 - self.current_sparsity()
modules_wrapper_final = pruner.get_modules_wrapper()
del pruner
_logger.info('Currently remained weights: %f', cur_ratio)
......@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner):
with open(cfg_path, 'w') as jf:
json.dump(cfg_list, jf)
self.analyzer.export(sensitivity_path)
if cur_ratio > target_ratio:
# If this is the last prune iteration, skip the time-consuming
# sensitivity analysis
self.analyzer.load_state_dict(self.model.state_dict())
self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs)
_logger.info('After Pruning: %.2f weights remains', cur_ratio)
self.modules_wrapper = modules_wrapper_final
self._wrap_model()
return self.model
def calc_mask(self, wrapper, **kwargs):
......
......@@ -163,7 +163,7 @@ class SensitivityAnalysis:
if val_kwargs is None:
val_kwargs = {}
# Get the original validation metric(accuracy/loss) before pruning
if self.ori_metric is None:
# Get the accuracy baseline before starting the analysis.
self.ori_metric = self.val_func(*val_args, **val_kwargs)
namelist = list(self.target_layer.keys())
if specified_layers is not None:
......@@ -172,19 +172,21 @@ class SensitivityAnalysis:
for name in namelist:
self.sensitivities[name] = {}
for sparsity in self.sparsities:
# here the sparsity is the relative sparsity of the
# the remained weights
# Calculate the actual prune ratio based on the already pruned ratio
sparsity = (
real_sparsity = (
1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names
cfg = [{'sparsity': sparsity, 'op_names': [
cfg = [{'sparsity': real_sparsity, 'op_names': [
name], 'op_types': ['Conv2d']}]
pruner = self.Pruner(self.model, cfg)
pruner.compress()
val_metric = self.val_func(*val_args, **val_kwargs)
logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f',
name, sparsity, val_metric)
name, real_sparsity, val_metric)
self.sensitivities[name][sparsity] = val_metric
pruner._unwrap_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment