Unverified Commit fd8fb783 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix total sparsity in sa & auto compress pruner bug (#4474)

parent 6b8efe3e
...@@ -99,7 +99,7 @@ if __name__ == '__main__': ...@@ -99,7 +99,7 @@ if __name__ == '__main__':
trainer(model, optimizer, criterion, i) trainer(model, optimizer, criterion, i)
evaluator(model) evaluator(model)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
# evaluator in 'SimulatedAnnealingPruner' could not be None. # evaluator in 'SimulatedAnnealingPruner' could not be None.
pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm=args.pruning_algo, pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm=args.pruning_algo,
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Callable, Optional from typing import Dict, List, Callable, Optional
...@@ -13,6 +14,8 @@ from .basic_pruner import ADMMPruner ...@@ -13,6 +14,8 @@ from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator from .tools import LotteryTicketTaskGenerator
_logger = logging.getLogger(__name__)
class AutoCompressTaskGenerator(LotteryTicketTaskGenerator): class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict], def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
...@@ -29,6 +32,13 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator): ...@@ -29,6 +32,13 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
log_dir=log_dir, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
return super().reset(model, config_list, masks)
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA') self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
self.iterative_pruner.reset(model, config_list=config_list, masks=masks) self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
......
...@@ -187,6 +187,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -187,6 +187,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature self.current_temperature = self.start_temperature
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks) self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list) self.target_sparsity_list = config_list_canonical(model, config_list)
self._adjust_target_sparsity() self._adjust_target_sparsity()
...@@ -281,7 +286,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -281,7 +286,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list): for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0: if len(current_sparsity) == 0:
self._temp_config_list.extend(deepcopy(config)) sub_temp_config_list = [deepcopy(config) for i in range(len(config['op_names']))]
for temp_config, op_name in zip(sub_temp_config_list, config['op_names']):
temp_config.update({'total_sparsity': 0, 'op_names': [op_name]})
self._temp_config_list.extend(sub_temp_config_list)
self._temp_sparsity_list.append([]) self._temp_sparsity_list.append([])
continue continue
while True: while True:
......
...@@ -98,7 +98,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -98,7 +98,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_simulated_annealing_pruner(self): def test_simulated_annealing_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = SimulatedAnnealingPruner(model, config_list, evaluator, start_temperature=40, log_dir='../../../logs') pruner = SimulatedAnnealingPruner(model, config_list, evaluator, start_temperature=40, log_dir='../../../logs')
pruner.compress() pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result() _, pruned_model, masks, _, _ = pruner.get_best_result()
...@@ -107,7 +107,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -107,7 +107,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_auto_compress_pruner(self): def test_auto_compress_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
admm_params = { admm_params = {
'trainer': trainer, 'trainer': trainer,
'traced_optimizer': get_optimizer(model), 'traced_optimizer': get_optimizer(model),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment