Unverified Commit fd8fb783 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix total sparsity in sa & auto compress pruner bug (#4474)

parent 6b8efe3e
......@@ -99,7 +99,7 @@ if __name__ == '__main__':
trainer(model, optimizer, criterion, i)
evaluator(model)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
# evaluator in 'SimulatedAnnealingPruner' could not be None.
pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm=args.pruning_algo,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional
......@@ -13,6 +14,8 @@ from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator
_logger = logging.getLogger(__name__)
class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
......@@ -29,6 +32,13 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
return super().reset(model, config_list, masks)
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
......
......@@ -187,6 +187,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list)
self._adjust_target_sparsity()
......@@ -281,7 +286,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0:
self._temp_config_list.extend(deepcopy(config))
sub_temp_config_list = [deepcopy(config) for i in range(len(config['op_names']))]
for temp_config, op_name in zip(sub_temp_config_list, config['op_names']):
temp_config.update({'total_sparsity': 0, 'op_names': [op_name]})
self._temp_config_list.extend(sub_temp_config_list)
self._temp_sparsity_list.append([])
continue
while True:
......
......@@ -98,7 +98,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_simulated_annealing_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = SimulatedAnnealingPruner(model, config_list, evaluator, start_temperature=40, log_dir='../../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
......@@ -107,7 +107,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_auto_compress_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
admm_params = {
'trainer': trainer,
'traced_optimizer': get_optimizer(model),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment